<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Neural-Networks on Hunter Heidenreich | AI Research Scientist &amp; Engineer</title><link>https://hunterheidenreich.com/tags/neural-networks/</link><description>Recent content in Neural-Networks on Hunter Heidenreich | AI Research Scientist &amp; Engineer</description><image><title>Hunter Heidenreich | AI Research Scientist &amp; Engineer</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Fri, 13 Mar 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/tags/neural-networks/index.xml" rel="self" type="application/rss+xml"/><item><title>ChemDFM-R: Chemical Reasoning LLM with Atomized Knowledge</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemdfm-r/</link><pubDate>Fri, 26 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemdfm-r/</guid><description>A 14B-parameter chemical reasoning LLM enhanced with atomized functional group knowledge and mix-sourced distillation strategy.</description><content:encoded><![CDATA[<h2 id="method-and-resource-contributions">Method and Resource Contributions</h2>
<p>This is primarily a <strong>Method</strong> paper with significant <strong>Resource</strong> contributions.</p>
<ul>
<li><strong>Methodological Basis</strong>: The paper introduces a training pipeline (&ldquo;mix-sourced distillation&rdquo;) and domain-specific reinforcement learning to improve reasoning capabilities in chemical LLMs. It validates the approach through ablation studies across training stages.</li>
<li><strong>Resource Contribution</strong>: The authors constructed <strong>ChemFG</strong>, a 101 billion-token corpus annotated with &ldquo;atomized&rdquo; knowledge regarding functional groups and reaction centers.</li>
</ul>
<h2 id="bridging-the-chemical-reasoning-gap">Bridging the Chemical Reasoning Gap</h2>
<p>Current chemical LLMs struggle to reason logically for two main reasons:</p>
<ol>
<li><strong>Shallow Domain Understanding</strong>: Models generally learn molecule-level properties directly, bypassing the intermediate &ldquo;atomized&rdquo; characteristics (e.g., functional groups) that ultimately dictate chemical behavior.</li>
<li><strong>Specialized Reasoning Logic</strong>: Chemical logic differs fundamentally from math or code. Distilling reasoning from general teacher models like DeepSeek-R1 frequently fails because the teachers lack the domain intuition required to generate valid chemical rationales.</li>
</ol>
<h2 id="atomized-knowledge-and-mixed-source-distillation">Atomized Knowledge and Mixed-Source Distillation</h2>
<p>The authors introduce three structural innovations to solve the reasoning gap:</p>
<ol>
<li><strong>Atomized Knowledge Enhancement (ChemFG)</strong>: A toolkit was built leveraging SMARTS notations to identify functional group changes during reactions. A critique of this approach is that it relies heavily on 2D cheminformatics abstractions, potentially missing deeper 3D stereochemical interactions.</li>
<li><strong>Mix-Sourced Distillation</strong>: General models (DeepSeek-R1/o3-mini) are fed &ldquo;pseudo-reasoning&rdquo; prompts that include ground truth answers and functional group data. While this forces the teacher to generate high-quality rationales for the student to learn, it introduces a layer of hindsight bias into the generated reasoning chains. During inference, the student model lacks both the pre-calculated functional group metadata and the ground truth, forcing it to bridge an artificially steep generalization gap.</li>
<li><strong>Chemical Reinforcement Learning</strong>: The intermediate model undergoes domain-specific reinforcement learning. The RL details are described in the paper&rsquo;s Appendix D, with the authors citing the open-source DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) framework. The optimization relies on rule-based rewards (format adherence and canonicalized SMILES accuracy) across a variety of chemical tasks.</li>
</ol>
<h2 id="benchmark-evaluation-and-ablation-studies">Benchmark Evaluation and Ablation Studies</h2>
<p>The model was evaluated on comprehensive chemical benchmarks: <strong>SciKnowEval</strong> (19 tasks) and <strong>ChemEval</strong> (36 tasks).</p>
<ul>
<li><strong>Baselines</strong>: Compared against similarly sized open models (Qwen2.5-14B-Instruct, Qwen3-14B), domain models (ChemLLM, MolInst), and frontier models (GPT-4o, DeepSeek-R1).</li>
<li><strong>Ablation</strong>: Evaluated across training stages (Base → ChemDFM-I → ChemDFM-R) to measure the specific impact of the instruction tuning versus the reasoning stages.</li>
<li><strong>Qualitative Analysis</strong>: The paper includes case studies demonstrating the model&rsquo;s step-by-step chemical reasoning and its potential for human-AI collaboration (Sections 4.2 and 4.3).</li>
</ul>
<h2 id="performance-outcomes-and-numerical-limitations">Performance Outcomes and Numerical Limitations</h2>
<ul>
<li><strong>Performance vs. Baselines</strong>: ChemDFM-R outperforms similarly sized open models and domain models on molecule-centric and reaction-centric tasks, and surpasses the much larger DeepSeek-R1 on ChemEval (0.78 vs. 0.58 overall). It shows competitive results relative to o4-mini, though o4-mini leads on SciKnowEval (0.74 vs. 0.70).</li>
<li><strong>Reasoning Interactivity</strong>: The model generates readable rationales that allow users to catch structural errors or identify reaction mechanisms accurately. Section 4.3 of the paper demonstrates human-AI collaboration scenarios.</li>
<li><strong>Quantitative Limitations</strong>: The model struggles with tasks involving numerical prediction and calculation (e.g., yield extraction, molecular property calculation). The paper notes that all molecule-centric and reaction-centric tasks where ChemDFM-R falls short of Qwen2.5-14B-Instruct involve numerical reasoning.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is constructed in three phases:</p>
<p><strong>1. Domain Pre-training (ChemFG)</strong>:</p>
<ul>
<li><strong>Size</strong>: 101 billion tokens</li>
<li><strong>Composition</strong>:
<ul>
<li>12M literature documents (79B tokens)</li>
<li>30M molecules from PubChem/PubChemQC</li>
<li>7M reactions from USPTO-FULL</li>
</ul>
</li>
<li><strong>Augmentation</strong>: SMILES augmentation (10x) using R-SMILES</li>
<li><strong>Atomized Features</strong>: Annotated with a custom &ldquo;Functional Group Identification Toolkit&rdquo; that identifies 241 functional group types and tracks changes in reaction centers. <em>Note: Data and toolkit are partially reproduced; while the toolkit (<a href="https://github.com/OpenDFM/ChemFG-Tool">ChemFG-Tool</a>) was open-sourced on GitHub, the 101 billion-token ChemFG dataset itself has not been publicly released.</em></li>
</ul>
<p><strong>2. Instruction Tuning</strong>:</p>
<ul>
<li><strong>Sources</strong>: Molecule-centric (PubChem, MoleculeNet), Reaction-centric (USPTO), and Knowledge-centric (Exams, Literature QA) tasks</li>
<li><strong>Mixing</strong>: Mixed with general instruction data in a 1:2 ratio</li>
</ul>
<p><strong>3. Distillation Dataset</strong>:</p>
<ul>
<li><strong>Sources</strong>:
<ul>
<li>~70% ChemDFM-R instruction data</li>
<li>~22% constructed pseudo-reasoning (functional group descriptions)</li>
<li>~8% teacher rationales (from DeepSeek-R1/o3-mini)</li>
</ul>
</li>
<li><strong>Mixing</strong>: Mixed with general data (including AM-Deepseek-R1-Distill-1.4M) in a 1:2 ratio</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Functional Group Identification</strong>:</p>
<ul>
<li>Extends the <code>thermo</code> library&rsquo;s SMARTS list</li>
<li>For reactions, identifies &ldquo;reacting functional groups&rdquo; by finding reactants containing atoms involved in bond changes (reaction centers) that do not appear in the product</li>
</ul>
<p><strong>Mix-Sourced Distillation</strong>:</p>
<ul>
<li>Teacher models (DeepSeek-R1, o3-mini) are prompted with Question + Ground Truth + Functional Group Info to generate high-quality &ldquo;Thoughts&rdquo;</li>
<li>These rationales are distilled into the student model using a supervised fine-tuning loss across target tokens $y_t$:
$$ \mathcal{L}_{\text{SFT}} = - \sum_{t=1}^T \log P_\theta(y_t \mid x, y_{&lt;t}) $$</li>
</ul>
<p><strong>Reinforcement Learning</strong>:</p>
<ul>
<li><strong>Algorithm</strong>: The paper cites DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) as the RL framework; full details are in Appendix D of the paper. <em>Note: While the underlying DAPO framework is open-source, the specific chemistry-oriented RL pipeline and environment used for ChemDFM-R has not been publicly released.</em></li>
<li><strong>Hyperparameters</strong> (from paper appendix): Learning rate <code>5e-7</code>, rollout batch size <code>512</code>, training batch size <code>128</code></li>
<li><strong>Rewards</strong>: The reward system applies rule-based constraints focusing on physical form and chemical validity. The total reward $R(y, y^*)$ for a generated response $y$ given target $y^*$ combines a format adherence reward ($R_{\text{format}}$) and an accuracy reward ($R_{\text{acc}}$) evaluated on canonicalized SMILES:
$$ R(y, y^*) = R_{\text{format}}(y) + R_{\text{acc}}(\text{canonicalize}(y), \text{canonicalize}(y^*)) $$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Base Model</strong>: Qwen2.5-14B</li>
<li><strong>ChemDFM-I</strong>: Result of instruction tuning the domain-pretrained model for 2 epochs</li>
<li><strong>ChemDFM-R</strong>: Result of applying mix-sourced distillation (1 epoch) followed by RL on ChemDFM-I. <em>Note: Model weights are publicly available on <a href="https://huggingface.co/OpenDFM/ChemDFM-R-14B">Hugging Face</a>.</em></li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Hardware and training time details are described in the paper&rsquo;s appendices, which are not available in the extracted text. The details below are reported from the paper but could not be independently cross-verified against the main text:</p>
<ul>
<li><strong>Compute</strong>: NVIDIA A800 Tensor Core GPUs</li>
<li><strong>Training Time</strong>: 30,840 GPU hours total (Domain Pretraining: 24,728 hours; Instruction Tuning: 3,785 hours; Distillation: 2,059 hours; Reinforcement Learning: 268 hours)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>SciKnowEval</strong>: 19 tasks (text-centric, molecule-centric, reaction-centric)</li>
<li><strong>ChemEval</strong>: 36 tasks, categorized similarly</li>
</ul>
<p><strong>Key Metrics</strong>: Accuracy, F1 Score, BLEU score (with PRS normalization for ChemEval)</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>SciKnowEval (all)</th>
          <th>ChemEval* (all)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Qwen2.5-14B-Instruct</td>
          <td>0.61</td>
          <td>0.57</td>
          <td>General-domain baseline</td>
      </tr>
      <tr>
          <td>ChemDFM-I</td>
          <td>0.69</td>
          <td>0.72</td>
          <td>After domain pretraining + instruction tuning</td>
      </tr>
      <tr>
          <td>ChemDFM-R</td>
          <td><strong>0.70</strong></td>
          <td><strong>0.78</strong></td>
          <td>After distillation + RL</td>
      </tr>
      <tr>
          <td>DeepSeek-R1</td>
          <td>0.62</td>
          <td>0.58</td>
          <td>General-domain reasoning model</td>
      </tr>
      <tr>
          <td>o4-mini</td>
          <td><strong>0.74</strong></td>
          <td>0.69</td>
          <td>Frontier reasoning model</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/OpenDFM/ChemDFM-R-14B">ChemDFM-R-14B</a></td>
          <td>Model</td>
          <td>AGPL-3.0</td>
          <td>Final reasoning model weights on Hugging Face</td>
      </tr>
      <tr>
          <td><a href="https://github.com/OpenDFM/ChemFG-Tool">ChemFG-Tool</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Functional group identification toolkit (241 groups)</td>
      </tr>
  </tbody>
</table>
<p><strong>Missing components</strong>: The 101B-token ChemFG pretraining dataset is not publicly released. The chemistry-oriented RL pipeline and training code are not open-sourced. The instruction tuning and distillation datasets are not available.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhao, Z., Chen, B., Wan, Z., Chen, L., Lin, X., Yu, S., Zhang, S., Ma, D., Zhu, Z., Zhang, D., Wang, H., Dai, Z., Wen, L., Chen, X., &amp; Yu, K. (2025). ChemDFM-R: A Chemical Reasoning LLM Enhanced with Atomized Chemical Knowledge. <em>arXiv preprint arXiv:2507.21990</em>. <a href="https://doi.org/10.48550/arXiv.2507.21990">https://doi.org/10.48550/arXiv.2507.21990</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{zhao2025chemdfmr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemDFM-R: A Chemical Reasoning LLM Enhanced with Atomized Chemical Knowledge}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zihan Zhao and Bo Chen and Ziping Wan and Lu Chen and Xuanze Lin and Shiyang Yu and Situo Zhang and Da Ma and Zichen Zhu and Danyang Zhang and Huayang Wang and Zhongyang Dai and Liyang Wen and Xin Chen and Kai Yu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2507.21990}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.CE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2507.21990}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa-3: Open Source Chemical Foundation Models</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta-3/</link><pubDate>Fri, 26 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta-3/</guid><description>An open-source framework integrating DeepChem and Ray for training and benchmarking chemical foundation models like MoLFormer and GROVER at scale.</description><content:encoded><![CDATA[<h2 id="core-contribution-an-open-source-framework">Core Contribution: An Open-Source Framework</h2>
<p>This is primarily a <strong>Resource ($\Psi_{\text{Resource}}$)</strong> paper, with secondary <strong>Method ($\Psi_{\text{Method}}$)</strong> contributions.</p>
<ul>
<li><strong>Resource Basis</strong>: The core contribution is &ldquo;ChemBERTa-3,&rdquo; an open-source framework integrated into DeepChem that standardizes the pretraining and benchmarking of chemical foundation models. The authors focus heavily on infrastructure (AWS/Ray integration) and correcting benchmarking inconsistencies in the field.</li>
<li><strong>Method Basis</strong>: It trains models like &ldquo;c3-MoLFormer&rdquo; to reproduce and validate the infrastructure.</li>
</ul>
<h2 id="the-pretraining-scalability-challenge">The Pretraining Scalability Challenge</h2>
<ul>
<li><strong>Scalability Challenges</strong>: Building robust molecular models is difficult due to the vast size of chemical space and the computational intensity of pretraining on large datasets.</li>
<li><strong>Proprietary Barriers</strong>: Many high-performing chemical foundation models (e.g., the full MoLFormer-XL) are partially closed-source or difficult to reproduce.</li>
<li><strong>Benchmarking Inconsistencies</strong>: There is a lack of systematic comparison between architectures (e.g., Graph vs. Transformer) using unified protocols. Specifically, previous comparisons relied on reported results that used differing scaffold splitting algorithms, making them inaccurate.</li>
</ul>
<h2 id="unified-infrastructure--standardized-benchmarking">Unified Infrastructure &amp; Standardized Benchmarking</h2>
<ul>
<li><strong>Unified Infrastructure</strong>: Integration of DeepChem with Ray for distributed, scalable pretraining and fine-tuning of both graph and transformer models.</li>
<li><strong>Standardized Benchmarking</strong>: Identification that MoLFormer&rsquo;s scaffold splitting algorithm differs from the standard DeepChem/MoleculeNet splitter, and the subsequent standardization of these benchmarks for fair comparison.</li>
<li><strong>New DeepChem Tools</strong>: Introduction of the <code>ModularTorchModel</code> class for flexible loss computation and <code>HuggingFaceModel</code> wrappers to bridge ecosystems.</li>
</ul>
<h2 id="benchmarking-transformers-vs-graph-models">Benchmarking Transformers vs. Graph Models</h2>
<ul>
<li><strong>Architecture Comparison</strong>: Benchmarked Transformers (<a href="/notes/computational-chemistry/chemical-language-models/chemberta/">ChemBERTa</a>, MoLFormer) against Graph models (GROVER, InfoGraph, InfoMax3D, DMPNN, GCN) and baselines (Random Forest).</li>
<li><strong>Pretraining Scale Disparity</strong>:
<ul>
<li>Transformers were pretrained on ZINC20 subsets ranging from 10M to 1.1B molecules (combining ZINC and PubChem).</li>
<li>Graph models were limited to 250K molecule subsets due to memory and computational overhead of message passing on large graphs. While this highlights the superior scalability of Transformer architectures, comparing a 1.1B-trained Transformer to a 250K-trained Graph model provides an unbalanced evaluation of architectural capacity.</li>
</ul>
</li>
<li><strong>Reproducibility Validation</strong>: Trained &ldquo;c3-MoLFormer&rdquo; (a reproduction of MoLFormer) on 1.1B molecules using two distinct hardware setups: AWS spot instances (Ray) and a local HPC cluster.</li>
<li><strong>Scaffold Split Analysis</strong>: Compared performance metrics using &ldquo;DeepChem scaffold splits&rdquo; vs. &ldquo;MoLFormer scaffold splits&rdquo; to quantify the impact of data leakage/overlap.</li>
</ul>
<h2 id="overcoming-scaffold-splitting-inconsistencies">Overcoming Scaffold Splitting Inconsistencies</h2>
<ul>
<li><strong>Scaling Transformers vs. Graphs</strong>: Transformer-based models are significantly easier to scale to large datasets than current graph-based approaches, though performance is comparable at small scales.</li>
<li><strong>Benchmarking sensitivity</strong>: MoLFormer&rsquo;s reported superiority over baselines was partly inflated by its specific scaffold splitting method, which had higher structural overlap between train and test sets (yielding a lower Tanimoto distance, generally quantified via $1 - \frac{|A \cap B|}{|A \cup B|}$) than DeepChem splits. When standardized, baselines like DMPNN perform more competitively.</li>
<li><strong>Infrastructure Viability</strong>: The framework successfully replicated large-scale training (MoLFormer-1.1B) on both cloud and on-premise HPC, confirming reproducibility.</li>
<li><strong>Open Source Release</strong>: All code, configurations, and the c3-MoLFormer-1.1B model weights are released to facilitate future research.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Pretraining</strong>:
<ul>
<li><strong>Source</strong>: ZINC20 (1.4B compounds) and PubChem.</li>
<li><strong>Scale</strong>: Subsets of 10M, 100M, and 1.1B (100% ZINC20 + 100% PubChem) were used for Transformers. Graph models used a 250K subset.</li>
</ul>
</li>
<li><strong>Fine-tuning</strong>:
<ul>
<li><strong>Suite</strong>: MoleculeNet.</li>
<li><strong>Tasks</strong>: Classification (BACE, BBBP, Tox21, HIV, SIDER, ClinTox) and Regression (ESOL, FreeSolv, Lipophilicity, QM9).</li>
<li><strong>Splits</strong>: Critical distinction made between &ldquo;DeepChem scaffold splits&rdquo; (80/10/10) and &ldquo;MoLFormer scaffold splits&rdquo; (which can be downloaded from <a href="https://ibm.ent.box.com/v/MoLFormer-data"><code>https://ibm.ent.box.com/v/MoLFormer-data</code></a>). The paper notes these algorithms differ.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework</strong>: DeepChem integrated with Ray for distributed training. To recreate the environment, the repository relies on a nightly version of DeepChem (<code>pip install --pre deepchem</code>) and specific dependencies found within the <code>requirements.txt</code>. Pretraining scripts are available in the <code>chemberta3_benchmarking/pretraining</code> directory of the repository.</li>
<li><strong>Data Preparation</strong>: Featurization workflows (e.g., <code>CircularFingerprint</code>, <code>RDKitConformer</code>) are documented under <code>chemberta3_benchmarking/data/data_preprocessing/</code> in the codebase.</li>
<li><strong>Modular Training</strong>: Uses <code>ModularTorchModel</code> to allow loss computation from intermediate values and flexible component connection.</li>
<li><strong>Training Brittleness</strong>:
<ul>
<li><strong>Optimizer</strong>: Linear learning rate scheduler with warmup.</li>
<li><strong>Instability Handling</strong>: The authors observed significant loss spikes during warmup. Their primary mitigation strategy involved checkpointing frequently and restarting from the last stable state upon a spike, highlighting a persistent brittleness in optimizing these large chemical foundation models.</li>
<li><strong>Numerical Issues</strong>: Addressed NaN values by pretraining on a small dataset with low LR before scaling up.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong><a href="/notes/computational-chemistry/chemical-language-models/chemberta/">ChemBERTa</a></strong>: RoBERTa-based architecture trained with Masked Language Modeling (MLM) and Multitask Regression (MTR). Specific model identifiers (e.g., <a href="https://huggingface.co/DeepChem/ChemBERTa-100M-MLM"><code>DeepChem/ChemBERTa-100M-MLM</code></a>) are hosted on Hugging Face so researchers can pull them directly via the <code>transformers</code> library. The core pretraining objective minimized the standard MLM loss:
$$ \mathcal{L}<em>{\text{MLM}} = - \frac{1}{|\mathcal{M}|} \sum</em>{i \in \mathcal{M}} \log \hat{y}<em>{i} $$
where $\mathcal{M}$ represents the set of masked SMILES token indices, and $\hat{y}</em>{i}$ is the model&rsquo;s predicted probability for the correct token given the corrupted sequence context.</li>
<li><strong>MoLFormer (c3-MoLFormer)</strong>: Re-implementation of the MoLFormer architecture (Rotary embeddings, linear attention). Specific model identifiers (e.g., <a href="https://huggingface.co/DeepChem/MoLFormer-c3-1.1B"><code>DeepChem/MoLFormer-c3-1.1B</code></a>) are similarly available on Hugging Face.
<ul>
<li>Tokenizer: <code>ibm/MoLFormer-XL-both-10pct</code> tokenizer.</li>
</ul>
</li>
<li><strong>Graph Models</strong>:
<ul>
<li><strong>GROVER</strong>: Graph Transformer with node/edge/graph level self-supervision.</li>
<li><strong>InfoGraph</strong>: Maximizes mutual information between graph-level and substructure representations.</li>
<li><strong>InfoMax3D</strong>: Incorporates 3D conformer data (via RDKit ETKDGv2) into contrastive pretraining.</li>
<li><strong>DMPNN</strong>: Directed Message Passing Neural Network (Chemprop variant).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: ROC-AUC for classification; RMSE for regression (MAE for QM9).</li>
<li><strong>Baselines</strong>: Random Forest, GCN, DMPNN trained on fine-tuning splits only.</li>
<li><strong>Protocol</strong>: Three independent runs per configuration to report mean and range (not a confidence interval), with the exception of the compute-heavy QM9 dataset, which only received a single run. Benchmarking execution scripts (e.g., GCN, RF, DMPNN, ChemBERTa) are stored in the repo under <code>chemberta3_benchmarking/models_benchmarking/</code> and contain the specific fine-tuning hyperparameters and optimizer configurations used for each downstream task.</li>
<li><strong>Key Results</strong>:
<ul>
<li><em>c3-MoLFormer-1.1B</em> achieved ~0.848 ROC-AUC on BACE and ~0.900 on BBBP (using MoLFormer splits). This closely matches the original IBM MoLFormer metrics, validating the reproducibility of the open-source framework.</li>
<li>When constrained to the equivalent 250K subset, Graph models (InfoGraph, GROVER) performed comparably to Transformers, indicating that Transformer superiority in chemistry is largely driven by data scalability rather than an inherent architectural advantage at small scales.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Cloud (AWS)</strong>:
<ul>
<li><strong>Compute</strong>: 40 NVIDIA T4 GPUs (<code>g4dn.12xlarge</code> spot instances for pretraining, <code>g4dn.2xlarge</code> for benchmarking).</li>
<li><strong>Cost</strong>: ~$4000 for MoLFormer 1.1B pretraining.</li>
<li><strong>Time</strong>: ~10 days (260 hours) for 1.1B model pretraining.</li>
<li><strong>Setup</strong>: Setup scripts for single-node and multi-node spot EC2 clusters are provided in the GitHub repository&rsquo;s <code>infra/</code> and <code>spot/</code> folders.</li>
</ul>
</li>
<li><strong>On-Premise HPC</strong>:
<ul>
<li><strong>Compute</strong>: 16 nodes (AMD EPYC), each with 4 AMD MI300A APUs.</li>
<li><strong>Environment</strong>: Ray multi-node multi-GPU framework.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/deepforestsci/chemberta3">ChemBERTa-3 GitHub Repository</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training, fine-tuning, and benchmarking framework</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-1.1B">DeepChem/MoLFormer-c3-1.1B</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer re-implementation pretrained on 1.1B molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/ChemBERTa-100M-MLM">DeepChem/ChemBERTa-100M-MLM</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>ChemBERTa pretrained on 100M ZINC molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-100M">DeepChem/MoLFormer-c3-100M</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer pretrained on 100M molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-550M">DeepChem/MoLFormer-c3-550M</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer pretrained on 550M molecules</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Singh, R. et al. (2026). ChemBERTa-3: an open source training framework for chemical foundation models. <em>Digital Discovery</em>, 5, 662-685. <a href="https://doi.org/10.1039/D5DD00348B">https://doi.org/10.1039/D5DD00348B</a></p>
<p><strong>Publication</strong>: Digital Discovery 2026</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/deepforestsci/chemberta3">ChemBERTa-3 GitHub Repository</a></li>
<li><a href="https://deepchem.io/">DeepChem Project</a></li>
<li><a href="https://huggingface.co/DeepChem">DeepChem Hugging Face Models</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{singhChemBERTa3OpenSource2026,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Singh, Riya and Barsainyan, Aryan Amit and Irfan, Rida and Amorin, Connor Joseph and He, Stewart and Davis, Tony and Thiagarajan, Arun and Sankaran, Shiva and Chithrananda, Seyone and Ahmad, Walid and Jones, Derek and McLoughlin, Kevin and Kim, Hyojin and Bhutani, Anoushka and Sathyanarayana, Shreyas Vinaya and Viswanathan, Venkat and Allen, Jonathan E. and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa-3}}: an open source training framework for chemical foundation models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2026}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{662-685}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{The Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D5DD00348B}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1039/D5DD00348B}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>GP-MoLFormer: Molecular Generation via Transformers</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/gp-molformer/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/gp-molformer/</guid><description>A 46.8M parameter transformer for molecular generation trained on 1.1B SMILES, introducing pair-tuning for efficient property optimization.</description><content:encoded><![CDATA[<h2 id="contribution-and-taxonomic-focus">Contribution and Taxonomic Focus</h2>
<p>This is primarily a <strong>Methodological</strong> paper, as it proposes a specific neural architecture (GP-MoLFormer) and a novel fine-tuning algorithm (Pair-tuning) for molecular generation. It validates these contributions against standard baselines (e.g., JT-VAE, MolGen-7b).</p>
<p>It also contains a secondary <strong>Theoretical</strong> contribution by establishing an empirical scaling law that relates inference compute (generation size) to the novelty of the generated molecules.</p>
<h2 id="motivation-data-scale-and-prompt-based-optimization">Motivation: Data Scale and Prompt-Based Optimization</h2>
<p>While large language models (LLMs) have transformed text generation, the impact of training data scale and memorization on <em>molecular</em> generative models remains under-explored. Specifically, there is a need to understand how training on billion-scale datasets affects the novelty of generated molecules and whether biases in public databases (like ZINC and PubChem) perpetuate memorization. Furthermore, existing optimization methods often require computationally expensive property predictors or reinforcement learning loops; there is a practical need for more efficient &ldquo;prompt-based&rdquo; optimization techniques.</p>
<h2 id="core-innovations-architecture-and-pair-tuning">Core Innovations: Architecture and Pair-Tuning</h2>
<ol>
<li><strong>Architecture</strong>: The application of a linear-attention transformer decoder with Rotary Positional Embeddings (RoPE) to generative chemistry, allowing for efficient training on 1.1 billion SMILES.</li>
<li><strong>Pair-Tuning</strong>: A novel, parameter-efficient fine-tuning method that uses property-ordered molecular pairs to learn &ldquo;soft prompts&rdquo; for optimization without updating the base model weights.</li>
<li><strong>Scaling Analysis</strong>: An extensive empirical investigation mapping the trade-off between inference compute (up to 10B generations) and chemical novelty, fitting an exponential decay curve that demonstrates how novelty saturates as generation volume grows.</li>
</ol>
<h2 id="experimental-methodology-and-downstream-tasks">Experimental Methodology and Downstream Tasks</h2>
<p>The authors evaluated GP-MoLFormer on three distinct tasks, though the comparisons highlight the difficulty of evaluating foundation models against classical baselines:</p>
<ol>
<li><strong>De Novo Generation</strong>: Comparing validity, uniqueness, and novelty against baselines (CharRNN, VAE, LIMO, MolGen-7b) on a held-out test set. Notably, this is an unequal comparison; most baselines were trained on the 1.6M molecule MOSES dataset, whereas GP-MoLFormer uses up to 1.1B molecules, meaning performance gains are heavily driven by data scale.</li>
<li><strong>Scaffold-Constrained Decoration</strong>: Generating molecules from DRD2 active binder scaffolds and measuring the hit rate of active compounds against specialized scaffold decorators.</li>
<li><strong>Property-Guided Optimization</strong>: Using Pair-tuning to optimize for Drug-likeness (QED), Penalized logP, and DRD2 binding activity, comparing the results to graph-based and reinforcement learning benchmarks.</li>
</ol>
<p>Additionally, they performed a <strong>Scaling Study</strong>:</p>
<ul>
<li>Comparing models trained on raw (1.1B) vs. de-duplicated (650M) data.</li>
<li>Generating up to 10 billion molecules to fit empirical scaling laws for novelty.</li>
</ul>
<h2 id="key-findings-and-scaling-laws">Key Findings and Scaling Laws</h2>
<ul>
<li><strong>Scale Driven Performance</strong>: GP-MoLFormer achieves high internal diversity and validity on generation metrics. However, its baseline novelty percentage (~32%) is considerably lower than classical models. The authors attribute this to the massive training scale forcing the model to heavily prioritize matching real-world molecule frequencies over pure exploration. GP-MoLFormer&rsquo;s advantage in generation metrics over LLM-baselines like MolGen-7b likely stems heavily from its 10x larger training dataset rather than fundamental architectural superiority.</li>
<li><strong>Pair-Tuning Efficacy</strong>: The proposed pair-tuning method effectively optimizes properties (e.g., improving DRD2 activity scores) without requiring full model fine-tuning or external reward loops. While successful, the text-based generation yields ~94.5% validity during optimization, which lags behind graph and SELFIES-based baselines that guarantee 100% structural validity.</li>
<li><strong>Memorization vs. Novelty</strong>: Training on de-duplicated data (GP-MoLFormer-UNIQ) yields higher novelty (approx. 5-8% higher) than training on raw data, confirming that duplication bias in public databases leads directly to memorization.</li>
<li><strong>Inference Scaling Law</strong>: Novelty decays exponentially with generation size ($y = ae^{-bx}$), yet the model maintains generative capability (~16.7% novelty) even after generating an unprecedented 10 billion molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Sources</strong>: A combination of <strong>PubChem</strong> (111M SMILES) and <strong>ZINC</strong> (1B SMILES) databases. Downloading and pre-training instructions are located in the repository&rsquo;s <code>data/README.md</code>.</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>All SMILES were canonicalized using RDKit (no isomeric information).</li>
<li><strong>GP-MoLFormer (Base)</strong>: Trained on the full 1.1B dataset (includes duplicates).</li>
<li><strong>GP-MoLFormer-UNIQ</strong>: Trained on a de-duplicated subset of 650M SMILES.</li>
</ul>
</li>
<li><strong>Tokenization</strong>: Uses the tokenizer from Schwaller et al. (2019) with a vocabulary size of <strong>2,362 tokens</strong>.</li>
<li><strong>Filtering</strong>: Sequences restricted to a maximum length of <strong>202 tokens</strong>.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pair-Tuning (Algorithm 1)</strong>:</p>
<ul>
<li><strong>Objective</strong>: Learn task-specific soft prompts $\phi_T$ to maximize the conditional probability of target molecule $b$ given a seed molecule $a$, where pair $(a, b)$ satisfies the property condition $b &gt; a$. The base model parameters $\theta$ remain frozen.</li>
<li><strong>Prompt Structure</strong>: Autoregressive training optimizes the continuous embeddings of $n$ enhancement tokens against the cross-entropy loss of the target sequence:
$$ \mathcal{L}(\phi_T) = - \sum_{i=1}^{|b|} \log P_{\theta}(b_i | \phi_T, a, b_{&lt;i}) $$</li>
<li><strong>Hyperparameters</strong>: Trained for 1,000 epochs with a batch size of 35 and a fixed learning rate of $3 \times 10^{-2}$.</li>
<li><strong>Inference</strong>: The learned prompt $\phi_T$ and seed molecule $a$ are prepended as context, and candidates are sampled autoregressively until a termination token is produced.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Availability</strong>: The model trained on deduplicated data (GP-MoLFormer-UNIQ) is publicly available on <a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">Hugging Face</a>. The full 1.1B base model is not explicitly hosted. The source code repository includes a disclosure that IBM will not maintain the code going forward.</li>
<li><strong>Architecture</strong>: Transformer decoder (~47M parameters: 12 layers, 12 heads, hidden size 768).</li>
<li><strong>Attention Mechanism</strong>: Combines Linear Attention (Generalized Random Feature map, $\phi$) with Rotary Positional Embeddings (RoPE). To avoid the quadratic complexity of standard attention while maintaining relative positional awareness, RoPE is applied to queries ($Q$) and keys ($K$) prior to the random feature mapping:
$$ \text{Attention}(Q, K, V) = \frac{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle v_n}{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle} $$</li>
<li><strong>Inference Speed</strong>: ~3ms per forward pass on a single A100 GPU.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Generation Quality Metrics</strong>: Validity, Uniqueness, Novelty (MOSES suite), Fréchet ChemNet Distance (FCD), Scaffold similarity (Scaf), and Similarity to Nearest Neighbor (SNN).</li>
<li><strong>MoLFormer-Based Metrics</strong>: The authors introduce Fréchet MoLFormer Distance (FMD) and MoLFormer-space IntDiv2 to measure distributional similarity using their own pre-trained continuous embeddings instead of standard fingerprints.</li>
<li><strong>Optimization Metrics</strong>: Penalized logP (calculated as $\text{logP} - \text{SA} - \text{max}(\text{maxrings}(size) - 6, 0)$), Drug-likeness (QED), and DRD2 activity scores.</li>
<li><strong>Scaling Metrics</strong>: Empirical fit for novelty decay: $y = ae^{-bx}$.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 16 x NVIDIA A100 (80 GB) GPUs across 2 nodes connected via EDR Infiniband.</li>
<li><strong>Training Time</strong>:
<ul>
<li>GP-MoLFormer (1.1B data): ~115 hours total (28.75 hours/epoch for 4 epochs).</li>
<li>GP-MoLFormer-UNIQ (650M data): ~80 hours total.</li>
</ul>
</li>
<li><strong>Hyperparameters</strong>: Used a batch size of 1,600 molecules per GPU with a fixed learning rate of $1.6 \times 10^{-4}$ (scaled up to $8\times$ factor as GPUs increased).</li>
<li><strong>Optimization</strong>: Used distributed data-parallel training and adaptive bucketing by sequence length to handle scale.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/gp-molformer/">GP-MoLFormer (GitHub)</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official implementation; IBM will not maintain going forward</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">GP-MoLFormer-Uniq (Hugging Face)</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained on 650M de-duplicated SMILES</td>
      </tr>
  </tbody>
</table>
<p>The full 1.1B base model weights are not publicly hosted. The training data (PubChem and ZINC) is publicly available, and instructions for downloading and pre-processing are in the repository&rsquo;s <code>data/README.md</code>.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ross, J., Belgodere, B., Hoffman, S. C., Chenthamarakshan, V., Navratil, J., Mroueh, Y., &amp; Das, P. (2025). GP-MoLFormer: A Foundation Model For Molecular Generation. <em>Digital Discovery</em> (2025). <a href="https://doi.org/10.1039/D5DD00122F">https://doi.org/10.1039/D5DD00122F</a></p>
<p><strong>Publication</strong>: Digital Discovery (2025)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ross2025gpmolformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{GP-MoLFormer: a foundation model for molecular generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ross, Jerret and Belgodere, Brian and Hoffman, Samuel C and Chenthamarakshan, Vijil and Navratil, Jiri and Mroueh, Youssef and Das, Payel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D5DD00122F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa-2: Scaling Molecular Transformers to 77M</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta-2/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta-2/</guid><description>Optimizing transformer pretraining for molecules using MLM vs MTR objectives, scaling to 77M compounds from PubChem for improved property prediction.</description><content:encoded><![CDATA[<h2 id="classifying-chemberta-2s-methodological-contributions">Classifying ChemBERTa-2&rsquo;s Methodological Contributions</h2>
<p>This is primarily a <strong>Methodological</strong> paper with a secondary <strong>Resource</strong> contribution.</p>
<p>It fits the Method classification because it focuses on optimizing the architecture and pretraining pipeline for molecular transformers. The authors perform extensive ablation studies (varying dataset size from 5M to 77M, comparing MLM vs. MTR objectives) to determine &ldquo;how well&rdquo; these strategies work compared to baselines. The secondary Resource classification applies because they open-source the trained models and establish a benchmark on a massive 77M compound dataset.</p>
<p><strong>Key methodological indicators</strong>:</p>
<ul>
<li><strong>Baseline comparison</strong>: The paper explicitly compares ChemBERTa-2 against standard baselines (D-MPNN, Random Forest, GCN) and its predecessor (ChemBERTa-1) with prominent benchmark tables</li>
<li><strong>Ablation studies</strong>: Extensive experiments comparing multi-task and self-supervised pretraining by varying hyperparameters and pretraining dataset size</li>
<li><strong>Scaling analysis</strong>: Systematic investigation of whether larger datasets (up to 77M compounds) yield better performance</li>
</ul>
<h2 id="motivations-for-scaling-molecular-transformers">Motivations for Scaling Molecular Transformers</h2>
<p>The authors aim to bridge the gap between NLP success stories (like GPT-3) and molecular machine learning by developing a &ldquo;chemical foundation model&rdquo;.</p>
<p><strong>Key motivations</strong>:</p>
<ul>
<li><strong>Label scarcity</strong>: Experimental labels for molecular properties are rare and expensive, but unlabeled SMILES strings are abundant</li>
<li><strong>Scaling hypothesis</strong>: Testing if scaling pretraining data (up to 77M compounds) yields consistent downstream improvements, similar to scaling laws in NLP</li>
<li><strong>Efficiency</strong>: Optimizing the pretraining process introduced in the original ChemBERTa by comparing self-supervised (MLM) and weakly supervised (MTR, using RDKit computed properties as labels) approaches</li>
</ul>
<h2 id="novelty-in-multi-task-regression-objectives">Novelty in Multi-Task Regression Objectives</h2>
<p><strong>Scale</strong>: Training on 77M unique SMILES from PubChem, which is one of the largest molecular pretraining datasets used to date (compared to 10M for ChemBERTa-1 or 18.7M for SMILES-BERT).</p>
<p><strong>Pipeline optimization</strong>: A direct, controlled comparison of <strong>Masked Language Modeling (MLM)</strong> vs. <strong>Multi-Task Regression (MTR)</strong> pretraining objectives on identical datasets.</p>
<p><strong>Proxy selection</strong>: The finding that MLM loss correlates well with MTR loss, allowing the cheaper MLM task to be used for hyperparameter tuning before running the expensive MTR pretraining.</p>
<h2 id="experimental-pretraining-setup-on-77m-compounds">Experimental Pretraining Setup on 77M Compounds</h2>
<h3 id="pretraining-setup">Pretraining Setup</h3>
<p><strong>Datasets</strong>: Subsets of PubChem containing 5M, 10M, and 77M unique SMILES.</p>
<p><strong>Tasks</strong>:</p>
<ul>
<li><strong>MLM</strong>: Masking 15% of tokens (following RoBERTa procedure). The model is optimized by minimizing the cross-entropy loss over the predicted masked tokens:
$$ \mathcal{L}_{MLM} = -\sum_{i \in \mathcal{M}} \log P(x_i \mid \mathbf{x}_{\setminus \mathcal{M}}) $$
where $\mathcal{M}$ represents the set of masked token indices.</li>
<li><strong>MTR</strong>: Predicting 200 calculated molecular properties (via RDKit) simultaneously using a mean squared error objective:
$$ \mathcal{L}_{MTR} = \frac{1}{200} \sum_{j=1}^{200} \frac{1}{N} \sum_{i=1}^{N} \left( \hat{y}_{ij} - y_{ij} \right)^2 $$
Continuous target labels $y_{ij}$ are mean-normalized prior to training to equilibrate the disparate scales of different chemical properties.</li>
</ul>
<p><strong>Hyperparameter search</strong>: Ran 50 random configurations on the 5M dataset; selected the top 5 to scale up to 10M and 77M.</p>
<h3 id="downstream-validation">Downstream Validation</h3>
<p><strong>Finetuning</strong>: Evaluated on 8 tasks from <strong>MoleculeNet</strong> (BACE, BBBP, ClinTox, Delaney, etc.) using scaffold splits (80/10/10).</p>
<p><strong>Analysis</strong>: Used UMAP to visualize embeddings from MLM, MTR, and ECFP to check for clustering by label without finetuning.</p>
<h2 id="key-performance-outcomes-and-scaling-realities">Key Performance Outcomes and Scaling Realities</h2>
<p><strong>Highly competitive performance</strong>: ChemBERTa-2 outperforms the D-MPNN baseline (chemprop) on 6 out of 8 MoleculeNet tasks, though the margins demonstrate that task-specific baselines remain notably robust.</p>
<p><strong>MTR superiority</strong>: Models pretrained on Multi-Task Regression (MTR) consistently perform better on downstream tasks than those pretrained on MLM on every finetuning task evaluated. MTR is substantially slower than MLM due to the larger input size from the 200-element label vector, but MLM loss serves as a reliable proxy for MTR loss, enabling cheaper architecture search before committing to full MTR pretraining.</p>
<p><strong>Scaling laws versus downstream utility</strong>: Pretraining loss improved by 25-35% when increasing the dataset from 5M to 77M compounds. However, this improvement in pretraining loss does not uniformly transfer to downstream tasks. For MTR models, SR-p53 ROC-AUC decreases monotonically from 0.834 (5M) to 0.827 (10M) to 0.817 (77M), and Lipophilicity RMSE is worse at 77M (0.798) than at 5M (0.758), despite a dip at 10M (0.744). This variability in transfer challenges the assumption that pretraining improvements always yield downstream gains.</p>
<p><strong>Transfer learning</strong>: The correlation between pretraining loss and downstream performance is task-dependent; it is strong for Lipophilicity but weaker for BACE classification.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The pretraining corpus is derived from <strong>PubChem</strong>.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Pretraining</strong></td>
          <td>PubChem</td>
          <td>77M SMILES</td>
          <td>Canonicalized and globally shuffled. Subsets of 5M and 10M used. <strong>Note: Exact splits and datasets are not published.</strong></td>
      </tr>
      <tr>
          <td><strong>Validation</strong></td>
          <td>PubChem</td>
          <td>100k SMILES</td>
          <td>A fixed set held out from the 77M corpus. <strong>Note: Exact 100k subset is not published.</strong></td>
      </tr>
      <tr>
          <td><strong>MTR Labels</strong></td>
          <td>RDKit</td>
          <td>200 props</td>
          <td>200 molecular properties calculated from SMILES using RDKit. Labels are mean-normalized. <strong>Note: Calculated labels are not published and must be re-computed.</strong></td>
      </tr>
      <tr>
          <td><strong>Finetuning</strong></td>
          <td>MoleculeNet</td>
          <td>1.5k - 8k</td>
          <td>Tasks: BACE, Clearance, Delaney, Lipophilicity, BBBP, ClinTox, HIV, Tox21. Split 80/10/10 via scaffold splitter.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pretraining Objectives:</strong></p>
<ol>
<li><strong>Masked Language Modeling (MLM)</strong>: Follows RoBERTa procedure. Masks 15% of tokens. Max sequence length 512.</li>
<li><strong>Multi-Task Regression (MTR)</strong>: Predicting 200 RDKit properties. Labels are mean-normalized.</li>
</ol>
<p><strong>Tokenizer:</strong></p>
<ul>
<li>Dictionary of common SMILES characters</li>
<li>Maximum vocabulary size: <strong>591 tokens</strong></li>
</ul>
<p><strong>Optimization:</strong></p>
<ul>
<li><strong>Patience</strong>: Early stopping set to one pass through the dataset to ensure full coverage</li>
<li><strong>Hyperparameter search</strong>: Random search (50 configs) varying hidden size, attention heads, dropout, intermediate size, hidden layers, and learning rate. <strong>Note: The precise configuration of the winning models that were scaled to 77M is absent from the paper.</strong></li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Based on <strong>RoBERTa</strong> (HuggingFace implementation)</li>
<li><strong>Parameter scale</strong>: Models ranged between <strong>5M and 46M parameters</strong></li>
<li><strong>Selection</strong>: Top 5 configurations from the 5M-dataset random search were trained on the full 77M dataset</li>
<li><strong>Checkpoints</strong>: Pre-trained weights are hosted by DeepChem on <a href="https://huggingface.co/DeepChem">Hugging Face</a>. Direct links include <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MTR">DeepChem/ChemBERTa-77M-MTR</a> and <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MLM">DeepChem/ChemBERTa-77M-MLM</a> (Note: Model cards are currently empty).</li>
<li><strong>Code Reference</strong>: While the <a href="https://github.com/deepchem/deepchem">DeepChem</a> repository is referenced for code, isolated training scripts tailored to recreate ChemBERTa-2&rsquo;s exact pipeline are not separated from the generalized deepchem library tooling.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Benchmarks were performed on <strong>MoleculeNet</strong> using DeepChem.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Tasks</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>RMSE</strong> ($\downarrow$)</td>
          <td>Delaney, Lipo, BACE (Reg), Clearance</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 outperformed D-MPNN on Delaney (0.889 vs 1.105) and Clearance (48.5 vs 49.8).</td>
      </tr>
      <tr>
          <td><strong>ROC-AUC</strong> ($\uparrow$)</td>
          <td>BBBP, ClinTox, HIV, Tox21, BACE (Cls)</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 generally competitive; MTR-77M achieved 0.728 on BBBP vs D-MPNN 0.697.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: AWS EC2 instances with <strong>Nvidia T4 GPUs</strong></li>
<li><strong>Strategy</strong>: AWS Spot instances were used to reduce cost; implemented frequent checkpointing to handle interruptions.</li>
<li><strong>Note</strong>: For MTR, they wrote a custom data loader wrapper around HuggingFace&rsquo;s text loader to handle CSV parsing efficiency, as the default CSV loader was a major bottleneck for the 200-element target vectors.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ahmad, W., Simon, E., Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2022). ChemBERTa-2: Towards Chemical Foundation Models. <em>arXiv preprint arXiv:2209.01712</em>. <a href="https://doi.org/10.48550/arXiv.2209.01712">https://doi.org/10.48550/arXiv.2209.01712</a></p>
<p><strong>Publication</strong>: arXiv 2022 (Presented at 2021 ELLIS ML for Molecule Discovery Workshop)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2010.09885">ChemBERTa-1 Paper</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{ahmadChemBERTa2ChemicalFoundation2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}: {{Towards Chemical Foundation Models}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-25}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemformer: A Pre-trained Transformer for Comp Chem</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemformer/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemformer/</guid><description>BART-based Transformer pre-trained on 100M molecules using self-supervision to accelerate convergence on chemical sequence tasks.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-methodological-classification">Paper Contribution and Methodological Classification</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It proposes an architecture adaptation (Chemformer based on BART) and a specific pre-training strategy (&ldquo;Combined&rdquo; masking and augmentation). The paper validates this method by benchmarking against established models on multiple tasks, including direct synthesis, retrosynthesis, and molecular optimization. It also includes a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution by making the pre-trained models and code available.</p>
<h2 id="motivation-computational-bottlenecks-in-cheminformatics">Motivation: Computational Bottlenecks in Cheminformatics</h2>
<p>Existing Transformer models for cheminformatics are often developed for single applications and are computationally expensive to train from scratch. For example, training a Molecular Transformer for reaction prediction can take days, limiting hyperparameter exploration. Self-supervised pre-training (like BERT or T5) has significantly advanced NLP by reducing fine-tuning time and improving performance. In chemistry, applications have traditionally focused on task-specific datasets or encoder-only architectures, which perform poorly on sequence generation tasks. The authors aim to use transfer learning on a large unlabelled dataset to create a model that converges quickly and performs well across diverse sequence-to-sequence and discriminative tasks.</p>
<h2 id="core-innovation-bart-architecture-and-combined-pre-training">Core Innovation: BART Architecture and Combined Pre-training</h2>
<p>The primary insight lies in the adaptation of the <strong>BART architecture</strong> for chemistry and the introduction of a <strong>&ldquo;Combined&rdquo; self-supervised pre-training task</strong>.</p>
<ul>
<li><strong>Architecture</strong>: Chemformer uses the BART encoder-decoder structure, allowing it to handle both discriminative (property prediction) and generative (reaction prediction) tasks efficiently. This provides an alternative to encoder-only (BERT) or decoder-only (GPT) models.</li>
<li><strong>Combined Pre-training</strong>: The authors introduce a task that applies both <strong>Span Masking</strong> (randomly replacing tokens with <code>&lt;mask&gt;</code>) and <strong>SMILES Augmentation</strong> (permuting atom order) simultaneously. Formally, given a canonical SMILES sequence $x$, a corrupted sequence $\tilde{x} = \text{Mask}(\text{Augment}(x))$ is generated. The model is trained using an autoregressive cross-entropy loss to reconstruct the canonical sequence from the corrupted input:
$$ \mathcal{L}_{\text{pre-train}} = -\sum_{t=1}^{|x|} \log P(x_t \mid x_{&lt;t}, \tilde{x}) $$</li>
<li><strong>Tunable Augmentation</strong>: A downstream augmentation strategy is proposed where the probability of augmenting the input/output SMILES ($p_{aug}$) is a tunable hyperparameter, performed on-the-fly.</li>
</ul>
<h2 id="experimental-setup-and-pre-training-tasks">Experimental Setup and Pre-training Tasks</h2>
<p>The authors pre-trained Chemformer on <strong>100 million molecules</strong> from ZINC-15 and fine-tuned it on three distinct task types:</p>
<ol>
<li><strong>Seq2Seq Reaction Prediction</strong>:
<ul>
<li><em>Direct Synthesis</em>: USPTO-MIT dataset (Mixed and Separated).</li>
<li><em>Retrosynthesis</em>: USPTO-50K dataset.</li>
</ul>
</li>
<li><strong>Molecular Optimization</strong>: Generating molecules with improved properties (LogD, solubility, clearance) starting from ChEMBL matched molecular pairs.</li>
<li><strong>Discriminative Tasks</strong>:
<ul>
<li><em>QSAR</em>: Predicting properties (ESOL, FreeSolv, Lipophilicity) from MoleculeNet.</li>
<li><em>Bioactivity</em>: Predicting pXC50 values for 133 genes using ExCAPE data.</li>
</ul>
</li>
</ol>
<p>Ablation studies compared three pre-training strategies (Masking, Augmentation, Combined) against a randomly initialized baseline.</p>
<h2 id="results-trade-offs-and-conclusions">Results, Trade-offs, and Conclusions</h2>
<ul>
<li><strong>Performance</strong>: Chemformer achieved <strong>competitive top-1 accuracy</strong> on USPTO-MIT (91.3% Mixed) and USPTO-50K (53.6-54.3%), outperforming the Augmented Transformer and graph-based models (GLN, GraphRetro).</li>
<li><strong>Convergence Speed</strong>: Pre-training significantly accelerated training; fine-tuning for just 20 epochs (30 mins) outperformed the previous baselines trained for significantly longer.</li>
<li><strong>Pre-training Tasks</strong>: The &ldquo;Combined&rdquo; task generally performed best for reaction prediction and bioactivity, while &ldquo;Masking&rdquo; was superior for molecular optimization.</li>
<li><strong>Augmentation Trade-off</strong>: The augmentation strategy improved top-1 accuracy but significantly degraded top-5/10 accuracy because beam search outputs became populated with augmented versions of the same molecule. This presents a considerable limitation for practical applications like retrosynthesis mapping, where retrieving a diverse set of candidate reactions is often critical.</li>
<li><strong>Discriminative Evaluation Caveats</strong>: Chemformer underperformed specialized baselines (like D-MPNN or MolBERT) on small discriminative datasets. The authors note that direct comparison is difficult: Chemformer was trained simultaneously on multiple subtasks (multi-task learning), while the literature baselines were trained and tuned on each subtask separately. Additionally, the Chemformer encoder uses fewer than 20M parameters compared to MolBERT&rsquo;s approximately 85M, and Chemformer&rsquo;s pre-training does not include molecular property objectives.</li>
<li><strong>Pre-training Data Scope</strong>: The 100M pre-training dataset from ZINC-15 was selected with constraints on molecular weight ($\le 500$ Da) and LogP ($\le 5$), focusing the learned representations on small, drug-like molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><em>Note: The primary GitHub repository for Chemformer was officially archived on February 11, 2026. Pre-trained weights and datasets used in the paper are still hosted externally on <a href="https://az.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Box</a>. Active development of Chemformer models has moved to the <a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels</a> repository.</em></p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/Chemformer">Chemformer (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Archived; original PyTorch implementation</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Active successor repository</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://az.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Pre-trained weights (Box)</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Base and Large model checkpoints</td>
      </tr>
  </tbody>
</table>
<p>The following datasets were used for pre-training and benchmarking.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Pre-training</strong></td>
          <td style="text-align: left">ZINC-15</td>
          <td style="text-align: left">100M</td>
          <td style="text-align: left">Selected subset (reactive, annotated purchasability, MW $\le 500$, LogP $\le 5$). Split: 99% Train / 0.5% Val / 0.5% Test.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Direct Synthesis</strong></td>
          <td style="text-align: left">USPTO-MIT</td>
          <td style="text-align: left">~470k</td>
          <td style="text-align: left">Evaluated on &ldquo;Mixed&rdquo; and &ldquo;Separated&rdquo; variants.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Retrosynthesis</strong></td>
          <td style="text-align: left">USPTO-50K</td>
          <td style="text-align: left">~50k</td>
          <td style="text-align: left">Standard benchmark for retrosynthesis.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Optimization</strong></td>
          <td style="text-align: left">ChEMBL MMPs</td>
          <td style="text-align: left">~160k Train</td>
          <td style="text-align: left">Matched Molecular Pairs for LogD, solubility, and clearance optimization.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Properties</strong></td>
          <td style="text-align: left">MoleculeNet</td>
          <td style="text-align: left">Small</td>
          <td style="text-align: left">ESOL (1128), FreeSolv (642), Lipophilicity (4200).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Bioactivity</strong></td>
          <td style="text-align: left">ExCAPE</td>
          <td style="text-align: left">~312k</td>
          <td style="text-align: left">133 gene targets; &gt;1200 compounds per gene.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Tokenization</strong>: Regex-based tokenization (523 tokens total) derived from ChEMBL 27 canonical SMILES.</li>
<li><strong>Augmentation</strong>: SMILES enumeration (permuting atom order) used for pre-training and on-the-fly during fine-tuning ($p_{aug}=0.5$ for Seq2Seq, $p_{aug}=1.0$ for discriminative).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pre-training Tasks</strong>:
<ol>
<li><em>Masking</em>: Span masking (BART style).</li>
<li><em>Augmentation</em>: Input is a randomized SMILES; target is canonical SMILES.</li>
<li><em>Combined</em>: Input is augmented <em>then</em> masked; target is canonical SMILES.</li>
</ol>
</li>
<li><strong>Optimization</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$).</li>
<li>Schedule: Linear warm-up (8000 steps) for pre-training; One-cycle schedule for fine-tuning.</li>
</ul>
</li>
<li><strong>Inference</strong>: Beam search with width 10 for Seq2Seq tasks. Used <code>molbart/inference_score.py</code> and <code>molbart/retrosynthesis/round_trip_inference.py</code> for standard and round-trip validation.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two model sizes were trained. Both use the Pre-Norm Transformer layout with GELU activation.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hyperparameter</th>
          <th style="text-align: left">Chemformer (Base)</th>
          <th style="text-align: left">Chemformer-Large</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Layers</strong></td>
          <td style="text-align: left">6</td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Model Dimension</strong></td>
          <td style="text-align: left">512</td>
          <td style="text-align: left">1024</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Feed-forward Dim</strong></td>
          <td style="text-align: left">2048</td>
          <td style="text-align: left">4096</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Attention Heads</strong></td>
          <td style="text-align: left">8</td>
          <td style="text-align: left">16</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Parameters</strong></td>
          <td style="text-align: left">~45M</td>
          <td style="text-align: left">~230M</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Pre-training Task</strong></td>
          <td style="text-align: left">All 3 variants</td>
          <td style="text-align: left">Combined only</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>Comparisons relied on Top-N accuracy for reaction tasks and validity metrics for optimization.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Key Result</th>
          <th style="text-align: left">Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Direct Synthesis (Sep)</td>
          <td style="text-align: left"><strong>92.8%</strong> (Large)</td>
          <td style="text-align: left">91.1% (Aug Transformer)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Retrosynthesis</td>
          <td style="text-align: left"><strong>54.3%</strong> (Large)</td>
          <td style="text-align: left">53.7% (GraphRetro) / 52.5% (GLN)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Desirable %</strong></td>
          <td style="text-align: left">Mol Optimization</td>
          <td style="text-align: left"><strong>75.0%</strong> (Base-Mask)</td>
          <td style="text-align: left">70.2% (Transformer-R)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>RMSE</strong></td>
          <td style="text-align: left">Lipophilicity</td>
          <td style="text-align: left">0.598 (Combined)</td>
          <td style="text-align: left">0.555 (D-MPNN)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 NVIDIA V100 GPUs (batch size 128 per GPU).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Pre-training: 2.5 days (Base) / 6 days (Large) for 1M steps.</li>
<li>Fine-tuning: ~20-40 epochs for reaction prediction (&lt;12 hours).</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Irwin, R., Dimitriadis, S., He, J., &amp; Bjerrum, E. J. (2022). Chemformer: a pre-trained transformer for computational chemistry. <em>Machine Learning: Science and Technology</em>, 3(1), 015022. <a href="https://doi.org/10.1088/2632-2153/ac3ffb">https://doi.org/10.1088/2632-2153/ac3ffb</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{irwinChemformerPretrainedTransformer2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Chemformer: A Pre-Trained Transformer for Computational Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Chemformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Irwin, Ross and Dimitriadis, Spyridon and He, Jiazhen and Bjerrum, Esben Jannik}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{015022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2632-2153}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/2632-2153/ac3ffb}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa: Molecular Property Prediction via Transformers</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemberta/</guid><description>A systematic evaluation of RoBERTa transformers pretrained on 77M PubChem SMILES for molecular property prediction tasks.</description><content:encoded><![CDATA[<h2 id="taxonomy-and-paper-contributions">Taxonomy and Paper Contributions</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$), with a significant <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>It is a methodological investigation because it systematically evaluates a specific architecture (Transformers/RoBERTa) against established State-of-the-Art (SOTA) baselines like directed Message Passing Neural Networks (D-MPNNs) to determine &ldquo;how well does this work?&rdquo; in the chemical domain. It ablates dataset size, tokenization, and input representation.</p>
<p>It is also a resource paper as it introduces &ldquo;PubChem-77M,&rdquo; a curated dataset of 77 million SMILES strings designed to facilitate large-scale self-supervised pretraining for the community.</p>
<h2 id="overcoming-data-scarcity-in-property-prediction">Overcoming Data Scarcity in Property Prediction</h2>
<p>The primary motivation is <strong>data scarcity</strong> in molecular property prediction. Graph Neural Networks (GNNs) achieve strong performance on property prediction tasks when provided with sufficient labeled data. Generating these labels requires costly and time-consuming laboratory testing, leading to severe data scarcity in specialized chemical domains.</p>
<p>Massive quantities of <strong>unlabeled chemical structure data</strong> exist in the form of SMILES strings. Inspired by the success of Transformers in NLP, where self-supervised pretraining on large corpora yields strong transfer learning, the authors aim to use these unlabeled datasets to learn effective molecular representations. Additionally, Transformers benefit from a mature software ecosystem (HuggingFace) that offers efficiency advantages over GNNs.</p>
<h2 id="pretraining-scaling-laws-and-novelty">Pretraining Scaling Laws and Novelty</h2>
<p>Previous works applied Transformers to SMILES strings. This paper advances the field by systematically evaluating scaling laws and architectural components for this domain. Specifically:</p>
<ul>
<li><strong>Scaling Analysis</strong>: It explicitly tests how pretraining dataset size (100K to 10M) impacts downstream performance.</li>
<li><strong>Tokenizer Comparison</strong>: It compares standard NLP Byte-Pair Encoding (BPE) against a chemically-aware &ldquo;SmilesTokenizer&rdquo;.</li>
<li><strong>Representation Comparison</strong>: It evaluates if the robust SELFIES string representation offers advantages over standard SMILES in a Transformer context.</li>
</ul>
<h2 id="experimental-setup-pretraining-and-finetuning">Experimental Setup: Pretraining and Finetuning</h2>
<p>The authors trained <strong>ChemBERTa</strong> (based on RoBERTa) using Masked Language Modeling (MLM) on subsets of the PubChem dataset. The core training objective minimizes the cross-entropy loss over a corrupted input where a subset of basic tokens, denoted by $\mathcal{M}$, are masked:</p>
<p>$$
\mathcal{L}_{\text{MLM}} = - \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid x_{\setminus \mathcal{M}}; \theta)
$$</p>
<p>where $x_i$ is the exact masked token, $x_{\setminus \mathcal{M}}$ is the corrupted SMILES context string, and $\theta$ represents the network parameters.</p>
<ul>
<li><strong>Pretraining</strong>: Models were pretrained on dataset sizes of 100K, 250K, 1M, and 10M compounds.</li>
<li><strong>Baselines</strong>: Performance was compared against D-MPNN (Graph Neural Network), Random Forest (RF), and SVM using 2048-bit Morgan Fingerprints.</li>
<li><strong>Downstream Tasks</strong>: Finetuning was performed individually on small MoleculeNet classification tasks: BBBP (blood-brain barrier), ClinTox (clinical toxicity), HIV, and Tox21 (p53 stress-response). This poses a transfer learning challenge, as the model must adapt from pretraining on 10 million molecules to classifying datasets ranging from ~1.5K to ~41K examples.</li>
<li><strong>Ablations</strong>:
<ul>
<li><strong>Tokenization</strong>: BPE vs. SmilesTokenizer on the 1M dataset, evaluated on Tox21.</li>
<li><strong>Input</strong>: SMILES vs. SELFIES strings on the Tox21 task.</li>
</ul>
</li>
</ul>
<h2 id="results-vs-graph-neural-network-baselines">Results vs. Graph Neural Network Baselines</h2>
<p>The main comparison between ChemBERTa (pretrained on 10M compounds) and Chemprop baselines on MoleculeNet tasks is summarized below (Table 1 from the paper):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP ROC</th>
          <th>BBBP PRC</th>
          <th>ClinTox ROC</th>
          <th>ClinTox PRC</th>
          <th>HIV ROC</th>
          <th>HIV PRC</th>
          <th>Tox21 ROC</th>
          <th>Tox21 PRC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemBERTa 10M</td>
          <td>0.643</td>
          <td>0.620</td>
          <td>0.733</td>
          <td>0.975</td>
          <td>0.622</td>
          <td>0.119</td>
          <td>0.728</td>
          <td>0.207</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.708</td>
          <td>0.697</td>
          <td>0.906</td>
          <td>0.993</td>
          <td>0.752</td>
          <td>0.152</td>
          <td>0.688</td>
          <td>0.429</td>
      </tr>
      <tr>
          <td>RF</td>
          <td>0.681</td>
          <td>0.692</td>
          <td>0.693</td>
          <td>0.968</td>
          <td>0.780</td>
          <td>0.383</td>
          <td>0.724</td>
          <td>0.335</td>
      </tr>
      <tr>
          <td>SVM</td>
          <td>0.702</td>
          <td>0.724</td>
          <td>0.833</td>
          <td>0.986</td>
          <td>0.763</td>
          <td>0.364</td>
          <td>0.708</td>
          <td>0.345</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Scaling Improvements &amp; Training Dynamics</strong>: Performance scales predictably with pretraining data size. Increasing data from 100K to 10M improved ROC-AUC by +0.110 and PRC-AUC by +0.059 on average across BBBP, ClinTox, and Tox21 (HIV was omitted due to resource constraints). Notably, researchers had to halt pretraining on the 10M subset after just 3 epochs due to overfitting, suggesting that simple 15% token masking might not provide a sufficiently difficult learning curvature for large-scale chemical representation.</li>
<li><strong>Performance Limits vs. GNNs</strong>: ChemBERTa generally performs below the D-MPNN baseline. On the Tox21 dataset, ChemBERTa-10M achieved a higher ROC-AUC (0.728) than D-MPNN (0.688); nonetheless, it recorded a substantially lower PRC-AUC (0.207 vs 0.429). This gap indicates that current Transformer iterations lack the explicit inductive biases of graph algorithms and struggle with the severe class imbalances typical of chemical datasets.</li>
<li><strong>Ablation Limitations (Tokenization &amp; SELFIES)</strong>: The authors&rsquo; ablation studies for tokenization (SmilesTokenizer narrowly beating BPE) and input representation (SELFIES performing comparably to SMILES) were evaluated exclusively on the single Tox21 task. Deriving broad architectural conclusions regarding &ldquo;semantically-aware tokenization&rdquo; or string robustness from an $N=1$ empirical evaluation is a significant limitation of the study. Broader benchmarking is required to validate these findings.</li>
<li><strong>Interpretability</strong>: Attention heads organically learn to track chemically relevant substructures (like specific functional groups and aromatic rings), mimicking the inductive biases of graph convolutions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors curated a massive dataset for pretraining and utilized standard benchmarks for evaluation.</p>
<ul>
<li><strong>Pretraining Data</strong>: <strong>PubChem-77M</strong>.
<ul>
<li>Source: 77 million unique SMILES from PubChem.</li>
<li>Preprocessing: Canonicalized and globally shuffled.</li>
<li>Subsets used: 100K, 250K, 1M, and 10M subsets.</li>
<li><em>Availability Note</em>: The authors provided a direct link to the <a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">canonicalized 10M compound subset</a> used for their largest experiments. Full reproducibility of the smaller (100K, 250K, 1M) or full 77M sets may require re-extracting from PubChem.</li>
</ul>
</li>
<li><strong>Evaluation Data</strong>: <strong>MoleculeNet</strong>.
<ul>
<li>Tasks: BBBP (2,039), ClinTox (1,478), HIV (41,127), Tox21 (7,831).</li>
<li>Splitting: 80/10/10 train/valid/test split using a <strong>scaffold splitter</strong> to ensure chemical diversity between splits.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The core training methodology mirrors standard BERT/RoBERTa procedures adapted for chemical strings.</p>
<ul>
<li><strong>Objective</strong>: Masked Language Modeling (MLM) with <strong>15% token masking</strong>.</li>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>BPE</strong>: Byte-Pair Encoder (vocab size 52K).</li>
<li><strong>SmilesTokenizer</strong>: Regex-based custom tokenizer available in DeepChem (documented <a href="https://deepchem.readthedocs.io/en/latest/tokenizers.html#smilestokenizer">here</a>).</li>
</ul>
</li>
<li><strong>Sequence Length</strong>: Maximum sequence length of <strong>512 tokens</strong>.</li>
<li><strong>Finetuning</strong>: Appended a linear classification layer; backpropagated through the base model for up to 25 epochs with early stopping on ROC-AUC.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: <strong>RoBERTa</strong> (via HuggingFace).
<ul>
<li>Layers: 6</li>
<li>Attention Heads: 12 (72 distinct mechanisms total).</li>
<li><em>Implementation Note</em>: The original training notebooks and scripts are maintained in the authors&rsquo; <a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry repository</a>, alongside the primary downstream tasks integrated into DeepChem. A <a href="https://github.com/deepchem/deepchem/blob/master/examples/tutorials/22_Transfer_Learning_With_HuggingFace_tox21.ipynb">full Tox21 transfer learning tutorial</a> has been incorporated into the DeepChem repository.</li>
</ul>
</li>
<li><strong>Baselines</strong> (via Chemprop library):
<ul>
<li><strong>D-MPNN</strong>: Directed Message Passing Neural Network with default hyperparameters.</li>
<li><strong>RF/SVM</strong>: Scikit-learn Random Forest and SVM using 2048-bit Morgan fingerprints (RDKit).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is measured using dual metrics to account for class imbalance common in toxicity datasets.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ROC-AUC</strong></td>
          <td>Area Under Receiver Operating Characteristic Curve</td>
      </tr>
      <tr>
          <td><strong>PRC-AUC</strong></td>
          <td>Area Under Precision-Recall Curve (vital for imbalanced data)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Single <strong>NVIDIA V100 GPU</strong>.</li>
<li><strong>Training Time</strong>: Approximately <strong>48 hours</strong> for the 10M compound subset.</li>
<li><strong>Carbon Footprint</strong>: Estimated 17.1 kg $\text{CO}_2\text{eq}$ (offset by Google Cloud).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training notebooks and finetuning scripts</td>
      </tr>
      <tr>
          <td><a href="https://github.com/deepchem/deepchem">DeepChem</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Integration of ChemBERTa and SmilesTokenizer</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">ChemBERTa-zinc-base-v1</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Pre-trained RoBERTa on 100K ZINC SMILES</td>
      </tr>
      <tr>
          <td><a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">PubChem-10M subset</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Canonicalized 10M compound subset used for largest experiments</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. Code and pre-trained models are available, and the 10M pretraining subset is downloadable. However, smaller subsets (100K, 250K, 1M) may need re-extraction from PubChem, and exact hyperparameter details for finetuning (learning rate, batch size) are not fully specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2020). ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction. <em>arXiv preprint arXiv:2010.09885</em>. <a href="https://doi.org/10.48550/arXiv.2010.09885">https://doi.org/10.48550/arXiv.2010.09885</a></p>
<p><strong>Publication</strong>: arXiv 2020 (Preprint)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">HuggingFace Model Hub (ChemBERTa-zinc-base-v1)</a> - <em>Additional pre-trained variations on PubChem &amp; ZINC datasets are available on the author&rsquo;s <a href="https://huggingface.co/seyonec">seyonec</a> HF profile.</em></li>
<li><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry GitHub Repository</a> - <em>Notebooks and scripts used for MLM pretraining and finetuning evaluations.</em></li>
</ul>
<h3 id="bibtex">BibTeX</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{chithranandaChemBERTaLargeScaleSelfSupervised2020,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa}}: {{Large-Scale Self-Supervised Pretraining}} for {{Molecular Property Prediction}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-24}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Score-Based Generative Modeling with SDEs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</guid><description>Unified SDE framework for score-based generative models, introducing Predictor-Corrector samplers and achieving SOTA on CIFAR-10 with FID 2.20.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper. It proposes a unified framework that generalizes previous discrete score-based models (SMLD and DDPM) into continuous-time Stochastic Differential Equations (SDEs). The paper introduces novel algorithms for sampling (Predictor-Corrector) and likelihood computation (Probability Flow ODE), validated by achieving state-of-the-art results. It also contains elements of <strong>Systematization</strong> by showing how existing methods are special cases of this broader framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Prior successful generative models, specifically Score Matching with Langevin Dynamics (SMLD) and Denoising Diffusion Probabilistic Models (DDPM), operate by sequentially corrupting data with slowly increasing noise and learning to reverse the process. Both methods treat the noise scales as a finite set of discrete steps. The authors aim to generalize this to a continuum of noise scales by modeling the diffusion process as a Stochastic Differential Equation (SDE). This continuous formulation enables:</p>
<ul>
<li><strong>Flexible sampling:</strong> Use of general-purpose SDE solvers.</li>
<li><strong>Exact likelihood computation:</strong> Via connection to Neural ODEs.</li>
<li><strong>Controllable generation:</strong> Solving inverse problems (inpainting, colorization) without retraining.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>SDE framework</strong> for score-based generative modeling:</p>
<ul>
<li><strong>Continuous Generalization:</strong> Proving that SMLD and DDPM noise perturbations correspond to discretizations of Variance Exploding (VE) SDEs and Variance Preserving (VP) SDEs, respectively.</li>
<li><strong>Reverse-Time SDE:</strong> Leveraging Anderson&rsquo;s result (Anderson, 1982: a result on time-reversal of diffusion processes showing that the reverse is also a diffusion, with the forward drift reversed and a correction term involving the score of the marginal density) that the reverse of a diffusion process is also a diffusion process, governed by the score (gradient of log density).</li>
<li><strong>Predictor-Corrector (PC) Samplers:</strong> A hybrid sampling strategy where a numerical SDE solver (Predictor) estimates the next step, and a score-based MCMC approach (Corrector) corrects the marginal distribution.</li>
<li><strong>Probability Flow ODE:</strong> Deriving a deterministic ODE that shares the same marginal densities as the SDE, enabling near-exact likelihood computation (accuracy is limited by both numerical ODE solver discretization and variance of the unbiased Hutchinson trace estimator) and latent space manipulation.</li>
<li><strong>Sub-VP SDE:</strong> A new SDE class proposed to improve likelihoods by bounding variance tighter than the VP SDE.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the framework on standard image benchmarks:</p>
<ul>
<li><strong>Datasets:</strong> CIFAR-10 (32x32), CelebA (64x64), LSUN (Bedroom, Church), and CelebA-HQ (256x256 and 1024x1024).</li>
<li><strong>Ablation Studies:</strong> Comparing samplers (Ancestral vs. Reverse Diffusion vs. Probability Flow vs. PC) and SDE types (VE, VP, sub-VP).</li>
<li><strong>Architecture Search:</strong> Exploring improvements like FIR up/downsampling, rescaling skip connections, and increasing depth (leading to NCSN++ and DDPM++ architectures).</li>
<li><strong>Likelihood Evaluation:</strong> Computing Negative Log-Likelihood (NLL) in bits/dim using the Probability Flow ODE.</li>
<li><strong>Inverse Problems:</strong> Testing class-conditional generation, inpainting, and colorization using the conditional reverse-time SDE.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Record Performance:</strong> The <strong>NCSN++ cont. (deep, VE)</strong> model achieved an Inception Score of 9.89 and FID of 2.20 on CIFAR-10 (as of ICLR 2021).</li>
<li><strong>High-Fidelity Generation:</strong> First score-based model to generate 1024x1024 images (CelebA-HQ).</li>
<li><strong>Competitive Likelihoods:</strong> The <strong>DDPM++ cont. (deep, sub-VP)</strong> model achieved 2.99 bits/dim on uniformly dequantized CIFAR-10, a record at the time.</li>
<li><strong>Sampling Efficiency:</strong> PC samplers consistently outperformed predictor-only methods (like standard ancestral sampling) for the same computational cost.</li>
<li><strong>Controllable Generation:</strong> Successful application to inpainting and colorization using a single unconditional model.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>CIFAR-10</strong>: Used for main benchmarking (FID, Inception Score, NLL).</li>
<li><strong>CelebA-HQ</strong>: Used for high-resolution experiments at 256x256 and 1024x1024.</li>
<li><strong>LSUN</strong>: Bedroom and Church Outdoor categories used for inpainting/colorization tests.</li>
<li><strong>Preprocessing</strong>: CIFAR-10 images are 32x32; CelebA pre-processed to 64x64 following Song &amp; Ermon (2020). Data is typically scaled to $[0, 1]$ or standardized depending on the specific SDE config.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Forward SDEs</strong>:</p>
<p>Here $dw$ denotes a Wiener process increment (a small, independent Gaussian noise burst at each timestep).</p>
<ul>
<li><strong>VE SDE (Variance Exploding)</strong>: $dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw$. Corresponds to SMLD. Used with $\sigma_{\min}=0.01$ and $\sigma_{\max}$ chosen via heuristics.</li>
<li><strong>VP SDE (Variance Preserving)</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)} dw$. Corresponds to DDPM.</li>
<li><strong>Sub-VP SDE</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)ds})} dw$. Bounded variance, good for likelihoods.</li>
</ul>
<p><strong>Reverse-Time SDE Solver (Predictor)</strong>:</p>
<ul>
<li>Discretized via <strong>Reverse Diffusion Sampling</strong>, which matches the forward discretization.</li>
<li><strong>Euler-Maruyama</strong> solver used for continuously-trained models.</li>
</ul>
<p><strong>Corrector Algorithm</strong>:</p>
<ul>
<li><strong>Langevin MCMC</strong>: Applies annealed Langevin dynamics: adds noise and takes a score-guided gradient step to correct the marginal distribution at each timestep.</li>
<li><strong>PC Sampling</strong>: Alternates between one step of the Predictor and one step of the Corrector.</li>
<li><strong>Signal-to-Noise Ratio ($r$)</strong>: A hyperparameter for the corrector step size. Tuned values: $r \approx 0.16$ for VE SDEs on CIFAR-10.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>NCSN++</strong>: Optimized architecture for VE SDEs. Key features:
<ul>
<li>4 residual blocks per resolution.</li>
<li>BigGAN-type residual blocks.</li>
<li>Rescaling skip connections by $1/\sqrt{2}$.</li>
<li>FIR (Finite Impulse Response) up/downsampling.</li>
<li>No progressive growing for output.</li>
</ul>
</li>
<li><strong>DDPM++</strong>: Optimized architecture for VP/sub-VP SDEs. Similar to NCSN++ but without FIR upsampling and no progressive growing.</li>
<li><strong>Deep Variants</strong>: &ldquo;cont. (deep)&rdquo; models double the depth (from 4 to 8 blocks per resolution) for SOTA results.</li>
<li><strong>Conditioning</strong>: Time $t$ is conditioned via random Fourier feature embeddings (scale 16) for continuous models.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>FID (Fréchet Inception Distance)</strong>: Computed on 50k samples.</li>
<li><strong>Inception Score</strong>: Reported for CIFAR-10.</li>
<li><strong>NLL (Negative Log-Likelihood)</strong>: Reported in bits/dim on uniformly dequantized data using the Probability Flow ODE.</li>
</ul>
<p><strong>Denoising</strong>: A single denoising step using Tweedie&rsquo;s formula is applied at the end of sampling to remove residual noise, which significantly improves FID.</p>
<h3 id="hardware">Hardware</h3>
<p><strong>Training</strong>:</p>
<ul>
<li>Batch size: 128 for CIFAR-10, 64 for LSUN, 8 for high-res CelebA-HQ.</li>
<li>Iterations: Models trained for ~1.3M iterations. High-res CelebA-HQ trained for 2.4M iterations.</li>
<li><strong>EMA</strong>: Exponential Moving Average rate of 0.999 used for VE models, 0.9999 for VP models.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., &amp; Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. <em>ICLR 2021</em>. <a href="https://arxiv.org/abs/2011.13456">https://arxiv.org/abs/2011.13456</a></p>
<p><strong>Publication</strong>: ICLR 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2021scorebased,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Score-Based Generative Modeling through Stochastic Differential Equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Sohl-Dickstein, Jascha and Kingma, Diederik P and Kumar, Abhishek and Ermon, Stefano and Poole, Ben}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://openreview.net/forum?id=PxTIG12RRHS}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/yang-song/score_sde">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/">Score Matching and Denoising Autoencoders</a></li>
</ul>
]]></content:encoded></item><item><title>Score Matching and Denoising Autoencoders</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</guid><description>Theoretical paper proving the equivalence between training Denoising Autoencoders and performing Score Matching on a Parzen density estimator.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Theory Paper</strong>.</p>
<p>Its primary contribution is a formal mathematical derivation connecting two previously distinct techniques: Score Matching (SM) and Denoising Autoencoders (DAE). It provides the &ldquo;why&rdquo; behind the empirical success of DAEs by grounding them in the probabilistic framework of energy-based models. It relies on proofs and equivalence relations (e.g., $J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$).</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper bridges a gap between two successful but disconnected approaches in unsupervised learning:</p>
<ol>
<li><strong>Denoising Autoencoders (DAE):</strong> Empirically successful for pre-training deep networks. They previously lacked a clear probabilistic interpretation.</li>
<li><strong>Score Matching (SM):</strong> A theoretically sound method for estimating unnormalized density models that avoids the partition function problem but requires computing expensive second derivatives.</li>
</ol>
<p>By connecting them, the authors aim to define a proper probabilistic model for DAEs (allowing sampling/ranking) and find a simpler way to apply score matching that avoids second derivatives.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Denoising Score Matching (DSM)</strong> framework and the proof of its equivalence to DAEs. Key contributions include:</p>
<ul>
<li><strong>Equivalence Proof:</strong> Showing that training a DAE with Gaussian noise is equivalent to matching the score of a model against a non-parametric Parzen density estimator of the data.</li>
<li><strong>Denoising Score Matching ($J_{DSM}$):</strong> A new objective that learns a score function by trying to denoise corrupted samples. This avoids the explicit second derivatives required by standard Implicit Score Matching ($J_{ISM}$).</li>
<li><strong>Explicit Energy Function:</strong> Deriving the specific energy function $E(x;\theta)$ that corresponds to the standard sigmoid DAE architecture.</li>
<li><strong>Justification for Tied Weights:</strong> Providing a theoretical justification for tying encoder and decoder weights, which arises naturally from differentiating the energy function.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation in this theoretical paper is purely mathematical and focuses on formal proofs:</p>
<ul>
<li><strong>Derivation of Equivalence:</strong> The paper formally proves the chain of equivalences:
$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$
where $q_{\sigma}$ is the Parzen density estimate.</li>
<li><strong>Appendix Proof:</strong> A detailed proof is provided to show that Explicit Score Matching ($J_{ESM}$) on the Parzen density is equivalent to the proposed Denoising Score Matching ($J_{DSM}$) objective.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Theoretical Unification:</strong> DAE training is formally equivalent to Score Matching on a smoothed data distribution ($q_{\sigma}$).</li>
<li><strong>New Training Objective:</strong> The $J_{DSM}$ objective offers a computationally efficient way to perform score matching (no Hessian required) by using a denoising objective.</li>
<li><strong>Probabilistic Interpretation:</strong> DAEs can now be understood as Energy-Based Models (EBMs), allowing for operations like sampling (via Hybrid Monte Carlo) and likelihood ranking, which were previously ill-defined for standard autoencoders.</li>
<li><strong>Regularization Insight:</strong> The smoothing kernel width $\sigma$ in the Parzen estimator corresponds to the noise level in the DAE. This suggests that DAEs are learning a regularized version of the score, which may explain their robustness.</li>
</ul>
<hr>
<h2 id="key-concepts-explained">Key Concepts Explained</h2>
<h3 id="1-score-and-score-matching">1. &ldquo;Score&rdquo; and &ldquo;Score Matching&rdquo;</h3>
<p><strong>What does &ldquo;score&rdquo; actually mean?</strong></p>
<p>In this paper (and probabilistic modeling generally), the <strong>score</strong> is the gradient of the log-density with respect to the <em>data vector</em> $x$.</p>
<ul>
<li><strong>Definition:</strong> $\psi(x) = \nabla_x \log p(x)$.</li>
<li><strong>Intuition:</strong> It is a vector field pointing in the direction of highest probability increase. Crucially, calculating the score avoids the intractable partition function $Z$, because $\nabla_x \log p(x) = \nabla_x \log \tilde{p}(x) - \nabla_x \log Z = \nabla_x \log \tilde{p}(x)$. The constant $Z$ vanishes upon differentiation.</li>
</ul>
<p><strong>What is Score Matching?</strong></p>
<p>Score Matching is a training objective for unnormalized models. It minimizes the squared Euclidean distance between the model&rsquo;s score $\psi(x;\theta)$ and the data&rsquo;s true score $\nabla_x \log q(x)$.</p>
<h3 id="2-the-parzen-density-estimator">2. The Parzen Density Estimator</h3>
<p><strong>What is it?</strong></p>
<p>It is a non-parametric method for estimating a probability density function from finite data. It places a smooth kernel (here, a Gaussian) centered at every data point in the training set $D_n$.</p>
<ul>
<li><strong>Formula:</strong> $q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n \mathcal{N}(\tilde{x}; x^{(t)}, \sigma^2 I)$.</li>
</ul>
<p><strong>Why smooth the data?</strong></p>
<ol>
<li>
<p><strong>To define the score:</strong> The empirical data distribution is a set of Dirac deltas (spikes). The gradient (score) of a Dirac delta is undefined. Smoothing creates a differentiable surface, allowing a valid target score $\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x})$ to be computed.</p>
</li>
<li>
<p><strong>To model corruption:</strong> The Parzen estimator with Gaussian kernels mathematically models the process of taking a clean data point $x$ and adding Gaussian noise - the exact procedure used in Denoising Autoencoders.</p>
</li>
</ol>
<h3 id="3-why-avoiding-second-derivatives-matters">3. Why avoiding second derivatives matters</h3>
<p>Standard <strong>Implicit Score Matching (ISM)</strong> eliminates the need for the unknown data score, but introduces a new cost: it requires computing the trace of the Hessian (the sum of second partial derivatives) of the log-density.</p>
<ul>
<li><strong>The Cost:</strong> For high-dimensional data (like images) and deep networks, computing second derivatives is computationally prohibitive ($O(d^2)$ operations per sample).</li>
<li>This paper shows that <strong>Denoising Score Matching (DSM)</strong> allows you to bypass Hessian computation entirely. By using the Parzen target, the objective simplifies to matching a first-order vector, making it scalable to deep neural networks.</li>
</ul>
<h3 id="4-the-equivalence-chain---why-each-step">4. The equivalence chain - why each step?</h3>
<p>The chain $J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ connects the concepts.</p>
<ul>
<li>
<p><strong>$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}}$ (Implicit $\to$ Explicit):</strong>
<strong>Why:</strong> Integration by parts. This is Hyvärinen&rsquo;s original proof (2005): integration by parts moves the derivative from $\psi$ onto the data density $q$, producing a term involving $q$&rsquo;s gradient (the score). The boundary term vanishes because $q_{\sigma}$ decays to zero at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching). The result allows replacing the unknown data score with a computable term involving only the model&rsquo;s score and its Jacobian.</p>
</li>
<li>
<p><strong>$J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$ (Explicit $\to$ Denoising):</strong>
<strong>Why:</strong> The explicit score of the Parzen density is known. When $x$ is perturbed to $\tilde{x}$ by Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$, the gradient of the log-density pointing back to the mean is exactly $\frac{1}{\sigma^2}(x - \tilde{x})$. Minimizing the error against the true score becomes minimizing the error against this restoration vector.</p>
</li>
<li>
<p><strong>$J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ (Denoising $\to$ Autoencoder):</strong>
<strong>Why:</strong> Algebraic substitution. If you define the model&rsquo;s score $\psi(\tilde{x};\theta)$ to be proportional to the reconstruction error ($\propto x^r - \tilde{x}$), the score matching loss $J_{DSM}$ becomes proportional to the standard autoencoder squared loss $|x^r - x|^2$.</p>
</li>
</ul>
<h3 id="5-energy-based-models-ebms-connection">5. Energy-Based Models (EBMs) connection</h3>
<p><strong>What is an EBM?</strong></p>
<p>An EBM defines a probability distribution via an energy function $E(x;\theta)$, where $p(x;\theta) \propto e^{-E(x;\theta)}$.</p>
<p><strong>Why standard autoencoders lack probabilistic interpretation:</strong></p>
<p>A standard autoencoder acts as a deterministic map $x \to x^r$, providing a reconstruction error. It lacks a normalization constant or a defined density function to support sampling or probability queries.</p>
<p><strong>What does this enable?</strong></p>
<p>By proving the equivalence, the DAE is formally defined as an EBM. This enables:</p>
<ol>
<li><strong>Sampling:</strong> Using MCMC methods (like Hybrid Monte Carlo) to generate new data from the DAE.</li>
<li><strong>Ranking:</strong> Calculating the energy of inputs to determine which are more &ldquo;likely&rdquo; or &ldquo;normal&rdquo; (useful for anomaly detection).</li>
</ol>
<h3 id="6-the-specific-energy-function-form">6. The specific energy function form</h3>
<p>The function is:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p><strong>Why does it have that specific form?</strong></p>
<p>It was derived via integration to ensure its derivative matches the DAE architecture. The authors worked backward from the DAE&rsquo;s reconstruction function (sigmoid + linear) to find the scalar field that generates it.</p>
<p><strong>Where does the quadratic term come from?</strong></p>
<p>The score (derivative) needs to look like $\psi(x) \propto c + x + f(Wx + b)$.</p>
<ul>
<li>The term $+x$ (i.e., $\nabla_x (\frac{1}{2}|x|^2) = x$) in the derivative must come from a $-\frac{1}{2}|x|^2$ term in the energy potential.</li>
</ul>
<p><strong>How does differentiating it recover the DAE reconstruction?</strong></p>
<ul>
<li>$\nabla_x \sum_j \text{softplus}(\langle W_j, x \rangle + b_j) = W^T \sigma(Wx + b)$ (The encoder part).</li>
<li>$\nabla_x \langle c, x \rangle = c$ (The bias).</li>
<li>$\nabla_x (-\frac{1}{2}|x|^2) = -x$ (The input subtraction).</li>
<li>Result: $-\nabla_x E \propto c + W^T h - x = x^r - x$.</li>
</ul>
<h3 id="7-tied-weights-justification">7. &ldquo;Tied weights&rdquo; justification</h3>
<p><strong>What does it mean for weights to be &ldquo;tied&rdquo;?</strong></p>
<p>The decoder matrix is the transpose of the encoder matrix ($W^T$).</p>
<p><strong>Why is this theoretically justified?</strong></p>
<p>Because the reconstruction function is interpreted as the <strong>gradient</strong> of an energy function. A vector field can only be the gradient of a scalar field if its Jacobian is symmetric.</p>
<ul>
<li>In the DAE energy derivative, the encoder contributes $W^T \sigma(Wx + b)$. If the decoder used a separate matrix $U$, the resulting vector field would not be a valid gradient of any scalar energy function (unless $U = W^T$).</li>
<li>Therefore, for a DAE to correspond to a valid probabilistic Energy-Based Model, the weights <em>must</em> be tied.</li>
</ul>
<p><strong>The necessity of tied weights:</strong></p>
<p>Within this parametrization, tied weights are a mathematical necessity: a separate decoder matrix $U \neq W^T$ would make the reconstruction function an invalid gradient of any scalar energy, breaking the EBM correspondence.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a theoretical paper, the &ldquo;reproducibility&rdquo; lies in the mathematical formulations derived.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Data ($D_n$):</strong> The theory assumes a set of training examples $D_n = {x^{(1)}, &hellip;, x^{(n)}}$ drawn from an unknown true pdf $q(x)$.</li>
<li><strong>Parzen Density Estimate ($q_{\sigma}$):</strong> The theoretical targets are derived from a kernel-smoothed empirical distribution:
$$q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n q_{\sigma}(\tilde{x}|x^{(t)})$$
where the kernel is an isotropic Gaussian of variance $\sigma^2$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Denoising Score Matching (DSM) Objective</strong></p>
<p>The paper proposes this objective as a tractable alternative to standard score matching. It minimizes the distance between the model score and the gradient of the log-noise density:</p>
<p>$$J_{DSMq_{\sigma}}(\theta) = \mathbb{E}_{q_{\sigma}(x,\tilde{x})} \left[ \frac{1}{2} \left| \psi(\tilde{x};\theta) - \frac{\partial \log q_{\sigma}(\tilde{x}|x)}{\partial \tilde{x}} \right|^2 \right]$$</p>
<p>For Gaussian noise, the target score is simply $\frac{1}{\sigma^2}(x - \tilde{x})$.</p>
<p><strong>2. Equivalence Chain</strong></p>
<p>The central result connects four objectives:</p>
<p>$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$</p>
<p>This implies optimizing the DAE reconstruction error is minimizing a score matching objective.</p>
<h3 id="models">Models</h3>
<p><strong>1. The Denoising Autoencoder (DAE)</strong></p>
<ul>
<li><strong>Corruption:</strong> Additive isotropic Gaussian noise $\tilde{x} = x + \epsilon, \epsilon \sim \mathcal{N}(0, \sigma^2 I)$.</li>
<li><strong>Encoder:</strong> $h = \text{sigmoid}(W\tilde{x} + b)$.</li>
<li><strong>Decoder:</strong> $x^r = W^T h + c$ (Tied weights $W$).</li>
<li><strong>Loss:</strong> Squared reconstruction error $\frac{1}{2\sigma^4} |x^r - x|^2$.</li>
</ul>
<p><strong>2. The Corresponding Energy Function</strong></p>
<p>To make the DAE equivalent to Score Matching, the underlying Energy-Based Model $p(x;\theta) \propto e^{-E(x;\theta)}$ must have the following energy function:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p>Note the scaling by $1/\sigma^2$ and the quadratic term $|x|^2$.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric:</strong> Theoretical Equivalence ($\sim$).</li>
<li><strong>Condition:</strong> The equivalence holds provided $\sigma &gt; 0$ and the density $q_{\sigma}$ is differentiable and vanishes at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Vincent, P. (2011). A Connection Between Score Matching and Denoising Autoencoders. <em>Neural Computation</em>, 23(7), 1661-1674. <a href="https://doi.org/10.1162/NECO_a_00142">https://doi.org/10.1162/NECO_a_00142</a></p>
<p><strong>Publication</strong>: Neural Computation 2011</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vincentConnectionScoreMatching2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Connection Between Score Matching}} and {{Denoising Autoencoders}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Vincent, Pascal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2011</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Neural Computation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1661--1674}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1162/NECO_a_00142}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="http://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Official PDF</a></li>
</ul>
]]></content:encoded></item><item><title>Neural ODEs: Continuous-Depth Deep Learning</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</guid><description>Introduces ODE-Nets, a continuous-depth neural network model parameterized by ODEs, enabling constant memory backpropagation and adaptive computation.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Key Prerequisites</strong>: Before diving in, note that for the ODE solver to guarantee a unique solution, the neural network $f(h(t), t, \theta)$ parameterizing the dynamics must be <a href="https://en.wikipedia.org/wiki/Lipschitz_continuity">Lipschitz continuous</a>. This ensures the <a href="https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem">Picard-Lindelöf theorem</a> holds, preventing trajectories from crossing and guaranteeing a well-defined backward pass.</p></blockquote>
<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel family of deep neural network models where the derivative of the hidden state is parameterized by a neural network. It provides specific algorithms (Algorithm 1) for training these models scalably.</li>
<li><strong>Theory</strong>: It derives the adjoint sensitivity method for backpropagating through black-box ODE solvers and proves the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1) for continuous normalizing flows.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors aim to address limitations in discrete deep learning architectures:</p>
<ul>
<li><strong>Discrete vs. Continuous</strong>: Existing models like Residual Networks build transformations by composing discrete steps, which can be seen as an Euler discretization of a continuous transformation. The authors investigate the limit as step sizes go to zero.</li>
<li><strong>Memory Efficiency</strong>: Backpropagating through deep discrete networks requires storing intermediate activations, leading to linear memory cost in terms of depth, which is a major bottleneck.</li>
<li><strong>Irregular Data</strong>: Recurrent Neural Networks (RNNs) struggle with data arriving at arbitrary times, typically requiring discretization into fixed bins.</li>
<li><strong>Normalizing Flow Costs</strong>: Standard normalizing flows have a bottleneck in computing the determinant of the Jacobian, which is computationally expensive ($O(D^3)$).</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core contribution is the <strong>Neural ODE</strong> formulation:
$$\frac{dh(t)}{dt} = f(h(t), t, \theta)$$
where the output is computed using a black-box differential equation solver.</p>
<p>Key technical innovations include:</p>
<ol>
<li><strong>Adjoint Sensitivity Method for Backprop</strong>: The authors treat the solver as a black box and compute gradients by solving a second, augmented ODE backwards in time. This allows for <strong>constant memory cost</strong> regardless of depth.</li>
<li><strong>Adaptive Computation</strong>: The model uses modern ODE solvers that adapt evaluation steps based on error tolerance, allowing the model to trade precision for speed explicitly.</li>
<li><strong>Continuous Normalizing Flows (CNF)</strong>: By moving to continuous time, the change of variables formula simplifies from a log-determinant (cubic cost) to a trace operation (linear cost), enabling scalable generative modeling.</li>
<li><strong>Latent ODEs</strong>: A generative time-series model that represents time-series as latent trajectories determined by a local initial state and global shared dynamics, handling irregular sampling naturally.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across three distinct domains:</p>
<ol>
<li><strong>Supervised Learning (MNIST)</strong>:
<ul>
<li>Compared <strong>ODE-Net</strong> against a standard <strong>ResNet</strong> and a Runge-Kutta network (<strong>RK-Net</strong>).</li>
<li>Measured test error, parameter count, and memory usage.</li>
<li>Analyzed the trade-off between numerical precision (tolerance) and speed (NFE).</li>
</ul>
</li>
<li><strong>Continuous Normalizing Flows (Generative)</strong>:
<ul>
<li>Compared CNF against standard Normalizing Flows (NF) on density estimation tasks using toy 2D datasets (Target, Two Circles, Two Moons).</li>
<li>Evaluated training loss (KL divergence) and maximum likelihood estimation.</li>
</ul>
</li>
<li><strong>Time-Series Modeling (Latent ODE)</strong>:
<ul>
<li>Tested on a dataset of bi-directional spirals with irregular timestamps and Gaussian noise.</li>
<li>Compared Latent ODEs against RNNs and RNNs with time-concatenation on predictive RMSE.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Efficiency</strong>: ODE-Nets achieved roughly equivalent accuracy to ResNets on MNIST (0.42% vs 0.41% error) but with <strong>constant memory cost</strong> ($O(1)$) compared to ResNet&rsquo;s linear cost ($O(L)$).</li>
<li><strong>Adaptive Depth</strong>: The number of function evaluations (NFE) in ODE-Nets increases with training epoch, suggesting the model adapts its complexity as it learns.</li>
<li><strong>Generative Performance</strong>: Continuous Normalizing Flows (CNF) achieved lower KL divergence loss than standard Normalizing Flows (NF), trained with only 10,000 iterations (Adam) compared to 500,000 iterations (RMSprop) for NF. The paper acknowledges the optimizer difference (Adam vs RMSprop) as a confound, but notes the iteration gap is large enough that CNF&rsquo;s efficiency advantage holds. CNF can also expand capacity by increasing width ($M$) without architectural constraints.</li>
<li><strong>Irregular Time-Series</strong>: Latent ODEs significantly outperformed RNNs across all observation counts on irregular spiral data. The advantage is most pronounced with sparse observations (0.1642 vs 0.3937 RMSE at 30 obs), and the model learns interpretable latent trajectories that switch direction smoothly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: Standard handwritten digit dataset used for supervised learning benchmarks.</li>
<li><strong>Toy 2D Densities</strong>: &ldquo;Two Circles&rdquo; and &ldquo;Two Moons&rdquo; distributions used for visualizing normalizing flows.</li>
<li><strong>Bi-directional Spirals</strong>: A generated dataset of 1,000 2D spirals (half clockwise, half counter-clockwise). Each spiral is sampled at 100 equally-spaced timesteps with added Gaussian noise. For training, each spiral is then subsampled without replacement to $n \in {30, 50, 100}$ irregularly-spaced observations, simulating realistic missing data.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Adjoint Sensitivity Method (Backpropagation)</strong></p>
<p>To optimize the parameters of the ODE-Net, the authors use the adjoint sensitivity method to compute gradients. Standard backpropagation would require storing the activations at every step of the ODE solver, incurring a high memory cost that scales linearly with the number of steps.</p>
<p>Instead, this method treats the ODE solver as a &ldquo;black box&rdquo; and computes gradients by solving a second, <strong>augmented ODE</strong> backwards in time from the final state $t_1$ to the initial state $t_0$.</p>
<p>The augmented state contains three components that are solved simultaneously:</p>
<ol>
<li><strong>The State</strong>: The original hidden state $z(t)$, which is reconstructed backwards.</li>
<li><strong>The Adjoint</strong>: The sensitivity of the loss with respect to the state, $a(t) = \partial L / \partial z(t)$.</li>
<li><strong>The Gradient</strong>: The accumulating gradients with respect to parameters, $\partial L / \partial \theta$.</li>
</ol>
<p>The dynamics of this augmented system are defined as:
$$\frac{d}{dt}\begin{bmatrix} z(t) \ a(t) \ \partial L/\partial \theta \end{bmatrix} = \begin{bmatrix} f(z(t), t, \theta) \ -a(t)^T \frac{\partial f}{\partial z} \ -a(t)^T \frac{\partial f}{\partial \theta} \end{bmatrix}$$</p>
<p>Using this approach, the vector-Jacobian products (e.g., $a(t)^T \frac{\partial f}{\partial z}$) are evaluated efficiently using automatic differentiation.</p>
<blockquote>
<p><strong>Why:</strong> Reconstructing $z(t)$ backwards avoids storing the forward pass, enabling <strong>constant memory cost</strong> ($O(1)$) regardless of depth.</p>
<p><strong>Origin:</strong> Adapted from Pontryagin&rsquo;s maximum principle (1962) for optimal control.</p></blockquote>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torchdiffeq <span style="color:#f92672">import</span> odeint_adjoint
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEFunc</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, dim):
</span></span><span style="display:flex;"><span>        super(ODEFunc, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>net <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(dim, <span style="color:#ae81ff">50</span>),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">50</span>, dim),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, t, y):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Defines dy/dt = f(y, t)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>net(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Usage with adjoint method for O(1) memory backprop</span>
</span></span><span style="display:flex;"><span>func <span style="color:#f92672">=</span> ODEFunc(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>y0 <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([[<span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">0.</span>]]) <span style="color:#75715e"># Initial state</span>
</span></span><span style="display:flex;"><span>t <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linspace(<span style="color:#ae81ff">0.</span>, <span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">10</span>) <span style="color:#75715e"># Time points to solve for</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># &#39;odeint_adjoint&#39; automatically handles the augmented state backward pass</span>
</span></span><span style="display:flex;"><span>out <span style="color:#f92672">=</span> odeint_adjoint(func, y0, t, method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;dopri5&#39;</span>)
</span></span></code></pre></div><p><strong>2. Instantaneous Change of Variables (CNF)</strong></p>
<p>For generative modeling, the authors introduce <strong>Continuous Normalizing Flows (CNF)</strong>. In discrete normalizing flows, the probability density of a transformed variable is calculated using the change of variables theorem, which requires computing the log-determinant of the Jacobian: $\log p(z_1) = \log p(z_0) - \log |\det \frac{\partial z_1}{\partial z_0}|$. This operation is computationally expensive ($O(D^3)$) and often restricts model architectures to ensure the Jacobian is easy to compute (e.g., triangular).</p>
<p>Moving to continuous time simplifies this requirement. The paper proves that if the transformation is defined by an ODE, the change in log-probability follows a differential equation determined by the <strong>trace</strong> of the Jacobian:
$$\frac{\partial \log p(z(t))}{\partial t} = -\text{tr}\left( \frac{\partial f}{\partial z(t)} \right)$$</p>
<p>The total change in log-density is obtained by integrating this value over time.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_trace</span>(y, f):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes trace of Jacobian df/dy.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    For high dimensions, use Hutchinson&#39;s trace estimator (approximate).
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    tr <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(y<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">1</span>)):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Gradients of f&#39;s i-th component w.r.t y&#39;s i-th component</span>
</span></span><span style="display:flex;"><span>        tr <span style="color:#f92672">+=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(f[:, i]<span style="color:#f92672">.</span>sum(), y, create_graph<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)[<span style="color:#ae81ff">0</span>][:, i]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> tr
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In the ODE function:</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># d(log_p)/dt = -trace(df/dy)</span>
</span></span></code></pre></div><blockquote>
<p><strong>Why:</strong> The trace operator has <strong>linear cost</strong> ($O(D)$), whereas the determinant has cubic cost ($O(D^3)$). This allows for unrestricted, &ldquo;wide&rdquo; architectures that are automatically bijective.</p>
<p><strong>Origin:</strong> This is the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1), derived in Appendix A of the paper.</p></blockquote>
<h3 id="models">Models</h3>
<p><strong>ODE-Net (MNIST Classification)</strong>:</p>
<ul>
<li><strong>Input</strong>: Downsamples input twice.</li>
<li><strong>Core</strong>: 6 standard residual blocks replaced by a single <strong>ODESolve</strong> module.</li>
<li><strong>Output</strong>: Global average pooling + Fully connected layer.</li>
<li><strong>Solver</strong>: Implicit Adams method.</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEBlock</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, odefunc):
</span></span><span style="display:flex;"><span>        super(ODEBlock, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>odefunc <span style="color:#f92672">=</span> odefunc
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>])<span style="color:#f92672">.</span>float()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x):
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>integration_time<span style="color:#f92672">.</span>type_as(x)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Returns [x(t0), x(t1)]; we only want final state x(t1)</span>
</span></span><span style="display:flex;"><span>        out <span style="color:#f92672">=</span> odeint_adjoint(self<span style="color:#f92672">.</span>odefunc, x, self<span style="color:#f92672">.</span>integration_time)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> out[<span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># ResNet-like architecture with ODE block</span>
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Conv2d(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">1</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>ReLU(inplace<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>),
</span></span><span style="display:flex;"><span>    ODEBlock(ODEFunc(<span style="color:#ae81ff">64</span>)), <span style="color:#75715e"># Continuous-depth layer replacement</span>
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>BatchNorm2d(<span style="color:#ae81ff">64</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>AdaptiveAvgPool2d((<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>)),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">10</span>)
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Latent ODE (Time-Series)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: RNN with 25 hidden units processing data backwards to produce $q(z_0|x)$. It runs backwards so the final RNN state summarizes the entire sequence at $t_0$, parameterizing the initial latent state $z_0$ for the forward-running ODE.</li>
<li><strong>Latent Space</strong>: 4-dimensional latent state $z_0$.</li>
<li><strong>Dynamics ($f$)</strong>: Neural network with one hidden layer of 20 units.</li>
<li><strong>Decoder</strong>: Neural network with one hidden layer of 20 units computing $p(x_{t_i}|z_{t_i})$.</li>
<li><strong>Likelihood</strong>: Gaussian log-likelihood for the spiral reconstruction task. The paper also describes an optional Poisson process likelihood $\lambda(z(t))$ for event-time data (e.g., medical records), but this is not used in the spiral experiment.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Metric</th>
          <th>Baseline (ResNet/RNN)</th>
          <th>ODE Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MNIST</td>
          <td>Test Error</td>
          <td>0.41%</td>
          <td>0.42%</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Parameters</td>
          <td>0.60 M</td>
          <td>0.22 M</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Memory</td>
          <td>$O(L)$</td>
          <td>$O(1)$</td>
      </tr>
      <tr>
          <td>Spirals (30 obs)</td>
          <td>RMSE</td>
          <td>0.3937</td>
          <td><strong>0.1642</strong></td>
      </tr>
      <tr>
          <td>Spirals (50 obs)</td>
          <td>RMSE</td>
          <td>0.3202</td>
          <td><strong>0.1502</strong></td>
      </tr>
      <tr>
          <td>Spirals (100 obs)</td>
          <td>RMSE</td>
          <td>0.1813</td>
          <td><strong>0.1346</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Implementation</strong>: Hidden state dynamics evaluated on GPU using <strong>TensorFlow</strong>.</li>
<li><strong>Solvers</strong>: Fortran ODE solvers (LSODE, VODE) from <code>scipy.integrate</code> were used for the actual integration.</li>
<li><strong>Note</strong>: While the original paper used TensorFlow/Scipy, the authors later released <code>torchdiffeq</code> (PyTorch), which has become the standard implementation for this architecture. The code samples above reflect this modern standard.</li>
<li><strong>Interface</strong>: Python&rsquo;s <code>autograd</code> framework bridged the TensorFlow dynamics and Scipy solvers.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, R. T. Q., Rubanova, Y., Bettencourt, J., &amp; Duvenaud, D. (2018). Neural ordinary differential equations. <em>Proceedings of the 32nd International Conference on Neural Information Processing Systems</em>, 6572-6583.</p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chen2018neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural ordinary differential equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 32nd International Conference on Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6572--6583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/rtqichen/torchdiffeq">Official PyTorch Implementation</a></li>
</ul>
]]></content:encoded></item><item><title>Flow Matching for Generative Modeling: Scalable CNFs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</guid><description>A simulation-free framework for training Continuous Normalizing Flows using Conditional Flow Matching and Optimal Transport paths.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, as it introduces &ldquo;Flow Matching&rdquo; (FM), a novel simulation-free paradigm for training Continuous Normalizing Flows (CNFs) at scale. It is supported by a strong <strong>Theory</strong> basis, providing formal theorems that allow the intractable marginal vector field regression to be solved via a tractable conditional objective. It also touches on <strong>Systematization</strong> by showing that existing diffusion paths are specific instances of the proposed Gaussian probability path framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper aims to overcome the scaling limitations of Continuous Normalizing Flows (CNFs).</p>
<ul>
<li><strong>Problem</strong>: Standard Maximum Likelihood training for CNFs requires expensive numerical ODE simulations during training, which scales poorly. Existing simulation-free methods often involve intractable integrals or result in biased gradients.</li>
<li><strong>Gap</strong>: Diffusion models scale well, yet they are restricted to specific, curved probability paths (e.g., VP, VE) that can result in slow sampling and long training times.</li>
<li><strong>Goal</strong>: To develop an efficient, simulation-free training method for CNFs that supports arbitrary probability paths, specifically allowing for straighter, more efficient trajectories like those from Optimal Transport.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is <strong>Flow Matching (FM)</strong> and specifically the <strong>Conditional Flow Matching (CFM)</strong> objective.</p>
<ul>
<li><strong>Direct Vector Field Regression</strong>: The model regresses a target vector field $u_t$ that generates a desired probability path $p_t$.</li>
<li><strong>Conditional Flow Matching (CFM)</strong>: The authors prove that regressing the vector field of <em>conditional</em> paths (e.g., $p_t(x|x_1)$ given a single data point) yields the same gradients as regressing the intractable marginal vector field. This bypasses the need to know the marginal score or vector field.</li>
<li><strong>Optimal Transport Paths</strong>: The framework enables the use of <strong>Optimal Transport (OT)</strong> displacement interpolation for probability paths. OT paths are straight lines with constant speed, leading to faster training and easier sampling.</li>
</ul>
<p><strong>Concurrent work note</strong>: Rectified Flow (Liu et al., 2023) and Stochastic Interpolants (Albergo &amp; Vanden-Eijnden, 2023) were published concurrently at ICLR 2023 with structurally similar contributions under different names. All three independently propose simulation-free training of continuous flows via direct vector field regression; the differences lie in the specific interpolation schemes, theoretical framing, and experimental focus. The related notes linked in the frontmatter cover both concurrent papers.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<ul>
<li><strong>Domains</strong>: 2D Checkerboard data, CIFAR-10, and ImageNet at resolutions $32 \times 32$, $64 \times 64$, and $128 \times 128$.</li>
<li><strong>Task</strong>: Unconditional generative modeling (density estimation and sample quality) and conditional super-resolution ($64 \to 256$).</li>
<li><strong>Baselines</strong>: Compared against Diffusion-based methods on the same architecture (U-Net): DDPM (Noise Matching), Score Matching (SM), and ScoreFlow.</li>
<li><strong>Ablations</strong>: Specifically compared <strong>FM with Diffusion paths</strong> vs. <strong>FM with Optimal Transport (OT) paths</strong> to isolate the benefit of the training objective vs. the path choice.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Outperforms diffusion baselines</strong>: FM-OT consistently outperforms all diffusion-based methods (DDPM, Score Matching, ScoreFlow) in both Likelihood (NLL) and Sample Quality (FID) across CIFAR-10 and ImageNet, using the same U-Net architecture and training budget. Selected rows from Table 1 (NLL in bits per dimension, BPD; lower is better for all three metrics; &ldquo;FM [w] / OT&rdquo; and &ldquo;FM [w] / Diffusion&rdquo; refer to FM trained with OT paths and Diffusion paths respectively):</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>NLL (BPD) ↓</th>
          <th>FID ↓</th>
          <th>NFE ↓</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>DDPM</td>
          <td>3.12</td>
          <td>7.48</td>
          <td>274</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>FM [w] / OT</td>
          <td><strong>2.99</strong></td>
          <td><strong>6.35</strong></td>
          <td><strong>142</strong></td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>ScoreFlow</td>
          <td>3.36</td>
          <td>24.95</td>
          <td>601</td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>FM [w] / OT</td>
          <td><strong>3.31</strong></td>
          <td><strong>14.45</strong></td>
          <td><strong>138</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Training stability</strong>: FM with diffusion paths (FM [w] / Diffusion) is itself a more stable alternative to diffusion training than DDPM and Score Matching, as shown by training curves in the paper (Figure 2), even before switching to OT paths. The OT path then provides further gains.</li>
<li><strong>Sampling speed</strong>: The straight trajectories of OT paths allow accurate sampling with significantly fewer function evaluations (NFE) compared to diffusion paths.</li>
<li><strong>Generality</strong>: Diffusion is a specific instance of Gaussian probability paths within FM. OT paths are a better-optimized alternative available within the same framework.</li>
<li><strong>Downstream adoption</strong>: Flow matching has been adopted beyond image generation. <a href="/notes/interdisciplinary/computational-biology/dynamicflow/">DynamicFlow</a> uses it as the generative backbone for simultaneously generating ligand molecules and transforming protein pockets, extending flow matching to structure-based drug design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Datasets</strong>: CIFAR-10, ImageNet ($32 \times 32$, $64 \times 64$, $128 \times 128$).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Images are center-cropped and resized.</li>
<li>For $32 \times 32$ and $64 \times 64$, the preprocessing follows Chrabaszcz et al. (2017).</li>
<li>Data is transformed via $\varphi(y) = 2^7(y+1)$ mapping $[-1, 1]$ pixel values to $[0, 256]$ for BPD computation.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Conditional Flow Matching (CFM) Objective</strong></p>
<p>The practical training objective used is the CFM loss, which bypasses intractable marginalization:</p>
<p>$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} | v_t(\psi_t(x_0)) - u_t(\psi_t(x_0) | x_1) |^2$$</p>
<p>Where $t \sim \mathcal{U}[0,1]$, $x_1 \sim q(x_1)$ (data), and $x_0 \sim p(x_0)$ (noise).</p>
<p><strong>2. Optimal Transport (OT) Probability Path</strong></p>
<p>The authors recommend the OT path for efficiency.</p>
<ul>
<li><strong>Mean/Std Schedule</strong>: $\mu_t(x) = t x_1$ and $\sigma_t(x) = 1 - (1 - \sigma_{min})t$.</li>
<li><strong>Conditional Flow Map</strong>: $\psi_t(x) = (1 - (1 - \sigma_{min})t)x + t x_1$.</li>
<li><strong>Target Vector Field</strong>: The closed-form regression target for OT is:
$$u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{min})x}{1 - (1 - \sigma_{min})t}$$</li>
</ul>
<p><strong>3. Sampling</strong></p>
<p>Sampling is performed by solving the ODE $\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))$ from $t=0$ to $t=1$ using the learned vector field $v_t$.</p>
<ul>
<li><strong>Solver</strong>: <code>dopri5</code> (adaptive) is used for robust evaluation. Fixed-step solvers (Euler, Midpoint) are used for low-NFE efficiency tests.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: U-Net architecture from Dhariwal &amp; Nichol (2021) is used for all image experiments.</li>
<li><strong>Toy Data</strong>: 5-layer MLP with 512 neurons.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$, weight decay=0.0).</li>
<li>Learning Rate: Polynomial decay or constant (see Table 3 in paper).</li>
<li>$\sigma_{min}$: Set to a small value (e.g., $1e-5$).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>NLL (BPD)</strong>: Computed using the continuous change of variables formula, estimated via the Hutchinson trace estimator to bypass $O(d^3)$ divergence computation.</li>
<li><strong>FID</strong>: Frechet Inception Distance for sample quality.</li>
<li><strong>NFE</strong>: Number of Function Evaluations required by the solver.</li>
</ul>
</li>
<li><strong>Likelihood Computation</strong>: Requires solving an augmented ODE to track the log-density change:
$$\frac{d}{dt} \begin{bmatrix} \phi_t(x) \ f(t) \end{bmatrix} = \begin{bmatrix} v_t(\phi_t(x)) \ -\text{div}(v_t(\phi_t(x))) \end{bmatrix}$$</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CIFAR-10</strong>: 2 GPUs.</li>
<li><strong>ImageNet-32</strong>: 4 GPUs.</li>
<li><strong>ImageNet-64</strong>: 16 GPUs.</li>
<li><strong>ImageNet-128</strong>: 32 GPUs.</li>
<li><strong>Precision</strong>: Full 32-bit for CIFAR/IM-32; 16-bit mixed precision for IM-64/128.</li>
</ul>
<hr>
<h2 id="theoretical-notes-why-cfm-works">Theoretical Notes: Why CFM Works</h2>
<p>The paper relies on three key theorems to make training tractable.</p>
<p><strong>Theorem 1 (Marginal Generation)</strong>:</p>
<p>Marginalizing conditional vector fields $u_t(x|x_1)$ yields the correct marginal vector field $u_t(x)$ that generates the marginal probability path $p_t(x)$.</p>
<p>$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$$</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>To understand why this theorem holds, we have to look at the <strong>Continuity Equation</strong>, which is the fundamental partial differential equation (PDE) that links a probability density path $p_t$ to a vector field $u_t$.</p>
<p>A vector field $u_t$ is said to &ldquo;generate&rdquo; a probability path $p_t$ if and only if they satisfy the continuity equation:</p>
<p>$$\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$$</p>
<p>The proof of Theorem 1 relies on substituting the definitions of the marginal path and vector field into this equation to see if they balance out.</p>
<p><strong>Step-by-Step Proof:</strong></p>
<ol>
<li>
<p><strong>Start with the time derivative of the marginal path</strong>: We begin by differentiating the marginal probability path $p_t(x)$ with respect to time. By definition, the marginal path is the integral of the conditional paths over the data distribution:
$$\frac{\partial p_t(x)}{\partial t} = \frac{\partial}{\partial t} \int p_t(x|x_1) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Swap derivative and integral</strong>: Assuming standard regularity conditions (Leibniz Rule), we can move the time derivative inside the integral:
$$\frac{\partial p_t(x)}{\partial t} = \int \frac{\partial p_t(x|x_1)}{\partial t} q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Apply the Conditional Continuity Equation</strong>: This is the critical step. We know that the conditional vector field $u_t(x|x_1)$ generates the conditional path $p_t(x|x_1)$. Therefore, for every single sample $x_1$, the pair satisfies the continuity equation:
$$\frac{\partial p_t(x|x_1)}{\partial t} = -\nabla \cdot (p_t(x|x_1) u_t(x|x_1))$$</p>
<p>Substituting this into our integral gives:
$$\frac{\partial p_t(x)}{\partial t} = -\int \nabla \cdot (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Pull the Divergence out</strong>: Since the divergence operator ($\nabla \cdot$) acts on $x$ and the integral is over $x_1$, we can pull the divergence operator outside the integral (by linearity):
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot \left( \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1 \right)$$</p>
</li>
<li>
<p><strong>Match with the Marginal Vector Field Definition</strong>: Now, look at the term inside the parentheses. The paper defines the marginal vector field $u_t(x)$ specifically to make this term simpler. Rearranging the definition of $u_t(x)$ provided in the theorem:
$$p_t(x) u_t(x) = \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1$$</p>
<p>Substitute $p_t(x) u_t(x)$ back into our equation from Step 4:
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot (p_t(x) u_t(x))$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: We have just shown that $\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$. This is exactly the continuity equation. Because the marginal path and the aggregated marginal vector field satisfy this equation, the vector field is proven to generate the path.</p></blockquote>
<p><strong>Theorem 2 (Gradient Equivalence)</strong>:</p>
<p>The intractable Flow Matching objective $\mathcal{L}_{FM}$ (which requires $u_t(x)$) has the <strong>same gradients</strong> as the tractable Conditional Flow Matching objective $\mathcal{L}_{CFM}$.</p>
<p>$$\nabla_\theta \mathcal{L}_{FM}(\theta) = \nabla_\theta \mathcal{L}_{CFM}(\theta)$$</p>
<p>This allows the model to learn the marginal vector field by only seeing conditional sample paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The reason Theorem 2 holds is that the &ldquo;Conditional Flow Matching&rdquo; (CFM) objective is essentially an unbiased estimator of the &ldquo;Flow Matching&rdquo; (FM) objective (up to a constant). When we average over all the conditional data points $x_1$, the &ldquo;cross-term&rdquo; in the loss function aligns perfectly with the marginal vector field.</p>
<p><strong>1. Expand the Loss Functions</strong></p>
<p>First, let&rsquo;s look at the squared error in both objectives. Recall that $v_t$ is our neural network (parameterized by $\theta$), $u_t$ is the intractable marginal target, and $u_t(x|x_1)$ is the tractable conditional target.</p>
<p>Expanding the squared norms:</p>
<ul>
<li>
<p><strong>FM Objective</strong>:
$$\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, p_t(x)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x) + |u_t(x)|^2 \right]$$</p>
</li>
<li>
<p><strong>CFM Objective</strong>:
$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x|x_1) + |u_t(x|x_1)|^2 \right]$$</p>
</li>
</ul>
<p><strong>Key Insight</strong>: When we take the gradient $\nabla_\theta$, the last term in both equations disappears because the targets ($u_t$) are independent of the network weights $\theta$. We only need to show that the expectations of the first two terms match.</p>
<p><strong>2. Matching the First Term ($|v_t(x)|^2$)</strong></p>
<p>This part is straightforward. The expectation of $|v_t(x)|^2$ is the same in both cases because of how the marginal density $p_t(x)$ is defined.</p>
<ul>
<li><strong>FM</strong>: averages over $p_t(x)$.</li>
<li><strong>CFM</strong>: averages over $p_t(x|x_1)q(x_1)$.</li>
</ul>
<p>Since $p_t(x) = \int p_t(x|x_1) q(x_1) dx_1$ (by definition), averaging over the joint distribution is mathematically identical to averaging over the marginal $p_t(x)$.</p>
<p><strong>3. Matching the Cross Term (The &ldquo;Trick&rdquo;)</strong></p>
<p>This is the critical part of the proof. We need to show that the interaction between the network and the marginal field equals the interaction between the network and the conditional field.</p>
<p><strong>The Goal</strong>: Show $\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$.</p>
<p><strong>The Proof</strong>:</p>
<ol>
<li>
<p>Start with the <strong>FM cross-term</strong> (marginal):
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)]$$</p>
</li>
<li>
<p>Substitute the definition of the marginal vector field $u_t(x)$ derived in <strong>Theorem 1</strong>:
$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1$$</p>
</li>
<li>
<p>Plug this into the integral. The $p_t(x)$ terms cancel:
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \int_t \int_x p_t(x) v_t(x) \cdot \left[ \int_{x_1} u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 \right] dx$$</p>
</li>
<li>
<p>This simplifies to:
$$= \int_t \int_x \int_{x_1} v_t(x) \cdot u_t(x|x_1) p_t(x|x_1) q(x_1) dx_1 dx dt$$</p>
</li>
<li>
<p>This is exactly the definition of the expectation in the <strong>CFM objective</strong>:
$$= \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: Because the expectations of all terms involving $\theta$ are identical, the gradients must be identical.</p>
<p>Intuitively, this works like <strong>Denoising Score Matching</strong> or <strong>Stochastic Gradient Descent</strong>: even though each individual conditional vector field $u_t(x|x_1)$ points to a specific data point $x_1$ (which may differ from the true marginal direction), the <em>average</em> of all these pulls equals the true marginal vector field $u_t(x)$.</p></blockquote>
<p><strong>Theorem 3 (Gaussian Conditional VFs)</strong>:</p>
<p>For any Gaussian probability path $p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$, the unique vector field generating it is available in closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p>This theorem allows explicitly defining targets for both Diffusion (curved) and Optimal Transport (straight) paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The derivation of Theorem 3 comes from the direct relationship between a flow map $\psi_t$ and its generating vector field. Because we chose a specific, simple path (Gaussian), we can invert the flow map to find the vector field in closed form.</p>
<p><strong>1. Define the Flow Map $\psi_t$</strong></p>
<p>We start by defining the conditional probability path as a Gaussian:</p>
<p>$$p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$$</p>
<p>The simplest way to &ldquo;push&rdquo; a standard normal distribution (noise) $p_0 = \mathcal{N}(0, I)$ to this Gaussian is using an affine transformation (scaling and shifting). We define the flow map $\psi_t$ as:</p>
<p>$$\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$$</p>
<p>This map takes a noise sample $x_0$ and transforms it into a sample $x$ at time $t$.</p>
<p><strong>2. The Definition of a Generating Vector Field</strong></p>
<p>By definition, a vector field $u_t$ generates a flow $\psi_t$ if the vector field describes the instantaneous velocity of the flow at any point. Mathematically:</p>
<p>$$u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$$</p>
<p>Let $x = \psi_t(x_0)$ be the position of the particle at time $t$. We want to find $u_t(x)$.</p>
<p><strong>3. Invert the Flow Map</strong></p>
<p>To find $u_t(x)$, we must express the equation in terms of $x$ rather than $x_0$. Since our flow map is a simple affine transformation (multiply and add), it is easily invertible (assuming $\sigma_t(x_1) \neq 0$):</p>
<p>$$x_0 = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}$$</p>
<p>We will call this inverse map $\psi_t^{-1}(x)$.</p>
<p><strong>4. Differentiate the Flow Map</strong></p>
<p>Now we calculate the left side of our definition equation (velocity): $\frac{d}{dt}\psi_t(x_0)$.</p>
<p>Taking the time derivative of $\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$:</p>
<p>$$\frac{d}{dt}\psi_t(x_0) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>(Note: $\sigma&rsquo;_t$ and $\mu&rsquo;_t$ denote time derivatives).</p>
<p><strong>5. Substitute and Solve</strong></p>
<p>Now we combine everything. We know $u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$.</p>
<p>Substitute the result from Step 4 into this equation:</p>
<p>$$u_t(\psi_t(x_0)) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>This expresses the vector field in terms of the initial point $x_0$. We must express it in terms of the current point $x$. So, we plug in the inverse formula for $x_0$ derived in Step 3:</p>
<p>$$u_t(x|x_1) = \sigma&rsquo;_t(x_1) \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \mu&rsquo;_t(x_1)$$</p>
<p>Rearranging terms gives the final closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p><strong>Why is this useful?</strong></p>
<p>This formula means that as long as you can define a mean schedule $\mu_t(x_1)$ and a standard deviation schedule $\sigma_t(x_1)$ (which is easy to do for both Diffusion and Optimal Transport), you immediately get the exact vector field target $u_t(x|x_1)$ needed to train your neural network, bypassing complex ODE solving or score matching approximations.</p></blockquote>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., &amp; Le, M. (2023). Flow Matching for Generative Modeling. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lipmanFlowMatchingGenerative2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow Matching for Generative Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2210.02747">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Building Normalizing Flows with Stochastic Interpolants</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</guid><description>A continuous-time normalizing flow using stochastic interpolants and quadratic loss to bypass costly ODE backpropagation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with significant <strong>Theory</strong> contributions.</p>
<p>The authors propose a specific algorithm (&ldquo;InterFlow&rdquo;) for constructing generative models based on continuous-time normalizing flows. The work is characterized by the derivation of a new training objective (a simple quadratic loss) that bypasses the computational bottlenecks of previous methods. It includes prominent baseline comparisons against continuous flow methods (FFJORD, OT-Flow) and diffusion models. The theoretical component establishes the validity of the interpolant density satisfying the continuity equation (a conservation law governing how probability mass flows) and bounds the Wasserstein-2 distance (a measure of transport cost between distributions, penalizing squared displacement) of the transport.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to overcome the computational inefficiency of training Continuous Normalizing Flows (CNFs) using Maximum Likelihood Estimation (MLE). Standard CNF training requires backpropagating through numerical ODE solvers, which is costly and limits scalability.</p>
<p>Additionally, while score-based diffusion models (SDEs) have achieved high sample quality, they theoretically require infinite time integration and rely on specific noise schedules. The authors aim to establish a method that works strictly with Probability Flow ODEs on finite time intervals, retaining the flexibility to connect arbitrary densities without the complexity of SDEs or the cost of standard ODE adjoint methods.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Stochastic Interpolant</strong> framework:</p>
<ul>
<li><strong>Explicit Interpolant Construction</strong>: The method defines a time-dependent interpolant $x_t = I_t(x_0, x_1)$ (e.g., trigonometric interpolation) that connects samples from the base density $\rho_0$ and target $\rho_1$.</li>
<li><strong>Simulation-Free Training</strong>: The velocity field $v_t(x)$ of the probability flow is learned by minimizing a simple quadratic objective: $G(\hat{v}) = \mathbb{E}[|\hat{v}_t(x_t)|^2 - 2\partial_t x_t \cdot \hat{v}_t(x_t)]$. Because $\partial_t I_t$ is known analytically from the interpolant definition, the expectation can be estimated by sampling $(x_0, x_1, t)$ directly. This avoids ODE integration during training (ODE integration is still required at inference).</li>
<li><strong>Decoupling Path and Optimization</strong>: The choice of path (interpolant) is separated from the optimization of the velocity field. MLE methods couple the path and objective.</li>
<li><strong>Connection to Score-Based Models</strong>: The authors show that for Gaussian base densities and trigonometric interpolants, the learned velocity field is explicitly related to the score function $\nabla \log \rho_t$, providing a theoretical bridge between CNFs and diffusion models.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors performed validation across synthetic, tabular, and image domains:</p>
<ul>
<li><strong>2D Density Estimation</strong>: Benchmarked on &ldquo;Checkerboard&rdquo;, &ldquo;8 Gaussians&rdquo;, and &ldquo;Spirals&rdquo; to visualize mode coverage and transport smoothness.</li>
<li><strong>High-Dimensional Tabular Data</strong>: Evaluated on standard benchmarks (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) comparing Negative Log Likelihood (NLL) against FFJORD, OT-Flow, and others.</li>
<li><strong>Image Generation</strong>: Trained models on CIFAR-10 ($32 \times 32$), ImageNet ($32 \times 32$), and Oxford Flowers ($128 \times 128$) to test scalability.</li>
<li><strong>Ablations</strong>: Investigated optimizing the interpolant path itself (e.g., learning Fourier coefficients for the path) to approach optimal transport and minimize path length.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Performance</strong>: The method matches or supersedes conventional ODE flows (like FFJORD) in terms of NLL while being significantly cheaper to train.</li>
<li><strong>Efficiency</strong>: The training cost per epoch is constant (simulation-free), whereas MLE-based ODE methods see growing costs as the dynamics become more complex.</li>
<li><strong>Scalability</strong>: The method successfully scales to $128 \times 128$ resolution on a single GPU, a resolution that prior ab-initio ODE flows had not demonstrated.</li>
<li><strong>Flexibility</strong>: The framework can connect <em>any</em> two arbitrary densities (e.g., connecting two different complex 2D distributions) without needing one to be Gaussian.</li>
<li><strong>Optimal Transport</strong>: For a fixed interpolant, minimizing $G(\hat{v})$ over the velocity field recovers the velocity for that specific path. Additionally optimizing over the interpolant family yields a solution to the Benamou-Brenier optimal transport problem.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Tabular Datasets</strong>: POWER (6D), GAS (8D), HEPMASS (21D), MINIBOONE (43D), BSDS300 (63D).
<ul>
<li>Training points range from ~30k (MINIBOONE) to ~1.6M (POWER).</li>
</ul>
</li>
<li><strong>Image Datasets</strong>:
<ul>
<li>CIFAR-10 ($32 \times 32$, 50k training points).</li>
<li>ImageNet ($32 \times 32$, ~1.28M training points).</li>
<li>Oxford Flowers ($128 \times 128$, ~315k training points).</li>
</ul>
</li>
<li><strong>Time Sampling</strong>: Time $t$ is sampled from a Beta distribution during training (reweighting) to focus learning near the target.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Interpolant</strong>: The primary interpolant used is trigonometric: $I_t(x_0, x_1) = \cos(\frac{\pi t}{2})x_0 + \sin(\frac{\pi t}{2})x_1$.
<ul>
<li>Alternative linear interpolant: $I_t = a_t x_0 + b_t x_1$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>:
$$G(\hat{v}) = \mathbb{E}_{t, x_0, x_1}[|\hat{v}_t(x_t)|^2 - 2\partial_t I_t(x_0, x_1) \cdot \hat{v}_t(x_t)]$$
<ul>
<li>The expectation is amenable to empirical estimation using batches of $x_0, x_1, t$.</li>
</ul>
</li>
<li><strong>Sampling</strong>: Numerical integration using Dormand-Prince (Runge-Kutta 4/5).</li>
<li><strong>Optimization</strong>: SGD/Adam variants used for optimization.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Tabular Architectures</strong>:
<ul>
<li>Feed-forward networks with 3-5 hidden layers.</li>
<li>Hidden widths: 512-1024 units.</li>
<li>Activation: ReLU (general) or ELU (BSDS300).</li>
</ul>
</li>
<li><strong>Image Architectures</strong>:
<ul>
<li>U-Net based on the DDPM implementation.</li>
<li>Dimensions: 256 hidden dimension.</li>
<li>Sinusoidal time embeddings used.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Negative Log Likelihood (NLL) in nats, Frechet Inception Distance (FID) for images.</li>
<li><strong>Baselines</strong>: FFJORD, Glow, Real NVP, OT-Flow, ScoreFlow, DDPM.</li>
</ul>
<p><strong>Tabular NLL</strong> (nats, lower is better; Table 2 Left):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>POWER</th>
          <th>GAS</th>
          <th>HEPMASS</th>
          <th>MINIBOONE</th>
          <th>BSDS300</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MADE</td>
          <td>3.08</td>
          <td>-3.56</td>
          <td>20.98</td>
          <td>15.59</td>
          <td>-148.85</td>
      </tr>
      <tr>
          <td>Real NVP</td>
          <td>-0.17</td>
          <td>-8.33</td>
          <td>18.71</td>
          <td>13.55</td>
          <td>-153.28</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>-0.17</td>
          <td>-8.15</td>
          <td>18.92</td>
          <td>11.35</td>
          <td>-155.07</td>
      </tr>
      <tr>
          <td>CPF</td>
          <td>-0.52</td>
          <td>-10.36</td>
          <td>16.93</td>
          <td>10.58</td>
          <td>-154.99</td>
      </tr>
      <tr>
          <td>NSP</td>
          <td>-0.64</td>
          <td>-13.09</td>
          <td>14.75</td>
          <td>9.67</td>
          <td>-157.54</td>
      </tr>
      <tr>
          <td>FFJORD</td>
          <td>-0.46</td>
          <td>-8.59</td>
          <td>14.92</td>
          <td>10.43</td>
          <td>-157.40</td>
      </tr>
      <tr>
          <td>OT-Flow</td>
          <td>-0.30</td>
          <td>-9.20</td>
          <td>17.32</td>
          <td>10.55</td>
          <td>-154.20</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>-0.57</strong></td>
          <td><strong>-12.35</strong></td>
          <td><strong>14.85</strong></td>
          <td><strong>10.42</strong></td>
          <td><strong>-156.22</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation NLL and FID</strong> (Table 2 Right; all numbers as reported in Albergo &amp; Vanden-Eijnden (2023)):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>CIFAR-10 NLL</th>
          <th>CIFAR-10 FID</th>
          <th>ImageNet-32 NLL</th>
          <th>ImageNet-32 FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FFJORD</td>
          <td>3.40</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>3.35</td>
          <td>-</td>
          <td>4.09</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM</td>
          <td>≤3.75</td>
          <td>3.17</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM++ (Song et al., 2021)</td>
          <td>≤3.37</td>
          <td>2.90</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ScoreSDE (Song et al., 2021)</td>
          <td>2.99</td>
          <td>2.92</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>VDM</td>
          <td>≤2.65</td>
          <td>7.41</td>
          <td>-</td>
          <td>≤3.72</td>
      </tr>
      <tr>
          <td>Soft Truncation</td>
          <td>2.88</td>
          <td>3.45</td>
          <td>3.85</td>
          <td>8.42</td>
      </tr>
      <tr>
          <td>ScoreFlow</td>
          <td>2.81</td>
          <td>5.40</td>
          <td>3.76</td>
          <td>10.18</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>2.99</strong></td>
          <td><strong>10.27</strong></td>
          <td><strong>3.48</strong></td>
          <td><strong>8.49</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: DDPM++ and ScoreSDE are both architectural variants from Song et al. (2021); NLL and FID figures in that row correspond to different model configurations within the same paper (NLL 2.99 from the VP-SDE Probability Flow ODE; FID 2.92 from the VE-SDE configuration). InterFlow matches ScoreSDE on CIFAR-10 NLL (2.99) while being simulation-free. FID is weaker than dedicated image models (10.27 vs 2.92 for ScoreSDE), reflecting the paper&rsquo;s primary focus on tractable likelihood rather than sample quality.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: All models were trained on a single NVIDIA A100 GPU.</li>
<li><strong>Training Time</strong>:
<ul>
<li>Tabular: $10^5$ steps.</li>
<li>Images: $1.5 \times 10^5$ to $6 \times 10^5$ steps.</li>
<li>Speedup: Demonstrated ~400x speedup compared to FFJORD on MiniBooNE dataset.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>The Eleventh International Conference on Learning Representations</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{albergoBuildingNormalizingFlows2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Building {{Normalizing Flows}} with {{Stochastic Interpolants}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Eleventh International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Albergo, Michael Samuel and {Vanden-Eijnden}, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=li7qeBbCR1t}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=li7qeBbCR1t">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/2209.15571">arXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Translating InChI to IUPAC Names with Transformers</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/handsel-inchi-iupac-2021/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/handsel-inchi-iupac-2021/</guid><description>Sequence-to-sequence Transformer translating InChI identifiers to IUPAC names with 91% accuracy on organic compounds.</description><content:encoded><![CDATA[<h2 id="primary-contribution-a-transformer-based-method">Primary Contribution: A Transformer-Based Method</h2>
<p>This is primarily a <strong>Method</strong> paper. It adapts a specific architecture (Transformer) to a specific task (InChI-to-IUPAC translation) and evaluates its performance against both machine learning and commercial baselines. It also has a secondary <strong>Resource</strong> contribution, as the trained model and scripts are released as open-source software.</p>
<h2 id="motivation-the-bottleneck-in-algorithmic-iupac-nomenclature">Motivation: The Bottleneck in Algorithmic IUPAC Nomenclature</h2>
<p>Generating correct IUPAC names is difficult due to the comprehensive but complex rules defined by the International Union of Pure and Applied Chemistry. Commercial software generates names from structures but remains closed-source with opaque methodologies and frequent inter-package disagreements. Open identifiers like InChI and SMILES lack direct human readability. This creates a need for an open, automated method to generate informative IUPAC names from standard identifiers like InChI, which are ubiquitous in online chemical databases.</p>
<h2 id="novelty-treating-chemical-translation-as-a-character-level-sequence">Novelty: Treating Chemical Translation as a Character-Level Sequence</h2>
<p>The key novelty is treating chemical nomenclature translation as a character-level sequence-to-sequence problem using a Transformer architecture, specifically using <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a> as the source language.</p>
<ul>
<li>Standard Neural Machine Translation (NMT) uses sub-word tokenization. This model processes InChI and predicts IUPAC names character-by-character.</li>
<li>It demonstrates that character-level tokenization outperforms byte-pair encoding or unigram models for this specific chemical task.</li>
<li>It uses InChI&rsquo;s standardization to avoid the canonicalization issues inherent in SMILES-based approaches.</li>
<li>The attention mechanism allows the decoder to align specific parts of the generated IUPAC name with corresponding structural features in the source InChI string, operating via the standard scaled dot-product attention:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$</li>
</ul>
<h2 id="methodology--experimental-validation">Methodology &amp; Experimental Validation</h2>
<ul>
<li><strong>Training:</strong> The model was trained on 10 million InChI/IUPAC pairs sampled from PubChem using a character-level objective. The model is supervised using categorical cross-entropy loss across the vocabulary of characters:
$$ \mathcal{L} = -\sum_{i=1}^{N} y_i \log(\hat{y}_i) $$</li>
<li><strong>Ablation Studies:</strong> The authors experimentally validated architecture choices, finding that LSTM models and sub-word tokenization (BPE) performed worse than the Transformer with character tokenization. They also optimized dropout rates.</li>
<li><strong>Performance Benchmarking:</strong> The model was evaluated on a held-out test set of 200,000 samples. Performance was quantified primarily by Whole-Name Accuracy and Normalized Edit Distance (based on the Damerau-Levenshtein distance, scaled by the maximum string length).</li>
<li><strong>Commercial Comparison:</strong> The authors compared their model against four major commercial packages (ACD/I-Labs, ChemAxon, Mestrelab, and PubChem&rsquo;s Lexichem). However, this evaluation used a highly limited test set of only 100 molecules, restricting the statistical confidence of the external baseline.</li>
<li><strong>Error Analysis:</strong> They analyzed performance across different chemical classes (organics, charged species, macrocycles, inorganics) and visualized attention coefficients to interpret model focus.</li>
</ul>
<h2 id="key-results-and-the-inorganic-challenge">Key Results and the Inorganic Challenge</h2>
<ul>
<li><strong>High Accuracy on Organics:</strong> The model achieved 91% whole-name accuracy on the test set, performing particularly well on organic compounds.</li>
<li><strong>Comparable to Commercial Tools:</strong> On the limited 100-molecule benchmark, the edit distance between the model&rsquo;s predictions and commercial packages (15-23%) was similar to the variation found <em>between</em> the commercial packages themselves (16-21%).</li>
<li><strong>Limitations on Inorganics:</strong> The model performed poorly on inorganic (14% accuracy) and organometallic compounds (20% accuracy). This is attributed to inherent data limitations in the standard InChI format (which deliberately disconnects metal atoms from their ligands) and low training data coverage for those classes.</li>
<li><strong>Character-Level Superiority:</strong> Character-level tokenization was found to be essential; byte-pair encoding reduced accuracy significantly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset was derived from <a href="https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Extras/">PubChem&rsquo;s public FTP server</a> (<code>CID-SMILES.gz</code> and <code>CID-IUPAC.gz</code>).</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Raw</strong></td>
          <td>PubChem</td>
          <td>100M pairs</td>
          <td>Filtered for length (InChI &lt; 200 chars, IUPAC &lt; 150 chars). 132k unparseable SMILES dropped.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Subsampled</td>
          <td>10M pairs</td>
          <td>Random sample from the filtered set.</td>
      </tr>
      <tr>
          <td><strong>Validation</strong></td>
          <td>Held-out</td>
          <td>10,000 samples</td>
          <td>Limited to InChI length &gt; 50 chars.</td>
      </tr>
      <tr>
          <td><strong>Test</strong></td>
          <td>Held-out</td>
          <td>200,000 samples</td>
          <td>Limited to InChI length &gt; 50 chars.</td>
      </tr>
      <tr>
          <td><strong>Tokenization</strong></td>
          <td>Vocab</td>
          <td>InChI: 66 chars<br>IUPAC: 70 chars</td>
          <td>Character-level tokenization. Spaces treated as tokens.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework</strong>: OpenNMT-py 2.0.0 (using PyTorch). Training scripts and vocabularies are available as supplementary files to the original publication. Pre-trained model weights are hosted on <a href="https://doi.org/10.5281/zenodo.5081159">Zenodo</a>.</li>
<li><strong>Architecture Type</strong>: Transformer Encoder-Decoder.</li>
<li><strong>Optimization</strong>: ADAM optimizer ($\beta_1=0.9, \beta_2=0.998$).</li>
<li><strong>Learning Rate</strong>: Linear warmup over 8000 steps to 0.0005, then decayed by inverse square root of iteration.</li>
<li><strong>Regularization</strong>:
<ul>
<li>Dropout: 0.1 (applied to dense and attentional layers).</li>
<li>Label Smoothing: Magnitude 0.1.</li>
</ul>
</li>
<li><strong>Training Strategy</strong>: Teacher forcing used for both training and validation.</li>
<li><strong>Gradient Accumulation</strong>: Gradients accumulated over 4 batches before updating parameters.</li>
<li><strong>Inference</strong>: Beam search with width 10 and length penalty 1.0.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Structure</strong>: 6 layers in encoder, 6 layers in decoder.</li>
<li><strong>Attention</strong>: 8 heads per attention sub-layer.</li>
<li><strong>Dimensions</strong>:
<ul>
<li>Feed-forward hidden state size: 2048.</li>
<li>Embedding vector length: 512.</li>
</ul>
</li>
<li><strong>Initialization</strong>: Glorot&rsquo;s method.</li>
<li><strong>Position</strong>: Positional encoding added to word vectors.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics reported include <strong>Whole-Name Accuracy</strong> (percentage of exact matches) and <strong>Normalized Edit Distance</strong> (Damerau-Levenshtein, scale 0-1).</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy (All)</td>
          <td>91%</td>
          <td>N/A</td>
          <td>Test set of 200k samples.</td>
      </tr>
      <tr>
          <td>Accuracy (Inorganic)</td>
          <td>14%</td>
          <td>N/A</td>
          <td>Limited by InChI format and data.</td>
      </tr>
      <tr>
          <td>Accuracy (Organometallic)</td>
          <td>20%</td>
          <td>N/A</td>
          <td>Limited by InChI format and data.</td>
      </tr>
      <tr>
          <td>Accuracy (Charged)</td>
          <td>79%</td>
          <td>N/A</td>
          <td>Test set subset.</td>
      </tr>
      <tr>
          <td>Accuracy (Rajan)</td>
          <td>72%</td>
          <td>N/A</td>
          <td>Comparative ML model (STOUT).</td>
      </tr>
      <tr>
          <td>Edit Dist (Organic)</td>
          <td>$0.02 \pm 0.03$</td>
          <td>N/A</td>
          <td>Very high similarity for organics.</td>
      </tr>
      <tr>
          <td>Edit Dist (Inorganic)</td>
          <td>$0.32 \pm 0.20$</td>
          <td>N/A</td>
          <td>Poor performance on inorganics.</td>
      </tr>
      <tr>
          <td>Edit Dist (Organometallic)</td>
          <td>$0.37 \pm 0.24$</td>
          <td>N/A</td>
          <td>Poor performance on organometallics.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Tesla K80.</li>
<li><strong>Training Time</strong>: 7 days.</li>
<li><strong>Throughput</strong>: ~6000 tokens/sec (InChI) and ~3800 tokens/sec (IUPAC).</li>
<li><strong>Batch Size</strong>: 4096 tokens (approx. 30 compounds).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.5081159">InChI to IUPAC model</a></td>
          <td>Model</td>
          <td>CC BY 4.0</td>
          <td>Pre-trained Transformer weights (551 MB), requires OpenNMT-py 2.0.0</td>
      </tr>
      <tr>
          <td><a href="https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Extras/">PubChem FTP</a></td>
          <td>Dataset</td>
          <td>Public Domain</td>
          <td>Source data: CID-SMILES.gz and CID-IUPAC.gz</td>
      </tr>
      <tr>
          <td>Training scripts &amp; vocabularies</td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Included as supplementary files with the publication</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Handsel, J., Matthews, B., Knight, N. J., &amp; Coles, S. J. (2021). Translating the InChI: Adapting Neural Machine Translation to Predict IUPAC Names from a Chemical Identifier. <em>Journal of Cheminformatics</em>, 13(1), 79. <a href="https://doi.org/10.1186/s13321-021-00535-x">https://doi.org/10.1186/s13321-021-00535-x</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{handselTranslatingInChIAdapting2021a,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Translating the {{InChI}}: Adapting Neural Machine Translation to Predict {{IUPAC}} Names from a Chemical Identifier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Translating the {{InChI}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Handsel, Jennifer and Matthews, Brian and Knight, Nicola J. and Coles, Simon J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{79}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00535-x}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{We present a sequence-to-sequence machine learning model for predicting the IUPAC name of a chemical from its standard International Chemical Identifier (InChI). The model uses two stacks of transformers in an encoder-decoder architecture, a setup similar to the neural networks used in state-of-the-art machine translation. Unlike neural machine translation, which usually tokenizes input and output into words or sub-words, our model processes the InChI and predicts the IUPAC name character by character. The model was trained on a dataset of 10 million InChI/IUPAC name pairs freely downloaded from the National Library of Medicine&#39;s online PubChem service. Training took seven days on a Tesla K80 GPU, and the model achieved a test set accuracy of 91\%. The model performed particularly well on organics, with the exception of macrocycles, and was comparable to commercial IUPAC name generation software. The predictions were less accurate for inorganic and organometallic compounds. This can be explained by inherent limitations of standard InChI for representing inorganics, as well as low coverage in the training data.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{Attention,GPU,InChI,IUPAC,seq2seq,Transformer}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Struct2IUPAC: Translating SMILES to IUPAC via Transformers</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/struct2iupac-2021/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/struct2iupac-2021/</guid><description>A Transformer-based model for translating between SMILES strings and IUPAC names, trained on 47M PubChem examples, achieving 98.9% accuracy with verification.</description><content:encoded><![CDATA[<h2 id="struct2iupac-as-a-methodological-shift">Struct2IUPAC as a Methodological Shift</h2>
<p>This is primarily a <strong>Method</strong> paper with significant elements of <strong>Position</strong>.</p>
<ul>
<li><strong>Method</strong>: The authors propose a specific neural architecture (Transformer with custom tokenization) and a verification pipeline (round-trip check) to solve the SMILES $\leftrightarrow$ IUPAC translation task. They rigorously benchmark this against rule-based baselines (OPSIN).</li>
<li><strong>Position</strong>: The authors explicitly argue for a paradigm shift, suggesting that &ldquo;heavy&rdquo; neural architectures should replace complex, costly rule-based legacy systems even for &ldquo;exact&rdquo; algorithmic tasks.</li>
</ul>
<h2 id="the-cost-of-rule-based-chemical-naming">The Cost of Rule-Based Chemical Naming</h2>
<ul>
<li><strong>Complexity of Naming</strong>: Generating IUPAC names manually is error-prone and requires deep algorithmic knowledge.</li>
<li><strong>Lack of Open Source Tools</strong>: While open-source tools exist for Name-to-Structure (e.g., OPSIN), there were no open-source tools for the inverse &ldquo;Structure-to-Name&rdquo; conversion at the time of writing.</li>
<li><strong>Cost of Development</strong>: Developing rule-based converters &ldquo;from scratch&rdquo; is prohibitively expensive and time-consuming compared to training a neural model on existing data.</li>
</ul>
<h2 id="struct2iupac-core-innovation">Struct2IUPAC Core Innovation</h2>
<ul>
<li><strong>Struct2IUPAC</strong>: The first effective open-source neural model for <a href="/notes/computational-chemistry/chemical-language-models/stout-v2/">converting SMILES to IUPAC names</a>, treating chemical translation as a Neural Machine Translation (NMT) problem.</li>
<li><strong>Verification Loop</strong>: A novel inference pipeline that generates multiple candidates via beam search and validates them using a reverse converter (OPSIN) to ensure the generated name maps back to the original structure.</li>
<li><strong>Custom Tokenization</strong>: A manually curated rule-based tokenizer for IUPAC names that handles specific chemical suffixes, prefixes, and stereochemical markers.</li>
</ul>
<h2 id="experimental-setup-and-stress-testing">Experimental Setup and Stress Testing</h2>
<ul>
<li><strong>Accuracy Benchmarking</strong>: The models were tested on a held-out subset of 100,000 molecules from PubChem. The authors measured accuracy across different beam sizes (1, 3, 5).</li>
<li><strong>Comparison to Rules</strong>: The neural IUPAC2Struct model was compared directly against the rule-based OPSIN tool.</li>
<li><strong>Stress Testing</strong>:
<ul>
<li><strong>Sequence Length</strong>: Evaluated performance across varying token lengths, identifying a &ldquo;sweet spot&rdquo; (10-60 tokens) and failure modes for very short (e.g., methane) or long molecules.</li>
<li><strong>Stereochemistry</strong>: Tested on &ldquo;stereo-dense&rdquo; compounds. The authors define a &ldquo;stereo-density&rdquo; index ($I$) as the ratio of stereocenters ($S$) to total tokens ($N$):
$$I = \frac{S}{N}$$
They observed a performance drop for these dense molecules, though the model still handled many stereocenters robustly.</li>
<li><strong>Tautomers</strong>: Verified the model&rsquo;s ability to handle different tautomeric forms (e.g., Guanine and Uracil variants).</li>
</ul>
</li>
<li><strong>Latency Analysis</strong>: Benchmarked inference speeds on CPU vs. GPU relative to output sequence length.</li>
</ul>
<h2 id="benchmarks-and-outcomes">Benchmarks and Outcomes</h2>
<ul>
<li><strong>High Accuracy</strong>: The Struct2IUPAC model achieved <strong>98.9% accuracy</strong> (Beam 5 with verification). The reverse model (IUPAC2Struct) achieved <strong>99.1%</strong>, comparable to OPSIN&rsquo;s 99.4%.</li>
<li><strong>Distribution Modeling vs. Intuition</strong>: The authors claim the model infers &ldquo;chemical logic,&rdquo; because it correctly generates multiple valid IUPAC names for single molecules where naming ambiguity exists (e.g., parent group selection). However, this more likely reflects the Transformer successfully modeling the high-frequency conditional probability distribution of synonymous names present in the PubChem training data, rather than learning intrinsic chemical rules.</li>
<li><strong>Production Readiness</strong>: Inference on GPU takes less than 0.5 seconds even for long names, making it viable for production use.</li>
<li><strong>Paradigm Shift</strong>: The authors conclude that neural networks are a viable, cost-effective alternative to developing rule-based algorithms for legacy notation conversion.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study utilized the PubChem database.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Total</strong></td>
          <td>PubChem</td>
          <td>~95M</td>
          <td>Filtered for RDKit compatibility</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Split A</td>
          <td>47,312,235</td>
          <td>Random 50% split</td>
      </tr>
      <tr>
          <td><strong>Testing</strong></td>
          <td>Split B</td>
          <td>47,413,850</td>
          <td>Random 50% split</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Cleaning</strong>: Molecules that could not be processed by RDKit were removed. Molecules containing tokens not in the tokenizer (e.g., aromatic selenium) were excluded.</li>
<li><strong>Availability</strong>: A subset of 100,000 test molecules is available on GitHub (<code>data/test_100000.csv</code>) and Zenodo. The full train/test splits are not explicitly provided.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>SMILES</strong>: Character-based tokenization.</li>
<li><strong>IUPAC</strong>: Custom rule-based tokenizer splitting suffixes (<code>-one</code>, <code>-al</code>), prefixes (<code>-oxy</code>, <code>-di</code>), and special symbols (<code>(</code>, <code>)</code>, <code>R(S)</code>).</li>
</ul>
</li>
<li><strong>Verification Step</strong>:
<ol>
<li>Generate $N$ names using Beam Search ($N=5$).</li>
<li>Reverse translate the candidate name using OPSIN.</li>
<li>Check if the OPSIN structure matches the original input SMILES.</li>
<li>Display the first verified match; otherwise, report failure.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Standard Transformer with 6 encoder layers and 6 decoder layers.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Attention Heads: 8</li>
<li>Attention Dimension ($d_{\text{model}}$): 512</li>
<li>Feed-Forward Dimension ($d_{\text{ff}}$): 2048</li>
</ul>
</li>
<li><strong>Training Objective</strong>: The models were trained using standard autoregressive cross-entropy loss over the target token sequence $y$ given the input string $x$:
$$\mathcal{L} = - \sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, x)$$</li>
<li><strong>Training</strong>: Two separate models were trained: <code>Struct2IUPAC</code> (SMILES $\to$ IUPAC) and <code>IUPAC2Struct</code> (IUPAC $\to$ SMILES).</li>
<li><strong>Availability</strong>: Code for model architecture is provided in the GitHub repository. Pre-trained weights for the IUPAC2Struct model are available, but the Struct2IUPAC model weights are not publicly released, meaning researchers would need to retrain that model on their own PubChem data to reproduce those results.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation was performed on a random subset of 100,000 molecules from the test set.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Beam Size</th>
          <th>Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Struct2IUPAC</td>
          <td>1</td>
          <td>96.1%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Struct2IUPAC</td>
          <td>5</td>
          <td>98.9%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>IUPAC2Struct</td>
          <td>1</td>
          <td>96.6%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>IUPAC2Struct</td>
          <td>5</td>
          <td>99.1%</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Robustness</strong>: Accuracy drops significantly for augmented (non-canonical) SMILES (37.16%) and stereo-enriched compounds (66.52%).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Infrastructure</strong>: 4 $\times$ Tesla V100 GPUs and 36 CPUs.</li>
<li><strong>Training Time</strong>: Approximately 10 days under full load.</li>
<li><strong>Inference Speed</strong>: &lt;0.5s per molecule on GPU; scale is linear with output token length.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/sergsb/IUPAC2Struct">IUPAC2Struct (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Transformer code and pre-trained IUPAC2Struct model</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.4280814">Test data (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>100k test molecules, OPSIN failure cases, model failure cases</td>
      </tr>
      <tr>
          <td><a href="https://app.syntelly.com/smiles2iupac">Struct2IUPAC web demo</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online interface for SMILES to IUPAC conversion</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Krasnov, L., Khokhlov, I., Fedorov, M. V., &amp; Sosnin, S. (2021). Transformer-based artificial neural networks for the conversion between chemical notations. <em>Scientific Reports</em>, 11(1), 14798. <a href="https://doi.org/10.1038/s41598-021-94082-y">https://doi.org/10.1038/s41598-021-94082-y</a></p>
<p><strong>Publication</strong>: Scientific Reports 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{krasnovTransformerbasedArtificialNeural2021a,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Transformer-Based Artificial Neural Networks for the Conversion between Chemical Notations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Krasnov, Lev and Khokhlov, Ivan and Fedorov, Maxim V. and Sosnin, Sergey}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{14798}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41598-021-94082-y}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/sergsb/IUPAC2Struct">GitHub Repository</a></li>
<li><a href="https://app.syntelly.com/smiles2iupac">Web Demo</a></li>
</ul>
]]></content:encoded></item><item><title>STOUT: SMILES to IUPAC Names via Neural Machine Translation</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/stout/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/stout/</guid><description>A deep-learning neural machine translation approach to translate between SMILES strings and IUPAC names using the STOUT model.</description><content:encoded><![CDATA[<h2 id="contribution-translating-chemistry-as-a-language">Contribution: Translating Chemistry as a Language</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary contribution as a <strong>Resource</strong> paper.</p>
<ul>
<li><strong>Method</strong>: It proposes a neural machine translation (NMT) architecture to approximate the complex, rule-based algorithm of IUPAC naming, treating it as a language translation task.</li>
<li><strong>Resource</strong>: It provides an open-source tool and trained models to the community, addressing a gap where such functionality was previously limited to proprietary software.</li>
</ul>
<h2 id="motivation-democratizing-iupac-nomenclature">Motivation: Democratizing IUPAC Nomenclature</h2>
<p>The International Union of Pure and Applied Chemistry (IUPAC) naming scheme is universally accepted but algorithmically complex. Generating these names correctly is challenging for humans, and automated generation is largely missing from major open-source toolkits like CDK, RDKit, or Open Babel. While reliable commercial tools exist (e.g., ChemAxon&rsquo;s <code>molconvert</code>), there was a lack of open-source alternatives for the scientific community. STOUT aims to fill this gap using a data-driven approach.</p>
<h2 id="core-innovation-sequence-to-sequence-naming">Core Innovation: Sequence-to-Sequence Naming</h2>
<ul>
<li><strong>Language Translation Approach</strong>: The authors treat chemical representations (<a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>/<a href="/notes/computational-chemistry/molecular-representations/selfies/">SELFIES</a>) and IUPAC names as two different languages, applying Neural Machine Translation (NMT) to translate between them.</li>
<li><strong>Use of SELFIES</strong>: The work establishes SELFIES (Self-Referencing Embedded Strings) as a robust choice over SMILES for deep learning tokenization in this specific task, capitalizing on its syntactic robustness.</li>
<li><strong>Hardware Acceleration</strong>: The paper benchmarks GPU versus TPU training and highlights the practical necessity of Tensor Processing Units (TPUs) for training large-scale chemical language models, reducing training time by an order of magnitude.</li>
</ul>
<h2 id="methodology--translation-validation">Methodology &amp; Translation Validation</h2>
<ul>
<li><strong>Data Scale</strong>: The model was trained on datasets of 30 million and 60 million molecules derived from PubChem.</li>
<li><strong>Hardware Benchmarking</strong>: Training efficiency was compared between an nVidia Tesla V100 GPU and Google TPU v3-8/v3-32 units.</li>
<li><strong>Bidirectional Translation</strong>: The system was tested on two distinct tasks:
<ol>
<li><strong>Forward</strong>: SELFIES → IUPAC names</li>
<li><strong>Reverse</strong>: IUPAC names → SELFIES</li>
</ol>
</li>
<li><strong>Validation</strong>: Performance was evaluated on a held-out test set of 2.2 million molecules.</li>
</ul>
<h2 id="translation-accuracy--hardware-scaling">Translation Accuracy &amp; Hardware Scaling</h2>
<ul>
<li><strong>High Accuracy</strong>: The model achieved an average BLEU score of ~90% and a Tanimoto similarity index &gt; 0.9 for both translation directions.</li>
<li><strong>Generalization</strong>: Even when predictions were textually mismatched (low BLEU score), the underlying chemical structures often remained highly similar (high Tanimoto similarity), suggesting the system captures fundamental chemical semantics rather than merely memorizing strings.</li>
<li><strong>Impact of Data Size</strong>: Expanding training from 30 million to 60 million molecules yielded consistent performance gains without saturating.</li>
<li><strong>Hardware Necessity</strong>: Training on TPUs proved up to 54 times faster than a standard GPU baseline (Tesla V100), making scaling highly computationally tractable.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">STOUT (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Current repo hosts STOUT V2.0 transformer models; V1 RNN code available in earlier commits</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Public Domain</td>
          <td style="text-align: left">Source of 111M molecules; 30M/60M training subsets not directly provided</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The dataset was curated from PubChem (111 million molecules). Note that the specific 30M and 60M subsets are not directly linked in the publication repository, which means a user would have to reconstruct the filtering process.</p>
<p><strong>Preprocessing &amp; Filtering</strong>:</p>
<ul>
<li>Explicit hydrogens removed; converted to canonical SMILES.</li>
<li><strong>Filtering Rules</strong>: MW &lt; 1500 Da, no counter ions, limited element set (C, H, O, N, P, S, F, Cl, Br, I, Se, B), no hydrogen isotopes, 3-40 bonds, no charged groups.</li>
<li><strong>Ground Truth Generation</strong>: ChemAxon&rsquo;s <code>molconvert</code> (Marvin Suite 20.15) was used to generate target IUPAC names for training.</li>
<li><strong>Representation</strong>: All SMILES were converted to SELFIES for training.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">PubChem Filtered</td>
          <td style="text-align: left">30M &amp; 60M</td>
          <td style="text-align: left">Two distinct training sets created.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">PubChem Held-out</td>
          <td style="text-align: left">2.2M</td>
          <td style="text-align: left">Molecules not present in training sets; uniform token frequency.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>SELFIES</strong>: Split iteratively by brackets <code>[</code> and <code>]</code>.</li>
<li><strong>IUPAC</strong>: Split via punctuation (<code>(</code>, <code>)</code>, <code>{</code>, <code>}</code>, <code>[</code>, <code>]</code>, <code>-</code>, <code>.</code>, <code>,</code>) and a discrete set of sub-word chemical morphemes (e.g., <code>methyl</code>, <code>benzene</code>, <code>fluoro</code>).</li>
<li><strong>Padding</strong>: SELFIES padded to 48 tokens; IUPAC padded to 78 tokens. &ldquo;Start&rdquo; and &ldquo;End&rdquo; sequence markers append each chain.</li>
</ul>
</li>
<li><strong>Optimization</strong>: Adam optimizer instantiated with a learning rate of $0.0005$.</li>
<li><strong>Objective Function</strong>: Sparse categorical cross-entropy, assessing prediction probabilities for token $i$ over vocabulary $V$:
$$ \mathcal{L} = -\sum_{i=1}^{V} y_i \log(\hat{y}_i) $$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder sequence-to-sequence network with Bahdanau attention mechanism context weighting.</li>
<li><strong>Components</strong>:
<ul>
<li><strong>Encoder/Decoder</strong>: Recurrent Neural Networks (RNN) constructed using Gated Recurrent Units (GRU).</li>
<li><strong>Attention</strong>: Bahdanau (additive) soft attention, which calculates alignment scores to softly weight encoder hidden states natively:
$$ e_{tj} = v_a^\top \tanh(W_a s_{t-1} + U_a h_j) $$</li>
<li><strong>Embedding</strong>: Decoder output passes through a continuous embedding layer before concatenating with the attention context vector.</li>
</ul>
</li>
<li><strong>Implementation</strong>: Python 3 backend using TensorFlow 2.3.0. <em>Note: The linked GitHub repository currently defaults to the STOUT V2.0 transformer models, so researchers aiming to reproduce this specific V1 RNN paper should reference the older tag/commit history.</em></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics heavily emphasize both linguistic accuracy and cheminformatic structural correctness:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Details</th>
          <th style="text-align: left">Result (60M Model)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>BLEU Score</strong></td>
          <td style="text-align: left">NLTK sentence BLEU (unigram to 4-gram)</td>
          <td style="text-align: left">0.94 (IUPAC $\to$ SELFIES)</td>
          <td style="text-align: left">Exact text overlap. Serves as a strictly syntactic proxy.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Tanimoto Similarity</strong></td>
          <td style="text-align: left">PubChem fingerprints via CDK</td>
          <td style="text-align: left">0.98 (Valid IUPAC names)</td>
          <td style="text-align: left">Evaluates substructure alignment over bit vectors, $T(A, B) = \frac{\vert A \cap B \vert}{\vert A \cup B \vert}$.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Comparison of hardware efficiency for training large chemical language models:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hardware</th>
          <th style="text-align: left">Batch Size</th>
          <th style="text-align: left">Time per Epoch (15M subset)</th>
          <th style="text-align: left">Speedup Factor</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>GPU (Tesla V100)</strong></td>
          <td style="text-align: left">256</td>
          <td style="text-align: left">~27 hours</td>
          <td style="text-align: left">1x</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>TPU v3-8</strong></td>
          <td style="text-align: left">1024 (Global)</td>
          <td style="text-align: left">~2 hours</td>
          <td style="text-align: left">13x</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>TPU v3-32</strong></td>
          <td style="text-align: left">1024 (Global)</td>
          <td style="text-align: left">~0.5 hours</td>
          <td style="text-align: left">54x</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2021). STOUT: SMILES to IUPAC names using neural machine translation. <em>Journal of Cheminformatics</em>, 13(1), 34. <a href="https://doi.org/10.1186/s13321-021-00512-4">https://doi.org/10.1186/s13321-021-00512-4</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanSTOUTSMILESIUPAC2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{STOUT: SMILES to IUPAC Names Using Neural Machine Translation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{STOUT}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00512-4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-09-22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{Chemical compounds can be identified through a graphical depiction, a suitable string representation, or a chemical name. A universally accepted naming scheme for chemistry was established by the International Union of Pure and Applied Chemistry (IUPAC) based on a set of rules. Due to the complexity of this ruleset a correct chemical name assignment remains challenging for human beings and there are only a few rule-based cheminformatics toolkits available that support this task in an automated manner. Here we present STOUT (SMILES-TO-IUPAC-name translator), a deep-learning neural machine translation approach to generate the IUPAC name for a given molecule from its SMILES string as well as the reverse translation, i.e. predicting the SMILES string from the IUPAC name. In both cases, the system is able to predict with an average BLEU score of about 90% and a Tanimoto similarity index of more than 0.9. Also incorrect predictions show a remarkable similarity between true and predicted compounds.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{Attention mechanism,Chemical language,Deep neural network,DeepSMILES,IUPAC names,Neural machine translation,Recurrent neural network,SELFIES,SMILES}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">GitHub Repository</a></li>
<li><a href="/notes/computational-chemistry/chemical-language-models/stout-v2/">STOUT V2.0 Note</a></li>
<li><a href="/notes/computational-chemistry/chemical-language-models/struct2iupac-2021/">Struct2IUPAC Note</a></li>
<li><a href="/notes/computational-chemistry/chemical-language-models/handsel-inchi-iupac-2021/">HandSEL Note (InChI to IUPAC)</a></li>
</ul>
]]></content:encoded></item><item><title>STOUT V2.0: Transformer-Based SMILES to IUPAC Translation</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/stout-v2/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/stout-v2/</guid><description>A Transformer-based model for translating SMILES to IUPAC names, trained on ~1 billion molecules, achieving ~0.99 BLEU score on benchmarks.</description><content:encoded><![CDATA[<h2 id="paper-contribution--methodological-scope">Paper Contribution &amp; Methodological Scope</h2>
<p><strong>Method (Primary) / Resource (Secondary)</strong></p>
<p>This paper presents a <strong>Methodological</strong> contribution by developing and validating a Transformer-based neural machine translation model (STOUT V2) for bidirectional chemical nomenclature (<a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a> $\leftrightarrow$ IUPAC). It systematically compares this new architecture against previous RNN-based baselines (<a href="/notes/computational-chemistry/chemical-language-models/stout/">STOUT V1</a>) and performs ablation studies on tokenization strategies.</p>
<p>It also serves as a significant <strong>Resource</strong> contribution by generating a massive training dataset of nearly 1 billion SMILES-IUPAC pairs (curated via commercial Lexichem software) and releasing the resulting models and code as open-source tools for chemical naming.</p>
<h2 id="the-need-for-robust-open-source-iupac-nomenclature-rules">The Need for Robust Open-Source IUPAC Nomenclature Rules</h2>
<p>Assigning systematic IUPAC names to chemical structures requires adherence to complex rules, challenging human consistency. Deterministic, rule-based software options like OpenEye Lexichem and ChemAxon are reliable commercial solutions. Existing open-source tools like OPSIN focus on parsing names to structures.</p>
<p>The previous version of STOUT (V1), based on RNNs/GRUs, achieved ~90% BLEU accuracy, with known limitations in capturing long-distance dependencies required for stereochemistry handling. This work uses the sequence-learning capabilities of Transformers combined with large-scale datasets to create a competitive open-source IUPAC naming tool.</p>
<h2 id="architectural-shift-and-billion-scale-training">Architectural Shift and Billion-Scale Training</h2>
<p>The primary advancements over previous iterations address both architecture and dataset scale:</p>
<ol>
<li><strong>Architecture Shift</strong>: Moving from an RNN-based Seq2Seq model to a <strong>Transformer-based architecture</strong> (4 layers, 8 heads), which captures intricate chemical patterns better than GRUs.</li>
<li><strong>Billion-Scale Training</strong>: Training on a dataset of nearly <strong>1 billion molecules</strong> (combining PubChem and ZINC15), significantly larger than the 60 million used for STOUT V1.</li>
<li><strong>Tokenization Strategy</strong>: Determining that <strong>character-wise tokenization</strong> for IUPAC names is superior to word-wise tokenization in terms of both accuracy and training efficiency (15% faster).</li>
</ol>
<h2 id="experimental-validation-and-scaling-limits">Experimental Validation and Scaling Limits</h2>
<p>The authors conducted three primary experiments to validate bidirectional translation (SMILES $\rightarrow$ IUPAC and IUPAC $\rightarrow$ SMILES):</p>
<ul>
<li><strong>Experiment 1 (Optimization)</strong>: Assessed the impact of dataset size (1M vs 10M vs 50M) and tokenization strategy on SMILES-to-IUPAC performance.</li>
<li><strong>Experiment 2 (Scaling)</strong>: Trained models on 110 million PubChem molecules for <strong>both</strong> forward and reverse translation tasks to test performance on longer sequences.</li>
<li><strong>Experiment 3 (Generalization)</strong>: Trained on the full ~1 billion dataset (PubChem + ZINC15) for both translation directions.</li>
<li><strong>External Validation</strong>: Benchmarked against an external dataset from ChEBI (1,485 molecules) and ChEMBL34 to test generalization to unseen data.</li>
</ul>
<p><strong>Evaluation Metrics</strong>:</p>
<ul>
<li><strong>Textual Accuracy</strong>: BLEU scores (1-4) and Exact String Match.</li>
<li><strong>Chemical Validity</strong>: Retranslation of generated names back to SMILES using OPSIN, followed by Tanimoto similarity checks (PubChem fingerprints) against the original input.</li>
</ul>
<h2 id="translation-accuracy-and-structural-validity">Translation Accuracy and Structural Validity</h2>
<ul>
<li><strong>Superior Performance</strong>: STOUT V2 achieved an average BLEU score of <strong>0.99</strong> (vs 0.94 for V1). While exact string matches varied by experiment (83-89%), the model notably achieved a perfect BLEU score (1.0) on <strong>97.49%</strong> of a specific test set where STOUT V1 only reached 66.65%.</li>
<li><strong>Structural Validity (&ldquo;Near Misses&rdquo;)</strong>: When the generated name differed from the ground truth string, the re-generated structure often remained chemically valid. The model maintained an average Tanimoto similarity $T(A,B)$ of <strong>0.68</strong> for these divergent names between bit-vector fingerprints $A$ and $B$, roughly defined as:
$$ T(A,B) = \frac{\sum (A \cap B)}{\sum (A \cup B)} $$
<em>Critique</em>: Note that an average Tanimoto coefficient of 0.68 typically suggests moderate structural similarity/drift, not an almost-identical &ldquo;near miss&rdquo; (which would be $&gt;0.85$). This implies the model constructs chemically related but structurally distinct outputs when it fails exact string matching.</li>
<li><strong>Tokenization</strong>: Character-level splitting for IUPAC names outperformed word-level splitting and was more computationally efficient.</li>
<li><strong>Data Imbalance &amp; Generalization</strong>: The model&rsquo;s drop in performance for sequences &gt;600 characters highlights a systemic issue in open chemical databases: long, highly complex SMILES strings are significantly underrepresented. Even billion-scale training datasets are still bound by the chemical diversity of their source material.</li>
<li><strong>Limitations</strong>:
<ul>
<li><strong>Preferred Names (PINs)</strong>: The model mimics Lexichem&rsquo;s naming conventions, generating valid IUPAC names distinct from strict <em>Preferred IUPAC Names</em> (PINs).</li>
<li><strong>Sequence Length</strong>: Performance degrades for very long SMILES (&gt;600 characters) due to scarcity in the training data.</li>
<li><strong>Algorithmic Distillation Bottleneck</strong>: Because the 1 billion training pairs were generated entirely by OpenEye&rsquo;s Lexichem, STOUT V2 acts as a knowledge distillation of that specific commercial algorithm. The model learns Lexichem’s heuristic mapping, specific dialects, and potential systematic errors, rather than deriving true nomenclature rules from first principles.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data was derived from PubChem and ZINC15. Ground truth IUPAC names were generated using OpenEye Lexichem TK 2.8.1 to ensure consistency.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training (Exp 1)</strong></td>
          <td>PubChem Subset</td>
          <td>1M, 10M, 50M</td>
          <td>Selected via MaxMin algorithm for diversity</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 2)</strong></td>
          <td>PubChem</td>
          <td>110M</td>
          <td>Filtered for SMILES length &lt; 600</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 3)</strong></td>
          <td>PubChem + ZINC15</td>
          <td>~1 Billion</td>
          <td>999,637,326 molecules total</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>ChEBI</td>
          <td>1,485</td>
          <td>External validation set, non-overlapping with training</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>SMILES</strong>: Canonicalized, isomeric, and kekulized using RDKit (v2023.03.1).</li>
<li><strong>Formatting</strong>: Converted to TFRecord format in 100 MB chunks for TPU efficiency.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES Tokenization</strong>: Regex-based splitting. Atoms (e.g., &ldquo;Cl&rdquo;, &ldquo;Au&rdquo;), bonds, brackets, and digits are separate tokens.</li>
<li><strong>IUPAC Tokenization</strong>: <strong>Character-wise split</strong> was selected as the optimal strategy (treating every character as a token).</li>
<li><strong>Optimization</strong>: Adam optimizer with a custom learning rate scheduler based on model dimensions.</li>
<li><strong>Loss Function</strong>: Trained to minimize the Sparse Categorical Cross-Entropy $L$, masking padding tokens. For a correctly predicted target class $t$ alongside probabilities $p_i$, the masked loss is represented mathematically as:
$$ L = - \sum_{i=1}^{m} m_i y_{i} \log(p_{i}) $$
where $m_i$ masks padded positions.</li>
<li><strong>Code Availability</strong>: The <a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">main STOUT V2 repository</a> contains the inference package. The training pipeline/instructions (originally linked to a separate repo that is currently a 404) can still be found within the <a href="https://doi.org/10.5281/zenodo.6559438">Zenodo archive release</a>.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model follows the standard Transformer architecture from &ldquo;Attention is All You Need&rdquo; (Vaswani et al.).</p>
<ul>
<li><strong>Architecture</strong>: 4 Transformer layers (encoder/decoder stack).</li>
<li><strong>Attention</strong>: Multi-head attention with <strong>8 heads</strong>.</li>
<li><strong>Dimensions</strong>: Embedding size ($d_{model}$) = 512; Feed-forward dimension ($d_{ff}$) = 2048.</li>
<li><strong>Regularization</strong>: Dropout rate of 0.1.</li>
<li><strong>Context Window</strong>: Max input length (SMILES) = 600; Max output length (IUPAC) = 700-1000.</li>
<li><strong>Weights</strong>: Model weights for forward and reverse architectures are <a href="https://doi.org/10.5281/zenodo.13318286">available via Zenodo (v3)</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation focused on both string similarity and chemical structural integrity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Method</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>BLEU Score</strong></td>
          <td>N-gram overlap</td>
          <td>Compared predicted IUPAC string to Ground Truth.</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Accuracy</td>
          <td>Binary 1/0 check for identical strings.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto</strong></td>
          <td>Structural Similarity</td>
          <td>Predicted Name $\rightarrow$ OPSIN $\rightarrow$ SMILES $\rightarrow$ Fingerprint comparison to input.</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">STOUT V2 GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Inference package (PyPI: STOUT-pypi)</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.13318286">Model Weights (Zenodo v3)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Forward and reverse translation weights</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/6559438">Code Snapshot (Zenodo)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training pipeline archive</td>
      </tr>
      <tr>
          <td><a href="https://stout.decimer.ai">Web Application</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Demo with Ketcher, bulk submission, DECIMER integration</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was conducted entirely on Google Cloud Platform (GCP) TPUs.</p>
<ul>
<li><strong>STOUT V1</strong>: Trained on TPU v3-8.</li>
<li><strong>STOUT V2</strong>: Trained on <strong>TPU v4-128 pod slices</strong> (128 nodes).</li>
<li><strong>Large Scale (Exp 3)</strong>: Trained on <strong>TPU v4-256 pod slice</strong> (256 nodes).</li>
<li><strong>Training Time</strong>: Average of <strong>15 hours and 2 minutes per epoch</strong> for the 1 billion dataset.</li>
<li><strong>Framework</strong>: TensorFlow 2.15.0-pjrt with Keras.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2024). STOUT V2.0: SMILES to IUPAC name conversion using transformer models. <em>Journal of Cheminformatics</em>, 16(146). <a href="https://doi.org/10.1186/s13321-024-00941-x">https://doi.org/10.1186/s13321-024-00941-x</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanSTOUTV20SMILES2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{STOUT V2}}.0: {{SMILES}} to {{IUPAC}} Name Conversion Using Transformer Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{STOUT V2}}.0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{146}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00941-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://stout.decimer.ai">Web Application</a> (Includes Ketcher drawing, bulk submission, and DECIMER integration)</li>
<li><a href="https://decimer.ai">DECIMER Project</a></li>
<li><a href="/notes/computational-chemistry/chemical-language-models/stout/">STOUT V1 Note</a></li>
<li><a href="https://zenodo.org/records/6559438">Zenodo Archive (Code Snapshot)</a></li>
</ul>
]]></content:encoded></item><item><title>Multimodal Search in Chemical Documents and Reactions</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/shah-multimodal-search-2025/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/shah-multimodal-search-2025/</guid><description>A multimodal search engine that integrates text passages, molecular diagrams, and reaction data to enable passage-level retrieval in chemical literature.</description><content:encoded><![CDATA[<h2 id="contribution-multimodal-synthesis-retrieval">Contribution: Multimodal Synthesis Retrieval</h2>
<p>This paper represents a $\Psi_{\text{Method}}$ projection that proposes a novel architectural pipeline for indexing and searching chemical literature. The framework unifies text, molecular diagrams, and structured reaction records. It also contains a secondary $\Psi_{\text{Resource}}$ projection, providing a functional demonstration tool and curating a specific benchmark dataset for Suzuki coupling reactions.</p>
<h2 id="the-gap-in-passage-level-chemical-retrieval">The Gap in Passage-Level Chemical Retrieval</h2>
<p>Scientific literature documents chemical reactions through a combination of text and visual diagrams. Textual descriptions detail parameters like yield and operational temperature, whereas diagrams graphically model these structural transformations. Existing tools such as SciFinder or Reaxys perform document-level or individual compound retrieval. They fail to explicitly link molecular figures to localized textual descriptions. This structure prevents researchers from directly extracting a corresponding reaction diagram alongside the exact textual protocol. Researchers require passage-level retrieval of synthesis protocols to efficiently access complete reaction conditions.</p>
<h2 id="core-innovation-unified-multimodal-indexing">Core Innovation: Unified Multimodal Indexing</h2>
<p>The core methodological innovation is a multimodal passage-level indexing and linking pipeline.</p>
<ul>
<li><strong>Unified Indexing:</strong> The framework processes text and diagrams in parallel and directly links them into a single index structure. This architecture supports search queries utilizing raw text, discrete SMILES strings, or multimodal combinations.</li>
<li><strong>Compound-Passage Linking:</strong> The mechanism applies conflict-resolution logic linking chemical diagrams to specific text citations using two parallel heuristics:
<ol>
<li><strong>Token-based Alignment:</strong> Matching parsed diagram labels against documented text strings (e.g., &ldquo;compound 5&rdquo;) using normalized Levenshtein distance.</li>
<li><strong>Fingerprint-based Alignment:</strong> Matching chemical structures against generated SMILES strings via structural Tanimoto Similarity.</li>
</ol>
</li>
<li><strong>ReactionMiner Integration:</strong> The pipeline parses and incorporates formatted reaction records (reactants, products, catalysts, quantitative yields) directly derived from segmented text passages.</li>
</ul>
<h2 id="methodology--expert-evaluation">Methodology &amp; Expert Evaluation</h2>
<p>The authors evaluated the system utilizing a chemical case study targeting specific synthesis domains alongside qualitative expert assessment.</p>
<ul>
<li><strong>Dataset:</strong> Evaluators processed a corpus of 7 research manuscripts and 6 supplementary data documents detailing Suzuki coupling reactions.</li>
<li><strong>Volume:</strong> The resulting index processed 1,282 extracted passages (indexing 538), extracted 383 unique SMILES, and logged 219 parsed reactions.</li>
<li><strong>Qualitative Evaluation:</strong> Practicing structural chemists developed real-world queries (such as cross-referencing the conceptual &ldquo;Burke group&rdquo; alongside an explicit structural SMARTS pattern) to gauge retrieval capability.</li>
</ul>
<h2 id="key-findings--system-limitations">Key Findings &amp; System Limitations</h2>
<ul>
<li><strong>Diagram-to-Text Linking:</strong> The pipeline accurately paired visual molecular diagrams with structurally derived text details, permitting testers to navigate directly from a molecule query card to the exact origin passage within the source PDF.</li>
<li><strong>Contextual Insight Extraction:</strong> Specialized chemists found the parsed reaction representations (yield metrics, isolated catalysts) functionally pragmatic as high-level extractive summaries.</li>
<li><strong>Extrapolative Retrieval:</strong> The architecture permitted the effective retrieval of targeted chemical derivatives (such as benzo[b]thiophen-2-ylboronic acid) via structurally related input queries (dibenzothiophene).</li>
</ul>
<p>The system evaluation highlights several architectural restrictions:</p>
<ul>
<li><strong>Domain-Restricted Validation:</strong> The initial validation is entirely qualitative and bounded to the specific subclass of Suzuki coupling reactions. The evaluation omits standardized quantitative retrieval baselines (e.g., MAP, NDCG) and lacks systematic ablation data for the fusion scoring mechanism.</li>
<li><strong>Algorithmic Transparency:</strong> The multimodal query routing mechanism does not clearly indicate the dominant retrieval feature. This hides whether keyword text or structural similarity actually drove the final result placement. This ambiguity limits operator control.</li>
<li><strong>Optical Processing Brittleness:</strong> The embedded vision inference and primitive parsing pipelines display inherent fragility, producing intermittent failures when associating text passages with correctly parsed molecular diagrams.</li>
<li><strong>Metadata Logging Incompleteness:</strong> Practicing chemists requested additional structured metadata targets (such as specific molar equivalents and parameterized mol% values) to successfully bridge the extracted data stream directly into digital electronic lab notebooks.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.cs.rit.edu/~dprl/reactionminer-demo-landing/">ReactionMiner Demo</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Online demo landing page; source code repository not publicly linked</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source:</strong> The corpus features 7 primary research papers and 6 auxiliary supplementary information documents focusing on Suzuki coupling reactions, sourced from practicing chemists at UIUC. This evaluation dataset is strictly internal and not publicly available.</li>
<li><strong>Preprocessing:</strong>
<ul>
<li>Engineers convert source PDFs to full-page raster images.</li>
<li>The system extracts localized graphical layout and raw text via <strong>PyTesseract</strong>.</li>
<li>The pipeline segments valid passage chunks emphasizing reaction-related sentences utilizing product-indicative lexicons and topic modeling.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Diagram Extraction:</strong> A <strong>YOLOv8</strong> model identifies and segments molecular regions within structured PDF pages.</li>
<li><strong>Diagram Parsing:</strong> The architecture relies on <strong>ChemScraper</strong> to infer structural semantics from raw diagrams:
<ul>
<li><em>Born-digital PDFs:</em> <strong>SymbolScraper</strong> extracts vector lines and polygons directly from bounding box definitions.</li>
<li><em>Raster images:</em> The system employs the <strong>Line Segment Detector (LSD)</strong> and watershed bounding algorithms to isolate native geometric primitives.</li>
</ul>
</li>
<li><strong>Text Entity Extraction:</strong> The framework deploys <strong>ChemDataExtractor 2.0</strong> to extract explicit molecular aliases. A translation layer maps these entities to string representations via <strong>OPSIN</strong>.</li>
<li><strong>Linking Logic (Fusion Score):</strong>
<ul>
<li><strong>Text Link:</strong> The algorithm calculates a normalized Levenshtein ratio connecting visual diagram labels against proximal text mentions based on calculated edit distance.</li>
<li><strong>Structure Link:</strong> The algorithm computes the discrete Tanimoto Similarity between generated 2048-bit Morgan fingerprints extracted from localized visual diagram features and baseline text SMILES queries:
$$ T(A, B) = \frac{A \cdot B}{|A|^{2} + |B|^{2} - A \cdot B} $$
where $A$ and $B$ represent the boolean bit vectors of the respective fingerprint pairs.</li>
<li><strong>Conflict Resolution Protocol:</strong> The system fuses structural geometry bounds and discrete textual tokenization metrics, prioritizing the ranking sequence that yields a higher terminal similarity score. During final retrieval, the candidate subset is systematically re-ranked leveraging the hybrid calculation of the BM25 explicit metric and the localized count of exact SMILES pattern hits.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Reaction Extraction Parameters:</strong> The engineers configure a <strong>LLaMA-3.1-8b</strong> model fine-tuned entirely via <strong>LoRA</strong> targeting custom tokens representing reaction entities (compounds, reagents, thermal inputs) directly pulled from text sub-chunks. Exact prompt constraints, the fine-tuning dataset, and specific LoRA hyperparameters are omitted from the source text.</li>
<li><strong>Diagram Processing Bounds:</strong> The codebase incorporates a segmentation-aware multi-task neural network topology built into ChemScraper to execute low-level raster image parsing tasks.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Search Engine Base:</strong> The authors implemented their indexing framework scaling atop <strong>PyTerrier</strong>.</li>
<li><strong>Text Feature Ranking:</strong> The metric utilizes standalone <strong>BM25</strong> bounds mapping keyword-similarity.</li>
<li><strong>Structure Feature Operations:</strong> The topology operates <strong>RDKit</strong> bindings powering substructure coordinate mapping logic and exact molecular similarity searches.</li>
<li><strong>Multimodal Fusion Processing:</strong>
<ul>
<li>The algorithm filters out terminal candidates mapping initial structural properties (SMILES queries) against the document-wide lexical properties (BM25 scores).</li>
<li>The final fusion routing assigns the strongest positive weight to retrieved passages that accumulate dense local clusters of structurally exact verified SMILES patterns.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute Infrastructure:</strong> The hardware and parameter requirements to host the multi-stage vision extractors (YOLOv8, ChemScraper) alongside a local 8B LLM are entirely unspecified in the paper.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Shah, A. K., et al. (2025). Multimodal Search in Chemical Documents and Reactions. <em>arXiv preprint arXiv:2502.16865</em>. <a href="https://doi.org/10.48550/arXiv.2502.16865">https://doi.org/10.48550/arXiv.2502.16865</a></p>
<p><strong>Publication</strong>: SIGIR &lsquo;25 (Demo Track), 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{shahMultimodalSearchChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Multimodal {{Search}} in {{Chemical Documents}} and {{Reactions}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Shah, Ayush Kumar and Dey, Abhisek and Luo, Leo and Amador, Bryan and Philippy, Patrick and Zhong, Ming and Ouyang, Siru and Friday, David Mark and Bianchi, David and Jackson, Nick and Zanibbi, Richard and Han, Jiawei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.cs.rit.edu/~dprl/reactionminer-demo-landing/">Online Demo</a> (Note: While the landing page advertises the system as open-source, the exact repository URL and installation prerequisites are omitted from the official manuscript.)</li>
</ul>
]]></content:encoded></item><item><title>MOFFlow: Flow Matching for MOF Structure Prediction</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/mofflow/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/mofflow/</guid><description>A Riemannian flow matching framework for generating Metal-Organic Framework structures by treating building blocks as rigid bodies.</description><content:encoded><![CDATA[<h2 id="methodological-contribution-mofflow-architecture">Methodological Contribution: MOFFlow Architecture</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$).</p>
<p>It introduces <strong>MOFFlow</strong>, a generative architecture and training framework designed specifically for the structure prediction of Metal-Organic Frameworks (MOFs). The paper focuses on the algorithmic innovation of decomposing the problem into rigid-body assembly on a Riemannian manifold, validates this through comparison against existing baselines, and performs ablation studies to justify architectural choices. While it leverages the theory of flow matching, its primary contribution is the application-specific architecture and the handling of modular constraints.</p>
<h2 id="motivation-scaling-limits-of-atom-level-generation">Motivation: Scaling Limits of Atom-Level Generation</h2>
<p>The primary motivation is to overcome the scalability and accuracy limitations of existing methods for MOF structure prediction.</p>
<ul>
<li><strong>Computational Cost of DFT:</strong> Conventional approaches rely on <em>ab initio</em> calculations (DFT) combined with random search, which are computationally prohibitive for large, complex systems like MOFs.</li>
<li><strong>Failure of General CSP:</strong> Existing deep generative models for general Crystal Structure Prediction (CSP) operate on an atom-by-atom basis. They fail to scale to MOFs, which often contain hundreds or thousands of atoms per unit cell, and do not exploit the inherent modular nature (building blocks) of MOFs.</li>
<li><strong>Tunability:</strong> MOFs have applications in carbon capture and drug delivery due to their tunable porosity, making automated design tools valuable.</li>
</ul>
<h2 id="core-innovation-rigid-body-flow-matching-on-se3">Core Innovation: Rigid-Body Flow Matching on SE(3)</h2>
<p>MOFFlow introduces a <strong>hierarchical, rigid-body flow matching framework</strong> tailored for MOFs.</p>
<ul>
<li><strong>Rigid Body Decomposition:</strong> MOFFlow treats metal nodes and organic linkers as rigid bodies, reducing the search space from $3N$ (atoms) to $6M$ (roto-translation of $M$ blocks) compared to atom-based methods.</li>
<li><strong>Riemannian Flow Matching on $SE(3)$:</strong> It is the first end-to-end model to jointly generate block-level rotations ($SO(3)$), translations ($\mathbb{R}^3$), and lattice parameters using <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">Riemannian flow matching</a>.</li>
<li><strong>MOFAttention:</strong> A custom attention module designed to encode the geometric relationships between building blocks, lattice parameters, and rotational constraints.</li>
<li><strong>Constraint Handling:</strong> It incorporates domain knowledge by operating on a mean-free system for translation invariance and using canonicalized coordinates for rotation invariance.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The authors evaluated MOFFlow on structure prediction accuracy, physical property preservation, and scalability.</p>
<ul>
<li><strong>Dataset:</strong> The <strong>Boyd et al. (2019)</strong> dataset consisting of 324,426 hypothetical MOF structures, decomposed into building blocks using the <strong>MOFid</strong> algorithm. Filtered to structures with &lt;200 blocks, yielding 308,829 structures (247,066 train / 30,883 val / 30,880 test). Structures contain up to approximately 2,400 atoms per unit cell.</li>
<li><strong>Baselines:</strong>
<ul>
<li><em>Optimization-based:</em> Random Search (RS) and Evolutionary Algorithm (EA) using CrySPY and CHGNet.</li>
<li><em>Deep Learning:</em> DiffCSP (deep generative model for general crystals).</li>
<li><em>Self-Assembly:</em> A heuristic algorithm used in MOFDiff (adapted for comparison).</li>
</ul>
</li>
<li><strong>Metrics:</strong>
<ul>
<li><strong>Match Rate (MR):</strong> Percentage of generated structures matching ground truth within tolerance.</li>
<li><strong>RMSE:</strong> Root mean squared displacement normalized by average free length per atom.</li>
<li><strong>Structural Properties:</strong> Volumetric/Gravimetric Surface Area (VSA/GSA), Pore Limiting Diameter (PLD), Void Fraction, etc., calculated via Zeo++.</li>
<li><strong>Scalability:</strong> Performance vs. number of atoms and building blocks.</li>
</ul>
</li>
</ul>
<h2 id="results-and-generative-performance">Results and Generative Performance</h2>
<p>MOFFlow outperformed all baselines in accuracy and efficiency, particularly for large structures.</p>
<ul>
<li><strong>Accuracy:</strong> With a single sample, MOFFlow achieved a <strong>31.69% match rate</strong> (stol=0.5) and <strong>87.46%</strong> (stol=1.0) on the full test set (30,880 structures). With 5 samples, these rose to <strong>44.75%</strong> (stol=0.5) and <strong>100.0%</strong> (stol=1.0). RS and EA (tested on 100 and 15 samples respectively due to computational cost, generating 20 candidates each) achieved 0.00% MR at both tolerance levels. DiffCSP reached 0.09% (stol=0.5) and 23.12% (stol=1.0) with 1 sample.</li>
<li><strong>Speed:</strong> Inference took <strong>1.94 seconds</strong> per structure, compared to 5.37s for DiffCSP, 332s for RS, and 1,959s for EA.</li>
<li><strong>Scalability:</strong> MOFFlow preserved high match rates across all system sizes, while DiffCSP&rsquo;s match rate dropped sharply beyond 200 atoms.</li>
<li><strong>Property Preservation:</strong> The distributions of physical properties (e.g., surface area, void fraction) for MOFFlow-generated structures closely matched the ground truth. DiffCSP frequently reduced volumetric surface area and void fraction to zero.</li>
<li><strong>Self-Assembly Comparison:</strong> In a controlled comparison where the self-assembly (SA) algorithm received MOFFlow&rsquo;s predicted translations and lattice, MOFFlow (MR=31.69%, RMSE=0.2820) outperformed SA (MR=30.04%, RMSE=0.3084), confirming the value of the learned rotational vector fields. In an extended scalability comparison, SA scaled better for structures with many building blocks, but MOFFlow achieved higher overall match rate (31.69% vs. 27.14%).</li>
<li><strong>Batch Implementation:</strong> A refactored Batch version achieves improved results: <strong>32.73% MR</strong> (stol=0.5), RMSE of 0.2743, inference in <strong>0.19s</strong> per structure (10x faster), and training in roughly 1/3 the GPU hours.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper identifies three key limitations:</p>
<ol>
<li><strong>Hypothetical-only evaluation:</strong> All experiments use the Boyd et al. hypothetical database. Evaluation on more challenging real-world datasets remains needed.</li>
<li><strong>Rigid-body assumption:</strong> The model assumes that local building block structures are known, which may be impractical for rare building blocks whose structural information is missing from existing libraries or is inaccurate.</li>
<li><strong>Periodic invariance:</strong> The model is not invariant to periodic transformations of the input. Explicitly modeling periodic invariance could further improve performance.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source:</strong> MOF dataset by Boyd et al. (2019).</li>
<li><strong>Preprocessing:</strong> Structures were decomposed using the metal-oxo decomposition algorithm from <strong>MOFid</strong>.</li>
<li><strong>Filtering:</strong> Structures with fewer than 200 building blocks were used, yielding 308,829 structures.</li>
<li><strong>Splits:</strong> Train/Validation/Test ratio of 8:1:1 (247,066 / 30,883 / 30,880).</li>
<li><strong>Availability:</strong> Pre-processed dataset is available on <a href="https://zenodo.org/records/15187230">Zenodo</a>.</li>
<li><strong>Representations:</strong>
<ul>
<li><em>Atom-level:</em> Tuple $(X, a, l)$ (coordinates, types, lattice).</li>
<li><em>Block-level:</em> Tuple $(\mathcal{B}, q, \tau, l)$ (blocks, rotations, translations, lattice).</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework:</strong> Riemannian Flow Matching.</li>
<li><strong>Objective:</strong> Conditional Flow Matching (CFM) loss regressing to clean data $q_1, \tau_1, l_1$.
$$
\begin{aligned}
\mathcal{L}(\theta) = \mathbb{E}_{t, \mathcal{S}^{(1)}} \left[ \frac{1}{(1-t)^2} \left( \lambda_1 |\log_{q_t}(\hat{q}_1) - \log_{q_t}(q_1)|^2 + \dots \right) \right]
\end{aligned}
$$</li>
<li><strong>Priors:</strong>
<ul>
<li>Rotations ($q$): Uniform on $SO(3)$.</li>
<li>Translations ($\tau$): Standard normal on $\mathbb{R}^3$.</li>
<li>Lattice ($l$): Log-normal for lengths, Uniform(60, 120) for angles (Niggli reduced).</li>
</ul>
</li>
<li><strong>Inference:</strong> ODE solver with <strong>50 integration steps</strong>.</li>
<li><strong>Local Coordinates:</strong> Defined using PCA axes, corrected for symmetry to ensure consistency.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture:</strong> Hierarchical structure with two key modules.
<ul>
<li><strong>Atom-level Update Layers:</strong> 4-layer EGNN-like structure to encode building block features $h_m$ from atomic graphs (cutoff 5Å).</li>
<li><strong>Block-level Update Layers:</strong> 6 layers that iteratively update $q, \tau, l$ using the <strong>MOFAttention</strong> module.</li>
</ul>
</li>
<li><strong>MOFAttention:</strong> Modified Invariant Point Attention (IPA) that incorporates lattice parameters as offsets to the attention matrix.</li>
<li><strong>Hyperparameters:</strong>
<ul>
<li>Node dimension: 256 (block-level), 64 (atom-level).</li>
<li>Attention heads: 24.</li>
<li>Loss coefficients: $\lambda_1=1.0$ (rot), $\lambda_2=2.0$ (trans), $\lambda_3=0.1$ (lattice).</li>
</ul>
</li>
<li><strong>Checkpoints:</strong> Pre-trained weights and models are openly provided on <a href="https://zenodo.org/records/15187230">Zenodo</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics:</strong>
<ul>
<li><strong>Match Rate:</strong> Using <code>StructureMatcher</code> from <code>pymatgen</code>. Tolerances: <code>stol=0.5/1.0</code>, <code>ltol=0.3</code>, <code>angle_tol=10.0</code>.</li>
<li><strong>RMSE:</strong> Normalized by average free length per atom.</li>
</ul>
</li>
<li><strong>Tools:</strong> <strong>Zeo++</strong> for structural property calculations (Surface Area, Pore Diameter, etc.).</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">MOFFlow</th>
          <th style="text-align: left">DiffCSP</th>
          <th style="text-align: left">RS (20 cands)</th>
          <th style="text-align: left">EA (20 cands)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">MR (stol=0.5, k=1)</td>
          <td style="text-align: left"><strong>31.69%</strong></td>
          <td style="text-align: left">0.09%</td>
          <td style="text-align: left">0.00%</td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=1.0, k=1)</td>
          <td style="text-align: left"><strong>87.46%</strong></td>
          <td style="text-align: left">23.12%</td>
          <td style="text-align: left">0.00%</td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=0.5, k=5)</td>
          <td style="text-align: left"><strong>44.75%</strong></td>
          <td style="text-align: left">0.34%</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=1.0, k=5)</td>
          <td style="text-align: left"><strong>100.0%</strong></td>
          <td style="text-align: left">38.94%</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">RMSE (stol=0.5, k=1)</td>
          <td style="text-align: left"><strong>0.2820</strong></td>
          <td style="text-align: left">0.3961</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">Avg. time per structure</td>
          <td style="text-align: left"><strong>1.94s</strong></td>
          <td style="text-align: left">5.37s</td>
          <td style="text-align: left">332s</td>
          <td style="text-align: left">1,959s</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Hardware:</strong> 8 $\times$ NVIDIA RTX 3090 (24GB VRAM).</li>
<li><strong>Training Time:</strong>
<ul>
<li><em>TimestepBatch version (main paper):</em> ~5 days 15 hours.</li>
<li><em>Batch version:</em> ~1 day 17 hours (332.74 GPU hours). The authors also release this refactored implementation, which achieves comparable performance with faster convergence.</li>
</ul>
</li>
<li><strong>Batch Size:</strong> 160 (capped at $1.6 \times 10^6$ atoms squared for memory).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kim, N., Kim, S., Kim, M., Park, J., &amp; Ahn, S. (2025). MOFFlow: Flow Matching for Structure Prediction of Metal-Organic Frameworks. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kimMOFFlowFlowMatching2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MOFFlow: Flow Matching for Structure Prediction of Metal-Organic Frameworks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kim, Nayoung and Kim, Seongsu and Kim, Minsu and Park, Jinkyoo and Ahn, Sungsoo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Thirteenth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=dNT3abOsLo}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=dNT3abOsLo">OpenReview Discussion</a></li>
<li><a href="https://github.com/nayoung10/MOFFlow">Official Code Repository</a></li>
</ul>
]]></content:encoded></item><item><title>InvMSAFold: Generative Inverse Folding with Potts Models</title><link>https://hunterheidenreich.com/notes/interdisciplinary/computational-biology/invmsafold/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/interdisciplinary/computational-biology/invmsafold/</guid><description>A fast, diverse inverse folding method combining deep learning with Potts models to capture full sequence landscapes.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It introduces a novel architecture, <strong>InvMSAFold</strong>, which hybridizes deep learning encoders with statistical physics-based decoders (Potts models). The rhetorical structure focuses on architectural innovation (low-rank parameter generation), ablation of speed/diversity against baselines (ESM-IF1), and algorithmic efficiency.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard inverse folding models (like ESM-IF1 or ProteinMPNN) solve a &ldquo;one-to-one&rdquo; mapping: given a structure, predict the <em>single</em> native sequence. However, in nature, folding is &ldquo;many-to-one&rdquo;: many homologous sequences fold into the same structure.</p>
<p>The authors identify two key gaps:</p>
<ol>
<li><strong>Lack of Diversity</strong>: Standard autoregressive models maximize probability for the ground truth sequence, often failing to capture the broad evolutionary landscape of viable homologs.</li>
<li><strong>Slow Inference</strong>: Autoregressive sampling requires a full neural network pass for <em>every amino acid</em>, making high-throughput screening (e.g., millions of candidates) computationally prohibitive.</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is shifting the learning objective from predicting <em>sequences</em> to predicting <em>probability distributions</em>.</p>
<p>InvMSAFold outputs the parameters (couplings $\mathbf{J}$ and fields $\mathbf{h}$) of a <strong>Potts Model</strong> (a pairwise Markov Random Field).</p>
<ul>
<li><strong>Low-Rank Decomposition</strong>: To handle the massive parameter space of pairwise couplings ($L xL xq xq$), the model predicts a low-rank approximation $\mathbf{V}$ ($L xK xq$), reducing complexity from $\mathcal{O}(L^2)$ to $\mathcal{O}(L)$.</li>
<li><strong>One-Shot Generation</strong>: The deep network runs only <em>once</em> to generate the Potts parameters. Sampling sequences from this Potts model is then performed on CPU via MCMC or autoregressive approximation, which is orders of magnitude faster than running a Transformer decoder for every step.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the model on three CATH-based test sets (Inter-cluster, Intra-cluster, MSA) to test generalization to unseen folds.</p>
<ul>
<li><strong>Speed Benchmarking</strong>: Compared wall-clock sampling time vs. ESM-IF1 on CPU/GPU.</li>
<li><strong>Covariance Reconstruction</strong>: Checked if generated sequences recover the evolutionary correlations found in natural MSAs (Pearson correlation of covariance matrices).</li>
<li><strong>Structural Fidelity</strong>: Generated sequences with high Hamming distance from native, folded them with AlphaFold 2 (no templates), and measured RMSD to the target structure.</li>
<li><strong>Property Profiling</strong>: Analyzed the distribution of predicted solubility (Protein-Sol) and thermostability (Thermoprot) to demonstrate diversity.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Massive Speedup</strong>: InvMSAFold is orders of magnitude faster than ESM-IF1 (CPU vs. GPU; the comparison is not hardware-matched). Because the &ldquo;heavy lifting&rdquo; (generating Potts parameters) happens once, sampling millions of sequences becomes trivial on CPUs.</li>
<li><strong>Better Diversity</strong>: The model captures evolutionary covariances significantly better than ESM-IF1 and ProteinMPNN (which shares similar covariance recovery to ESM-IF1). A PCA-based KL-divergence analysis (lower is better; 0 means a perfect match to the natural MSA distribution) shows InvMSAFold-AR scores of $0.49$ (Inter-cluster) and $0.67$ (Intra-cluster), compared to $15.8$ and $11.9$ for ESM-IF1, demonstrating that the generated sequences occupy a distribution much closer to natural MSAs.</li>
<li><strong>Robust Folding</strong>: Sequences generated far from the native sequence (high Hamming distance) still fold into the correct structure (low RMSD), whereas ESM-IF1 struggles to produce diverse valid sequences.</li>
<li><strong>Property Expansion</strong>: The method generates a wider spread of biochemical properties (solubility/thermostability), making it more suitable for screening libraries in protein design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Source</strong>: CATH v4.2 database (40% non-redundant dataset).</p>
<p><strong>Splits</strong>:</p>
<ul>
<li><strong>Training</strong>: ~22k domains.</li>
<li><strong>Inter-cluster Test</strong>: 10% of clusters held out (unseen folds).</li>
<li><strong>Intra-cluster Test</strong>: Unseen domains from seen clusters.</li>
<li><strong>Augmentation</strong>: MSAs generated using <strong>MMseqs2</strong> against UniRef50. Training uses random subsamples of these MSAs ($|M_X| = 64$) to teach the model evolutionary variance.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Architecture</strong>:</p>
<ul>
<li><strong>Encoder</strong>: Pre-trained <strong>ESM-IF1</strong> encoder. The paper does not explicitly state whether the encoder is frozen or jointly fine-tuned during training; this is an unresolved reproducibility detail.</li>
<li><strong>Decoder</strong>: 6-layer Transformer (8 heads) that outputs a latent tensor.</li>
<li><strong>Projection</strong>: Linear layers project latent tensor to fields $\mathbf{h}$ ($L xq$) and low-rank tensor $\mathbf{V}$ ($L xK xq$).</li>
</ul>
<p><strong>Coupling Construction</strong>:
The full coupling tensor $\mathcal{J}$ is approximated via:
$$\mathcal{J}_{i,a,j,b} = \frac{1}{\sqrt{K}} \sum_{k=1}^{K} \mathcal{V}_{i,k,a} \mathcal{V}_{j,k,b}$$
Rank $K=48$ was used.</p>
<p><strong>Loss Functions</strong>:
Two variants were trained:</p>
<ol>
<li><strong>InvMSAFold-PW</strong>: Trained via <strong>Pseudo-Likelihood (PL)</strong>. Computation is optimized to $\mathcal{O}(L)$ time using the low-rank property.</li>
<li><strong>InvMSAFold-AR</strong>: Trained via <strong>Autoregressive Likelihood</strong>. Couplings are masked ($J_{ij} = 0$ if $i&lt;j$) to allow exact likelihood computation and direct sampling without MCMC.</li>
</ol>
<h3 id="models">Models</h3>
<ul>
<li><strong>InvMSAFold-PW</strong>: Requires MCMC sampling (Metropolis-Hastings) at inference.</li>
<li><strong>InvMSAFold-AR</strong>: Allows direct, fast autoregressive sampling.</li>
<li><strong>Hyperparameters</strong>: AdamW optimizer, lr=$10^{-4}$ (PW) / $3.4 x10^{-4}$ (AR), 94 epochs. L2 regularization on fields/couplings.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>scRMSD</strong>: Structure fidelity (AlphaFold2 prediction vs. Native).</li>
<li><strong>Covariance Pearson Correlation</strong>: Measures recovery of evolutionary pairwise statistics.</li>
<li><strong>Sampling Speed</strong>: Wall-clock time vs. sequence length/batch size.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: Unspecified; described as &ldquo;fast&rdquo; due to low rank.</li>
<li><strong>Inference</strong>:
<ul>
<li><strong>ESM-IF1</strong>: NVIDIA GeForce RTX 4060 Laptop (8GB).</li>
<li><strong>InvMSAFold</strong>: Single core of Intel i9-13905H CPU.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Silva, L. A., Meynard-Piganeau, B., Lucibello, C., &amp; Feinauer, C. (2025). Fast Uncovering of Protein Sequence Diversity from Structure. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://arxiv.org/abs/2406.11975">https://arxiv.org/abs/2406.11975</a></p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{silvaFastUncoveringProtein2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Fast {{Uncovering}} of {{Protein Sequence Diversity}} from {{Structure}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Thirteenth International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Silva, Luca Alessandro and {Meynard-Piganeau}, Barthelemy and Lucibello, Carlo and Feinauer, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=1iuaxjssVp}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=1iuaxjssVp">OpenReview Page</a></li>
<li><a href="https://github.com/luchinoprince/Potts_Inverse_Folding">GitHub Repository</a></li>
</ul>
]]></content:encoded></item><item><title>InstructMol: Multi-Modal Molecular LLM for Drug Discovery</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/instructmol/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/instructmol/</guid><description>A multi-modal LLM aligning 2D molecular graphs with text via two-stage instruction tuning for drug discovery tasks.</description><content:encoded><![CDATA[<h2 id="instructmol-framework-overview">InstructMol Framework Overview</h2>
<p><strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong></p>
<p>This work proposes <strong>InstructMol</strong>, a novel multi-modal architecture and training paradigm. It focuses on engineering a system that aligns a pre-trained molecular graph encoder with a general-purpose Large Language Model (LLM). The paper&rsquo;s primary contribution is the <strong>Two-Stage Instruction Tuning</strong> strategy (Alignment Pre-training + Task-Specific Tuning) designed to bridge the modality gap between 2D molecular graphs and natural language.</p>
<h2 id="bridging-specialist-and-generalist-models">Bridging Specialist and Generalist Models</h2>
<p>Current AI approaches in drug discovery typically fall into two categories. Specialist models deliver high accuracy on specific tasks (such as property prediction) but require extensive labeled datasets and lack conversational adaptability. Conversely, generalist LLMs offer strong reasoning and dialogue capabilities but struggle to natively interpret complex structural data, often relying on brittle 1D text representations of molecules like SMILES.</p>
<p>There is a practical need for a unified &ldquo;Molecular Assistant&rdquo; capable of visually interpreting molecular graphs, reasoning about structure in natural language, and adapting across tasks like synthesis planning and property analysis without training from scratch.</p>
<h2 id="two-stage-modality-alignment">Two-Stage Modality Alignment</h2>
<p>The core novelty lies in the architecture and the <strong>two-stage training pipeline</strong> designed to align differing modalities efficiently:</p>
<ol>
<li><strong>MoleculeSTM Integration</strong>: InstructMol initializes its graph encoder with <strong>MoleculeSTM</strong>, which is already pre-aligned with text via contrastive learning, facilitating easier downstream alignment.</li>
<li><strong>Two-Stage Alignment Strategy</strong>:
<ul>
<li><strong>Stage 1 (Alignment Pre-training)</strong>: Freezes both the LLM and Graph Encoder; trains <em>only</em> a linear projector using a massive dataset of molecule-description pairs to map graph features into the LLM&rsquo;s token space.</li>
<li><strong>Stage 2 (Task-Specific Instruction Tuning)</strong>: Freezes the Graph Encoder; fine-tunes the Projector and the LLM (using <strong>LoRA</strong>) on specific downstream tasks. This allows the model to adapt its reasoning capabilities while preserving the structural understanding gained in Stage 1.</li>
</ul>
</li>
</ol>
<h2 id="task-evaluation-in-drug-discovery">Task Evaluation in Drug Discovery</h2>
<p>The authors evaluated InstructMol across three distinct categories of drug discovery tasks, comparing it against generalist LLMs (Vicuna, LLaMA, Galactica) and specialist models (ChemBERTa, MolT5):</p>
<ol>
<li><strong>Property Prediction</strong>:
<ul>
<li><em>Regression</em>: Predicting quantum mechanical properties (HOMO, LUMO, Gap) using the QM9 dataset.</li>
<li><em>Classification</em>: Predicting biological activity (BACE, BBBP, HIV) using MoleculeNet.</li>
</ul>
</li>
<li><strong>Molecule Description Generation</strong>: Generating natural language descriptions of molecules using the ChEBI-20 dataset.</li>
<li><strong>Chemical Reaction Analysis</strong>:
<ul>
<li><em>Forward Reaction Prediction</em>: Predicting products from reactants.</li>
<li><em>Reagent Prediction</em>: Identifying necessary reagents.</li>
<li><em>Retrosynthesis</em>: Suggesting reactants for a given product.</li>
</ul>
</li>
</ol>
<p><strong>Ablation Studies</strong> tested the impact of the projector type (Linear vs. MLP), LLM scale (7B vs 13B), and the necessity of the two-stage training approach.</p>
<h2 id="core-findings-and-limitations">Core Findings and Limitations</h2>
<ul>
<li><strong>Improvement Over Baseline Generalists</strong>: InstructMol significantly outperformed generalist LLMs (like LLaMA and Galactica) on all tasks, demonstrating the value of incorporating explicit graph modalities.</li>
<li><strong>Reducing the Gap with Specialists</strong>: While InstructMol brings versatile reasoning capabilities, it still trails highly optimized specialist models (such as Uni-Mol and MolT5) on tasks like molecule description generation. This remaining gap likely stems from its reliance on a relatively small alignment pre-training dataset (~264K PubChem pairs) and the information bottleneck of using a simple linear projector, compared to the millions of structures used to train expert foundational models.</li>
<li><strong>Importance of Alignment</strong>: Ablation studies confirmed that skipping Stage 1 (Alignment Pre-training) degraded performance, proving that a dedicated phase for projecting graph features into text space is crucial.</li>
<li><strong>Limitation</strong>: The model struggles with highly imbalanced datasets (e.g., HIV) and complex reaction mixtures where mapping multiple graph tokens to text becomes ambiguous.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training pipeline utilizes distinct datasets for the two stages. <strong>Note:</strong> As of the latest repository update, the finely-processed instruction-tuning datasets (e.g., the filtered ~264K PubChem pairs and instruction-formatted subset pairs) are listed as &ldquo;coming soon&rdquo;, requiring manual recreation for full reproduction.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Stage 1</strong> (Alignment)</td>
          <td style="text-align: left"><strong>PubChem</strong></td>
          <td style="text-align: left">~264K pairs</td>
          <td style="text-align: left">Molecule-text pairs. Filtered from 330K for invalid descriptions and overlaps with ChEBI-20 test set.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Reg.)</td>
          <td style="text-align: left"><strong>QM9</strong></td>
          <td style="text-align: left">362K samples</td>
          <td style="text-align: left">Quantum mechanics properties (HOMO, LUMO, Gap).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Class.)</td>
          <td style="text-align: left"><strong>MoleculeNet</strong></td>
          <td style="text-align: left">35K samples</td>
          <td style="text-align: left">BACE, BBBP, HIV datasets. Converted to instruction format (Yes/No answer).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Generation)</td>
          <td style="text-align: left"><strong>ChEBI-20</strong></td>
          <td style="text-align: left">26.5K samples</td>
          <td style="text-align: left">Molecule description generation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Reactions)</td>
          <td style="text-align: left"><strong>USPTO</strong></td>
          <td style="text-align: left">~380K samples</td>
          <td style="text-align: left">Combined datasets for Forward (125K), Retrosynthesis (130K), and Reagent (125K) prediction.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Two-Stage Training</strong>:
<ol>
<li><strong>Alignment Pre-training</strong>: Updates only the Projector. The objective maximizes the probability of generating the target description token sequence $\mathbf{X}_A$ given the molecule input $\mathbf{X}_M$ and instruction $\mathbf{X}_I$:
$$p(\mathbf{X}_A | \mathbf{X}_M, \mathbf{X}_I) = \prod_{i=1}^L p_\theta(x_i | \mathbf{X}_G \parallel \mathbf{X}_S, \mathbf{X}_I, \mathbf{X}_{A,&lt;i})$$</li>
<li><strong>Instruction Tuning</strong>: Updates Projector + LLM (via LoRA) using standard autoregressive language modeling on task-specific instructions. The objective minimizes the negative log-likelihood of generating the target response $R$ of length $L$:
$$\mathcal{L}(\theta) = -\sum_{i=1}^L \log p(R_i | I, M, R_{&lt;i}; \theta)$$
where $I$ represents the instruction and $M$ is the multi-modal molecular input.</li>
</ol>
</li>
<li><strong>LoRA (Low-Rank Adaptation)</strong>: Applied to the LLM in Stage 2. Rank $r=64$, Scaling $\alpha=16$.</li>
<li><strong>Optimization</strong>: AdamW optimizer. Learning rate starts at 2e-3 (Stage 1) and 8e-5 (Stage 2) with cosine decay. Warm-up ratio 0.03.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Note:</strong> The official repository currently lists the final fine-tuned <strong>InstructMol weights</strong> as &ldquo;coming soon.&rdquo; Consequently, one must fine-tune the components using the provided scripts. Base model weights (Vicuna-7B and MoleculeSTM) are publicly available via Hugging Face.</p>
<ul>
<li><strong>Graph Encoder ($f_g$)</strong>:
<ul>
<li>Architecture: Graph Isomorphism Network (GIN) with 5 layers.</li>
<li>Hidden Dimension: 300.</li>
<li>Initialization: <strong>MoleculeSTM</strong> checkpoint (pre-trained via contrastive learning).</li>
<li>Status: <strong>Frozen</strong> during Stage 2.</li>
</ul>
</li>
<li><strong>LLM</strong>:
<ul>
<li>Base: <strong>Vicuna-v1.3-7B</strong>.</li>
<li>Status: Frozen in Stage 1; LoRA fine-tuned in Stage 2.</li>
</ul>
</li>
<li><strong>Projector</strong>:
<ul>
<li>Architecture: Linear Layer.</li>
<li>Function: Maps node-level graph representation $Z_G \in \mathbb{R}^{N \times d}$ to the LLM&rsquo;s word embedding space dimensions.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric Libraries</strong>: RDKit for validity/fingerprints, standard NLP libraries for BLEU/ROUGE.</li>
<li><strong>Reaction Metrics</strong>: Fingerprint Tanimoto Similarity (FTS), Exact Match, Levenshtein distance, and validity (via RDKit).</li>
<li><strong>Description Metrics</strong>: BLEU-2, BLEU-4, ROUGE-1, ROUGE-2, ROUGE-L, METEOR.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA RTX A6000 (48GB VRAM).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Stage 1: 5 epochs.</li>
<li>Stage 2: 20-50 epochs (Description Generation), 10 epochs (Properties/Reactions).</li>
</ul>
</li>
<li><strong>Batch Size</strong>: 128 for both stages.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/IDEA-XL/InstructMol">InstructMol (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache 2.0 (code), CC BY-NC 4.0 (data)</td>
          <td style="text-align: left">Training/evaluation scripts provided; fine-tuned weights listed as &ldquo;coming soon&rdquo;</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/lmsys/vicuna-7b-v1.3">Vicuna-7B v1.3</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Non-commercial (LLaMA license)</td>
          <td style="text-align: left">Base LLM; must be downloaded separately</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/chao1224/MoleculeSTM">MoleculeSTM</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Pre-trained graph encoder checkpoint</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cao, H., Liu, Z., Lu, X., Yao, Y., &amp; Li, Y. (2025). InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery. <em>Proceedings of the 31st International Conference on Computational Linguistics</em>, 354-379.</p>
<p><strong>Publication</strong>: COLING 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{caoInstructMolMultiModalIntegration2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{InstructMol}}: {{Multi-Modal Integration}} for {{Building}} a {{Versatile}} and {{Reliable Molecular Assistant}} in {{Drug Discovery}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{InstructMol}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 31st {{International Conference}} on {{Computational Linguistics}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Cao, He and Liu, Zijing and Lu, Xingyu and Yao, Yuan and Li, Yu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{Rambow, Owen and Wanner, Leo and Apidianaki, Marianna and {Al-Khalifa}, Hend and Eugenio, Barbara Di and Schockaert, Steven}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{354--379}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Abu Dhabi, UAE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{The rapid evolution of artificial intelligence in drug discovery encounters challenges with generalization and extensive training, yet Large Language Models (LLMs) offer promise in reshaping interactions with complex molecular data. Our novel contribution, InstructMol, a multi-modal LLM, effectively aligns molecular structures with natural language via an instruction-tuning approach, utilizing a two-stage training strategy that adeptly combines limited domain-specific data with molecular and textual information. InstructMol showcases substantial performance improvements in drug discovery-related molecular tasks, surpassing leading LLMs and significantly reducing the gap with specialists, thereby establishing a robust foundation for a versatile and dependable drug discovery assistant.}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IDEA-XL/InstructMol">Official Repository</a></li>
</ul>
]]></content:encoded></item><item><title>Image-to-Sequence OCSR: A Comparative Analysis</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image-to-sequence-comparison/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image-to-sequence-comparison/</guid><description>Comparative analysis of image-to-sequence OCSR methods across architecture, output format, training data, and compute requirements.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>This note provides a comparative analysis of image-to-sequence methods for Optical Chemical Structure Recognition (OCSR). These methods treat molecular structure recognition as an image captioning task, using encoder-decoder architectures to generate sequential molecular representations (<a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>, <a href="/notes/computational-chemistry/molecular-representations/selfies/">SELFIES</a>, <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a>) directly from pixels.</p>
<p>For the full taxonomy of OCSR approaches including image-to-graph and rule-based methods, see the <a href="/notes/computational-chemistry/optical-structure-recognition/ocsr-methods/">OCSR Methods taxonomy</a>.</p>
<h2 id="architectural-evolution-2019-2025">Architectural Evolution (2019-2025)</h2>
<p>The field has undergone rapid architectural evolution, with clear generational shifts in both encoder and decoder design.</p>
<h3 id="timeline">Timeline</h3>
<table>
  <thead>
      <tr>
          <th>Era</th>
          <th>Encoder</th>
          <th>Decoder</th>
          <th>Representative Methods</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2019-2020</strong></td>
          <td>CNN (Inception V3, ResNet)</td>
          <td>LSTM/GRU with Attention</td>
          <td><a href="/notes/computational-chemistry/optical-structure-recognition/staker-deep-learning-2019/">Staker et al.</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/decimer/">DECIMER</a></td>
      </tr>
      <tr>
          <td><strong>2021</strong></td>
          <td>EfficientNet, ViT</td>
          <td>Transformer</td>
          <td><a href="/notes/computational-chemistry/optical-structure-recognition/decimer-1.0/">DECIMER 1.0</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/img2mol/">Img2Mol</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/vit-inchi-transformer/">ViT-InChI</a></td>
      </tr>
      <tr>
          <td><strong>2022</strong></td>
          <td>Swin Transformer, ResNet</td>
          <td>Transformer</td>
          <td><a href="/notes/computational-chemistry/optical-structure-recognition/swinocsr/">SwinOCSR</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/image2smiles/">Image2SMILES</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/micer/">MICER</a></td>
      </tr>
      <tr>
          <td><strong>2023-2024</strong></td>
          <td>EfficientNetV2, SwinV2</td>
          <td>Transformer</td>
          <td><a href="/notes/computational-chemistry/optical-structure-recognition/decimer-ai/">DECIMER.ai</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/image2inchi/">Image2InChI</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/mmssc-net/">MMSSC-Net</a></td>
      </tr>
      <tr>
          <td><strong>2025</strong></td>
          <td>EfficientViT, VLMs (Qwen2-VL)</td>
          <td>LLM decoders, RL fine-tuning</td>
          <td><a href="/notes/computational-chemistry/optical-structure-recognition/molsight/">MolSight</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/gtr-mol-vlm/">GTR-CoT</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/ocsu/">OCSU</a></td>
      </tr>
  </tbody>
</table>
<h3 id="encoder-architectures">Encoder Architectures</h3>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Methods Using It</th>
          <th>Key Characteristics</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Inception V3</strong></td>
          <td>DECIMER (2020)</td>
          <td>Early CNN approach, 299x299 input</td>
      </tr>
      <tr>
          <td><strong>ResNet-50/101</strong></td>
          <td>IMG2SMI, Image2SMILES, MICER, DGAT</td>
          <td>Strong baseline, well-understood</td>
      </tr>
      <tr>
          <td><strong>EfficientNet-B3</strong></td>
          <td>DECIMER 1.0</td>
          <td>Efficient scaling, compound coefficients</td>
      </tr>
      <tr>
          <td><strong>EfficientNet-V2-M</strong></td>
          <td>DECIMER.ai, DECIMER-Hand-Drawn</td>
          <td>Improved training efficiency</td>
      </tr>
      <tr>
          <td><strong>EfficientViT-L1</strong></td>
          <td>MolSight</td>
          <td>Optimized for deployment</td>
      </tr>
      <tr>
          <td><strong>Swin Transformer</strong></td>
          <td>SwinOCSR, MolParser</td>
          <td>Hierarchical vision transformer</td>
      </tr>
      <tr>
          <td><strong>SwinV2</strong></td>
          <td>MMSSC-Net, Image2InChI</td>
          <td>Improved training stability</td>
      </tr>
      <tr>
          <td><strong>Vision Transformer (ViT)</strong></td>
          <td>ViT-InChI</td>
          <td>Pure attention encoder</td>
      </tr>
      <tr>
          <td><strong>DenseNet</strong></td>
          <td>RFL, Hu et al. RCGD</td>
          <td>Dense connections, feature reuse</td>
      </tr>
      <tr>
          <td><strong>Deep TNT</strong></td>
          <td>ICMDT</td>
          <td>Transformer-in-Transformer</td>
      </tr>
      <tr>
          <td><strong>Qwen2-VL</strong></td>
          <td>OCSU, GTR-CoT</td>
          <td>Vision-language model encoder</td>
      </tr>
  </tbody>
</table>
<h3 id="decoder-architectures">Decoder Architectures</h3>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Methods Using It</th>
          <th>Output Format</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>GRU with Attention</strong></td>
          <td>DECIMER, RFL, Hu et al. RCGD</td>
          <td>SMILES, RFL, SSML</td>
      </tr>
      <tr>
          <td><strong>LSTM with Attention</strong></td>
          <td>Staker et al., ChemPix, MICER</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><strong>Transformer</strong></td>
          <td>Most 2021+ methods</td>
          <td>SMILES, SELFIES, InChI</td>
      </tr>
      <tr>
          <td><strong>GPT-2</strong></td>
          <td>MMSSC-Net</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><strong>BART</strong></td>
          <td>MolParser</td>
          <td>E-SMILES</td>
      </tr>
      <tr>
          <td><strong>Pre-trained CDDD</strong></td>
          <td>Img2Mol</td>
          <td>Continuous embedding → SMILES</td>
      </tr>
  </tbody>
</table>
<h2 id="output-representation-comparison">Output Representation Comparison</h2>
<p>The choice of molecular string representation significantly impacts model performance. Representations fall into three categories: core molecular formats for single structures, extended formats for molecular families and variable structures (primarily Markush structures in patents), and specialized representations optimizing for specific recognition challenges.</p>
<p>The <a href="/notes/computational-chemistry/optical-structure-recognition/rajan-string-representations-2022/">Rajan et al. 2022 ablation study</a> provides a comparison of core formats.</p>
<h3 id="core-molecular-formats">Core Molecular Formats</h3>
<p>These represent specific, concrete molecular structures.</p>
<table>
  <thead>
      <tr>
          <th>Format</th>
          <th>Validity Guarantee</th>
          <th>Sequence Length</th>
          <th>Key Characteristic</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>SMILES</strong></td>
          <td>No</td>
          <td>Shortest (baseline)</td>
          <td>Standard, highest accuracy</td>
          <td>DECIMER.ai, MolSight, DGAT, most 2023+</td>
      </tr>
      <tr>
          <td><strong>DeepSMILES</strong></td>
          <td>Partial</td>
          <td>~1.1x SMILES</td>
          <td>Reduces non-local dependencies</td>
          <td>SwinOCSR</td>
      </tr>
      <tr>
          <td><strong>SELFIES</strong></td>
          <td>Yes (100%)</td>
          <td>~1.5x SMILES</td>
          <td>Guaranteed valid molecules</td>
          <td>DECIMER 1.0, IMG2SMI</td>
      </tr>
      <tr>
          <td><strong>InChI</strong></td>
          <td>N/A (canonical)</td>
          <td>Variable (long)</td>
          <td>Unique identifiers, layered syntax</td>
          <td>ViT-InChI, ICMDT, Image2InChI</td>
      </tr>
      <tr>
          <td><strong>FG-SMILES</strong></td>
          <td>No</td>
          <td>Similar to SMILES</td>
          <td>Functional group-aware tokenization</td>
          <td>Image2SMILES</td>
      </tr>
  </tbody>
</table>
<h4 id="smiles-and-variants">SMILES and Variants</h4>
<p><strong>SMILES</strong> remains the dominant format due to its compactness and highest accuracy on clean data. Standard SMILES uses single characters for ring closures and branches that may appear far apart in the sequence, creating learning challenges for sequence models.</p>
<p><strong>DeepSMILES</strong> addresses these non-local syntax dependencies by modifying how branches and ring closures are encoded, making sequences more learnable for neural models. Despite this modification, DeepSMILES sequences are ~1.1x longer than standard SMILES (not shorter). The format offers partial validity improvements through regex-based tokenization with a compact 76-token vocabulary, providing a middle ground between SMILES accuracy and guaranteed validity.</p>
<p><strong>SELFIES</strong> guarantees 100% valid molecules by design through a context-free grammar, eliminating invalid outputs entirely. This comes at the cost of ~1.5x longer sequences and a typical 2-5% accuracy drop compared to SMILES on exact-match metrics. The validity guarantee makes SELFIES particularly attractive for generative modeling applications.</p>
<p><strong>InChI</strong> uses a layered canonical syntax fundamentally different from SMILES-based formats. While valuable for unique molecular identification, its complex multi-layer structure (formula, connectivity, stereochemistry, isotopes, etc.) and longer sequences make it less suitable for image-to-sequence learning, resulting in lower recognition accuracy.</p>
<h4 id="key-findings-from-rajan-et-al-2022">Key Findings from Rajan et al. 2022</h4>
<ol>
<li><strong>SMILES achieves highest exact-match accuracy</strong> on clean synthetic data</li>
<li><strong>SELFIES guarantees 100% valid molecules</strong> but at cost of ~2-5% accuracy drop</li>
<li><strong>InChI is problematic</strong> due to complex layered syntax and longer sequences</li>
<li><strong>DeepSMILES offers middle ground</strong> with partial validity improvements through modified syntax</li>
</ol>
<h3 id="extended-formats-for-variable-structures">Extended Formats for Variable Structures</h3>
<p><strong>Markush structures</strong> represent families of molecules, using variable groups (R1, R2, etc.) with textual definitions. They are ubiquitous in patent documents for intellectual property protection. Standard SMILES cannot represent these variable structures.</p>
<table>
  <thead>
      <tr>
          <th>Format</th>
          <th>Base Format</th>
          <th>Key Feature</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>E-SMILES</strong></td>
          <td>SMILES + XML annotations</td>
          <td>Backward-compatible with separator token</td>
          <td>MolParser</td>
      </tr>
      <tr>
          <td><strong>CXSMILES</strong></td>
          <td>SMILES + extension block</td>
          <td>Substituent tables, compression</td>
          <td>MarkushGrapher</td>
      </tr>
  </tbody>
</table>
<p><strong>E-SMILES</strong> (Extended SMILES) maintains backward compatibility by using a <code>&lt;sep&gt;</code> token to separate core SMILES from XML-like annotations. Annotations encode Markush substituents (<code>&lt;a&gt;index:group&lt;/a&gt;</code>), polymer structures (<code>&lt;p&gt;polymer_info&lt;/p&gt;</code>), and abstract ring patterns (<code>&lt;r&gt;abstract_ring&lt;/r&gt;</code>). The core structure remains parseable by standard RDKit.</p>
<p><strong>CXSMILES</strong> optimizes representation by moving variable groups directly into the main SMILES string as special atoms with explicit atom indexing (e.g., <code>C:1</code>) to link to an extension block containing substituent tables. This handles both frequency variation and position variation in Markush structures.</p>
<h3 id="specialized-representations">Specialized Representations</h3>
<p>These formats optimize for specific recognition challenges beyond standard single-molecule tasks.</p>
<h4 id="rfl-ring-free-language">RFL: Ring-Free Language</h4>
<p><strong>RFL</strong> fundamentally restructures molecular serialization through hierarchical ring decomposition, addressing a core challenge: standard 1D formats (SMILES, SSML) flatten complex 2D molecular graphs, losing explicit spatial relationships.</p>
<p><strong>Mechanism</strong>: RFL decomposes molecules into three explicit components:</p>
<ul>
<li><strong>Molecular Skeleton (𝒮)</strong>: Main graph with rings &ldquo;collapsed&rdquo;</li>
<li><strong>Ring Structures (ℛ)</strong>: Individual ring components stored separately</li>
<li><strong>Branch Information (ℱ)</strong>: Connectivity between skeleton and rings</li>
</ul>
<p><strong>Technical approach</strong>:</p>
<ol>
<li>Detect all non-nested rings using DFS</li>
<li>Calculate adjacency ($\gamma$) between rings based on shared edges</li>
<li>Merge isolated rings ($\gamma=0$) into <strong>SuperAtoms</strong> (single node placeholders)</li>
<li>Merge adjacent rings ($\gamma&gt;0$) into <strong>SuperBonds</strong> (edge placeholders)</li>
<li>Progressive decoding: predict skeleton first, then conditionally decode rings using stored hidden states</li>
</ol>
<p><strong>Performance</strong>: RFL achieves SOTA results on both handwritten (95.38% EM) and printed (95.58% EM) structures, with particular strength on high-complexity molecules where standard baselines fail completely (0% → ~30% on hardest tier).</p>
<p><strong>Note</strong>: RFL does not preserve original drawing orientation; it&rsquo;s focused on computational efficiency through hierarchical decomposition.</p>
<h4 id="ssml-structure-specific-markup-language">SSML: Structure-Specific Markup Language</h4>
<p><strong>SSML</strong> is the primary orientation-preserving format in OCSR. Based on Chemfig (LaTeX chemical drawing package), it provides step-by-step drawing instructions.</p>
<p><strong>Key characteristics</strong>:</p>
<ul>
<li>Describes <em>how to draw</em> the molecule alongside its graph structure</li>
<li>Uses &ldquo;reconnection marks&rdquo; for cyclic structures</li>
<li>Preserves branch angles and spatial relationships</li>
<li>Significantly outperformed SMILES for handwritten recognition: 92.09% vs 81.89% EM (Hu et al. RCGD 2023)</li>
</ul>
<p><strong>Use case</strong>: Particularly valuable for hand-drawn structure recognition where visual alignment between image and reconstruction sequence aids model learning.</p>
<h2 id="training-data-comparison">Training Data Comparison</h2>
<p>Training data scale has grown dramatically, with a shift toward combining synthetic and real-world images.</p>
<h3 id="data-scale-evolution">Data Scale Evolution</h3>
<table>
  <thead>
      <tr>
          <th>Year</th>
          <th>Typical Scale</th>
          <th>Maximum Reported</th>
          <th>Primary Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2019-2020</td>
          <td>1-15M</td>
          <td>57M (Staker)</td>
          <td>Synthetic (RDKit, CDK)</td>
      </tr>
      <tr>
          <td>2021-2022</td>
          <td>5-35M</td>
          <td>35M (DECIMER 1.0)</td>
          <td>Synthetic with augmentation</td>
      </tr>
      <tr>
          <td>2023-2024</td>
          <td>100-150M</td>
          <td>450M+ (DECIMER.ai)</td>
          <td>Synthetic + real patents</td>
      </tr>
      <tr>
          <td>2025</td>
          <td>1-10M + real</td>
          <td>7.7M (MolParser)</td>
          <td>Curated real + synthetic</td>
      </tr>
  </tbody>
</table>
<h3 id="synthetic-vs-real-data">Synthetic vs Real Data</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Training Data</th>
          <th>Real-World Performance Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>450M+ synthetic (RanDepict)</td>
          <td>Strong generalization via domain randomization</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>7.7M with active learning</td>
          <td>Explicitly targets &ldquo;in the wild&rdquo; images</td>
      </tr>
      <tr>
          <td><strong>GTR-CoT</strong></td>
          <td>Real patent/paper images</td>
          <td>Chain-of-thought improves reasoning</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>Multi-stage curriculum</td>
          <td>RL fine-tuning for stereochemistry</td>
      </tr>
  </tbody>
</table>
<h3 id="data-augmentation-strategies">Data Augmentation Strategies</h3>
<p>Common augmentation techniques across methods:</p>
<table>
  <thead>
      <tr>
          <th>Technique</th>
          <th>Purpose</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Rotation</strong></td>
          <td>Orientation invariance</td>
          <td>Nearly all methods</td>
      </tr>
      <tr>
          <td><strong>Gaussian blur</strong></td>
          <td>Image quality variation</td>
          <td>DECIMER, MolParser</td>
      </tr>
      <tr>
          <td><strong>Salt-and-pepper noise</strong></td>
          <td>Scan artifact simulation</td>
          <td>DECIMER, Image2SMILES</td>
      </tr>
      <tr>
          <td><strong>Affine transforms</strong></td>
          <td>Perspective variation</td>
          <td>ChemPix, MolParser</td>
      </tr>
      <tr>
          <td><strong>Font/style variation</strong></td>
          <td>Rendering diversity</td>
          <td>RanDepict (DECIMER.ai)</td>
      </tr>
      <tr>
          <td><strong>Hand-drawn simulation</strong></td>
          <td>Sketch-like inputs</td>
          <td>ChemPix, ChemReco, DECIMER-Hand-Drawn</td>
      </tr>
      <tr>
          <td><strong>Background variation</strong></td>
          <td>Document context</td>
          <td>MolParser, DECIMER.ai</td>
      </tr>
  </tbody>
</table>
<h2 id="hardware-and-compute-requirements">Hardware and Compute Requirements</h2>
<p>Hardware requirements span several orders of magnitude, from consumer GPUs to TPU pods.</p>
<h3 id="training-hardware-comparison">Training Hardware Comparison</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Hardware</th>
          <th>Training Time</th>
          <th>Dataset Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Staker et al. (2019)</strong></td>
          <td>8x GPUs</td>
          <td>26 days</td>
          <td>57M</td>
      </tr>
      <tr>
          <td><strong>IMG2SMI (2021)</strong></td>
          <td>1x RTX 2080 Ti</td>
          <td>5 epochs</td>
          <td>~10M</td>
      </tr>
      <tr>
          <td><strong>Image2SMILES (2022)</strong></td>
          <td>4x V100</td>
          <td>2 weeks</td>
          <td>30M</td>
      </tr>
      <tr>
          <td><strong>MICER (2022)</strong></td>
          <td>4x V100</td>
          <td>42 hours</td>
          <td>10M</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0 (2021)</strong></td>
          <td>TPU v3-8</td>
          <td>Not reported</td>
          <td>35M</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai (2023)</strong></td>
          <td>TPU v3-256</td>
          <td>Not reported</td>
          <td>450M+</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR (2022)</strong></td>
          <td>4x RTX 3090</td>
          <td>5 days</td>
          <td>5M</td>
      </tr>
      <tr>
          <td><strong>MolParser (2025)</strong></td>
          <td>8x A100</td>
          <td>Curriculum learning</td>
          <td>7.7M</td>
      </tr>
      <tr>
          <td><strong>MolSight (2025)</strong></td>
          <td>Not specified</td>
          <td>RL fine-tuning (GRPO)</td>
          <td>Multi-stage</td>
      </tr>
  </tbody>
</table>
<h3 id="inference-considerations">Inference Considerations</h3>
<p>Few papers report inference speed consistently. Available data:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Inference Speed</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>4x faster than DECIMER</td>
          <td>TensorFlow Lite optimization</td>
      </tr>
      <tr>
          <td><strong>OSRA</strong> (baseline)</td>
          <td>~1 image/sec</td>
          <td>CPU-based rule system</td>
      </tr>
      <tr>
          <td><strong>MolScribe</strong></td>
          <td>Real-time capable</td>
          <td>Optimized Swin encoder</td>
      </tr>
  </tbody>
</table>
<h3 id="accessibility-tiers">Accessibility Tiers</h3>
<table>
  <thead>
      <tr>
          <th>Tier</th>
          <th>Hardware</th>
          <th>Representative Methods</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Consumer</strong></td>
          <td>1x RTX 2080/3090</td>
          <td>IMG2SMI, ChemPix</td>
      </tr>
      <tr>
          <td><strong>Workstation</strong></td>
          <td>4x V100/A100</td>
          <td>Image2SMILES, MICER, SwinOCSR</td>
      </tr>
      <tr>
          <td><strong>Cloud/HPC</strong></td>
          <td>TPU pods, 8+ A100</td>
          <td>DECIMER.ai, MolParser</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmark-performance">Benchmark Performance</h2>
<h3 id="common-evaluation-datasets">Common Evaluation Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Type</th>
          <th>Size</th>
          <th>Challenge</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>USPTO</strong></td>
          <td>Patent images</td>
          <td>~5K test</td>
          <td>Real-world complexity</td>
      </tr>
      <tr>
          <td><strong>UOB</strong></td>
          <td>Scanned images</td>
          <td>~5K test</td>
          <td>Scan artifacts</td>
      </tr>
      <tr>
          <td><strong>Staker</strong></td>
          <td>Synthetic</td>
          <td>Variable</td>
          <td>Baseline synthetic</td>
      </tr>
      <tr>
          <td><strong>CLEF</strong></td>
          <td>Patent images</td>
          <td>~1K test</td>
          <td>Markush structures</td>
      </tr>
      <tr>
          <td><strong>JPO</strong></td>
          <td>Japanese patents</td>
          <td>~1K test</td>
          <td>Different rendering styles</td>
      </tr>
  </tbody>
</table>
<h3 id="accuracy-comparison-exact-match-">Accuracy Comparison (Exact Match %)</h3>
<p>Methods are roughly grouped by evaluation era; direct comparison is complicated by different test sets.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>USPTO</th>
          <th>UOB</th>
          <th>Staker</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>OSRA</strong> (baseline)</td>
          <td>~70%</td>
          <td>~65%</td>
          <td>~80%</td>
          <td>Rule-based reference</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>~85%</td>
          <td>~80%</td>
          <td>~90%</td>
          <td>First transformer-based</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR</strong></td>
          <td>~88%</td>
          <td>~82%</td>
          <td>~92%</td>
          <td>Swin encoder advantage</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>~90%</td>
          <td>~85%</td>
          <td>~95%</td>
          <td>Scale + augmentation</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>~92%</td>
          <td>~88%</td>
          <td>~96%</td>
          <td>Real-world focus</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>~93%+</td>
          <td>~89%+</td>
          <td>~97%+</td>
          <td>RL fine-tuning boost</td>
      </tr>
  </tbody>
</table>
<p><em>Note: Numbers are approximate and may vary by specific test split. See individual paper notes for precise figures.</em></p>
<h3 id="stereochemistry-recognition">Stereochemistry Recognition</h3>
<p>Stereochemistry remains a persistent challenge across all methods:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Approach</th>
          <th>Stereo Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Most methods</strong></td>
          <td>Standard SMILES</td>
          <td>Lower than non-stereo</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>RL (GRPO) specifically for stereo</td>
          <td>Improved</td>
      </tr>
      <tr>
          <td><strong>MolNexTR</strong></td>
          <td>Graph-based explicit stereo</td>
          <td>Better handling</td>
      </tr>
      <tr>
          <td><strong>Image2InChI</strong></td>
          <td>InChI stereo layers</td>
          <td>Mixed results</td>
      </tr>
  </tbody>
</table>
<h2 id="hand-drawn-recognition">Hand-Drawn Recognition</h2>
<p>A distinct sub-lineage focuses on hand-drawn/sketched chemical structures.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Target Domain</th>
          <th>Key Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ChemPix (2021)</strong></td>
          <td>Hand-drawn hydrocarbons</td>
          <td>First deep learning for sketches</td>
      </tr>
      <tr>
          <td><strong>Hu et al. RCGD (2023)</strong></td>
          <td>Hand-drawn structures</td>
          <td>Random conditional guided decoder</td>
      </tr>
      <tr>
          <td><strong>ChemReco (2024)</strong></td>
          <td>Hand-drawn C-H-O structures</td>
          <td>EfficientNet + curriculum learning</td>
      </tr>
      <tr>
          <td><strong>DECIMER-Hand-Drawn (2024)</strong></td>
          <td>General hand-drawn</td>
          <td>Enhanced DECIMER architecture</td>
      </tr>
  </tbody>
</table>
<h3 id="hand-drawn-vs-printed-trade-offs">Hand-Drawn vs Printed Trade-offs</h3>
<ul>
<li>Hand-drawn methods sacrifice some accuracy on clean printed images</li>
<li>Require specialized training data (synthetic hand-drawn simulation)</li>
<li>Generally smaller training sets due to data collection difficulty</li>
<li>Better suited for educational and lab notebook applications</li>
</ul>
<h2 id="key-innovations-by-method">Key Innovations by Method</h2>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Primary Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Staker et al.</strong></td>
          <td>First end-to-end deep learning OCSR</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>Transformer decoder + SELFIES</td>
      </tr>
      <tr>
          <td><strong>Img2Mol</strong></td>
          <td>Continuous embedding space (CDDD)</td>
      </tr>
      <tr>
          <td><strong>Image2SMILES</strong></td>
          <td>Functional group-aware SMILES (FG-SMILES)</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR</strong></td>
          <td>Hierarchical vision transformer encoder</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>Massive scale + RanDepict augmentation</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>Extended SMILES + active learning</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>RL fine-tuning (GRPO) for accuracy</td>
      </tr>
      <tr>
          <td><strong>GTR-CoT</strong></td>
          <td>Chain-of-thought graph traversal</td>
      </tr>
      <tr>
          <td><strong>OCSU</strong></td>
          <td>Multi-task vision-language understanding</td>
      </tr>
      <tr>
          <td><strong>RFL</strong></td>
          <td>Hierarchical ring decomposition with SuperAtoms/SuperBonds</td>
      </tr>
  </tbody>
</table>
<h2 id="open-challenges">Open Challenges</h2>
<ol>
<li><strong>Stereochemistry</strong>: Consistent challenge across all methods; RL approaches (MolSight) show promise</li>
<li><strong>Abbreviations/R-groups</strong>: E-SMILES and Markush-specific methods emerging</li>
<li><strong>Real-world robustness</strong>: Gap between synthetic training and patent/paper images</li>
<li><strong>Inference speed</strong>: Rarely reported; important for production deployment</li>
<li><strong>Memory efficiency</strong>: Almost never documented; limits accessibility</li>
<li><strong>Multi-molecule images</strong>: Most methods assume single isolated structure</li>
</ol>
<h2 id="references">References</h2>
<p>Individual paper notes linked throughout. For the complete method listing, see the <a href="/notes/computational-chemistry/optical-structure-recognition/ocsr-methods/">OCSR Methods taxonomy</a>.</p>
]]></content:encoded></item><item><title>DynamicFlow: Integrating Protein Dynamics into Drug Design</title><link>https://hunterheidenreich.com/notes/interdisciplinary/computational-biology/dynamicflow/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/interdisciplinary/computational-biology/dynamicflow/</guid><description>Flow matching model that co-generates ligands and flexible protein pockets, addressing rigid-receptor limitations in structure-based drug design.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a strong <strong>Resource</strong> ($\Psi_{\text{Resource}}$) component.</p>
<ul>
<li><strong>Method</strong>: It proposes <strong>DynamicFlow</strong>, a novel multiscale architecture combining atom-level SE(3)-equivariant GNNs (SE(3) is the special Euclidean group in 3D: the set of all 3D rotations and translations, and equivariance means predictions transform consistently under those symmetries) and residue-level Transformers within a <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">flow matching</a> framework to model the joint distribution of ligand generation and protein conformational change.</li>
<li><strong>Resource</strong>: It curates a significant dataset derived from MISATO, pairing AlphaFold2-predicted apo structures with multiple MD-simulated holo states, specifically filtered for flow matching tasks.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Traditional Structure-Based Drug Design (SBDD) methods typically assume the protein target is rigid, which limits their applicability because proteins are dynamic and undergo conformational changes (induced fit) upon ligand binding.</p>
<ul>
<li><strong>Biological Reality</strong>: Proteins exist as ensembles of states; binding often involves transitions from &ldquo;apo&rdquo; (unbound) to &ldquo;holo&rdquo; (bound) <a href="/posts/geom-conformer-generation-dataset/">conformational changes</a>, sometimes revealing cryptic pockets.</li>
<li><strong>Computational Bottleneck</strong>: <a href="/notes/computational-chemistry/molecular-dynamics/">Molecular Dynamics (MD)</a> simulates these changes but incurs high computational costs due to energy barriers.</li>
<li><strong>Gap</strong>: <a href="/notes/machine-learning/generative-models/">Existing generative models</a> for SBDD mostly condition on a fixed pocket structure, ignoring the co-adaptation of the protein and ligand.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>simultaneous modeling of ligand generation and protein conformational dynamics</strong> using a unified flow matching framework.</p>
<ul>
<li><strong>DynamicFlow Architecture</strong>: A multiscale model that treats the protein as both full-atom (for interaction) and residue-level frames (for large-scale dynamics), utilizing separate flow matching objectives for backbone frames, side-chain torsions, and ligand atoms.</li>
<li><strong>Stochastic Flow (SDE)</strong>: Introduction of a <a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">stochastic variant</a> (DynamicFlow-SDE) that improves robustness and diversity compared to the deterministic ODE flow.</li>
<li><strong>Coupled Generation</strong>: The model learns to transport the <em>apo</em> pocket distribution to the <em>holo</em> pocket distribution while simultaneously denoising the ligand, advancing beyond rigid pocket docking methods.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method on a curated dataset of 5,692 protein-ligand complexes.</p>
<ul>
<li><strong>Baselines</strong>: Compared against rigid-pocket SBDD methods: Pocket2Mol, TargetDiff, and IPDiff (adapted as TargetDiff* and IPDiff* for fair comparison of atom numbers). Also compared against conformation sampling baselines (Str2Str).</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Ligand Quality</strong>: Vina Score (binding affinity), QED (drug-likeness), SA (synthesizability), Lipinski&rsquo;s rule of 5.</li>
<li><strong>Pocket Quality</strong>: RMSD between generated and ground-truth holo pockets, Cover Ratio (percentage of holo states successfully retrieved), and Pocket Volume distributions.</li>
<li><strong>Interaction</strong>: Protein-Ligand Interaction Profiler (PLIP) to measure specific non-covalent interactions.</li>
</ul>
</li>
<li><strong>Ablations</strong>: Tested the impact of the interaction loss, residue-level Transformer, and SDE vs. ODE formulations.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Improved Affinity</strong>: DynamicFlow-SDE achieved the best (lowest) Vina scores ($-7.65$) compared to baselines like TargetDiff ($-5.09$) and Pocket2Mol ($-5.50$). Note that Vina scores are a computational proxy and do not directly predict experimental binding affinity. Moreover, Vina score optimization is gameable: molecules can achieve strong computed binding energies while remaining synthetically inaccessible. QED and SA scores, which assess drug-likeness and synthesizability respectively, were reported but were not primary optimization targets in the paper, which limits the strength of this affinity claim.</li>
<li><strong>Realistic Dynamics</strong>: The model successfully generated holo-like pocket conformations with volume distributions and interaction profiles closer to ground-truth MD simulations than the initial apo structures.</li>
<li><strong>Enhancing Rigid Methods</strong>: Holo pockets generated by DynamicFlow served as better inputs for rigid-SBDD baselines (e.g., TargetDiff improved from $-5.09$ to $-9.00$ and IPDiff improved from $-5.69$ to $-11.04$ when using &ldquo;Our Pocket&rdquo;), suggesting the method can act as a &ldquo;pocket refiner&rdquo;.</li>
<li><strong>ODE vs. SDE Trade-off</strong>: The deterministic ODE variant achieves better pocket RMSD, while the stochastic SDE variant achieves better Cover Ratio (diversity of holo states captured) and binding affinity. Neither dominates uniformly.</li>
<li><strong>Conformation Sampling Baseline</strong>: Str2Str, a dedicated conformation sampling baseline, performed worse than simply perturbing the apo structure with noise. One interpretation is that this highlights the difficulty of the apo-to-holo prediction task; another is that Str2Str was not designed specifically for apo-to-holo prediction, making it a limited test of its capabilities.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset is derived from <strong>MISATO</strong>, which contains MD trajectories for PDBbind complexes.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td>Curated MISATO</td>
          <td>5,692 complexes</td>
          <td>Filtered for valid MD (<a href="/posts/kabsch-algorithm/">RMSD</a> $&lt; 3\text{\AA}$), clustered to remove redundancy. Contains 46,235 holo-ligand conformations total.</td>
      </tr>
      <tr>
          <td><strong>Apo Structures</strong></td>
          <td>AlphaFold2</td>
          <td>N/A</td>
          <td>Apo structures were obtained by mapping PDB IDs to UniProt and retrieving AlphaFold2 predictions, then aligning to MISATO structures.</td>
      </tr>
      <tr>
          <td><strong>Splits</strong></td>
          <td>Standard</td>
          <td>50 test complexes</td>
          <td>50 complexes with no overlap (PM-score $&lt; 0.95$) with training set selected for testing. Note: 50 is a small held-out set; results should be interpreted cautiously.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Clustering</strong>: Holo-ligand conformations clustered with RMSD threshold $1.0\text{\AA}$; top 10 clusters kept per complex.</li>
<li><strong>Pocket Definition</strong>: Residues within $7\text{\AA}$ of the ligand.</li>
<li><strong>Alignment</strong>: AlphaFold predicted structures (apo) aligned to MISATO holo structures using sequence alignment (Smith-Waterman) to identify pocket residues.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Flow Matching Framework</strong>:</p>
<ul>
<li><strong>Continuous Variables</strong> (Pocket translation/rotation/torsions, Ligand positions): Modeled using <strong>Conditional Flow Matching (CFM)</strong>.
<ul>
<li><em>Prior</em>: Apo state for pocket; Normal distribution for ligand positions.</li>
<li><em>Target</em>: Holo state from MD; Ground truth ligand.</li>
<li><em>Interpolant</em>: Linear interpolation for Euclidean variables; Geodesic for rotations ($SO(3)$, the rotation-only subgroup of SE(3) containing all 3D rotations but not translations); Wrapped linear interpolation for torsions (Torus).</li>
</ul>
</li>
<li><strong>Discrete Variables</strong> (Ligand atom/bond types): Modeled using <strong>Discrete Flow Matching</strong> based on Continuous-Time Markov Chains (CTMC).
<ul>
<li><em>Rate Matrix</em>: Interpolates between mask token and data distribution.</li>
</ul>
</li>
<li><strong>Loss Function</strong>: Weighted sum of 7 losses:
<ol>
<li>Translation CFM (Eq 5)</li>
<li>Rotation CFM (Eq 7)</li>
<li>Torsion CFM (Eq 11)</li>
<li>Ligand Position CFM</li>
<li>Ligand Atom Type CTMC (Eq 14)</li>
<li>Ligand Bond Type CTMC</li>
<li><strong>Interaction Loss</strong> (Eq 18): Explicitly penalizes deviations in pairwise distances between protein and ligand atoms for pairs $&lt; 3.5\text{\AA}$.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: <strong>DynamicFlow</strong> is a multiscale model with 15.9M parameters.</p>
<ol>
<li><strong>Atom-Level SE(3)-Equivariant GNN</strong>:
<ul>
<li><em>Input</em>: Complex graph (k-NN, $k=32$) and Ligand graph (fully connected).</li>
<li><em>Layers</em>: 6 EGNN blocks modified to maintain node and edge hidden states.</li>
<li><em>Function</em>: Updates ligand positions and predicts ligand atom/bond types.</li>
</ul>
</li>
<li><strong>Residue-Level Transformer</strong>:
<ul>
<li><em>Input</em>: Aggregated atom features from the GNN + Residue frames/torsions.</li>
<li><em>Layers</em>: 4 Transformer blocks with <strong>Invariant Point Attention (IPA)</strong>.</li>
<li><em>Function</em>: Updates protein residue frames (translation/rotation) and predicts side-chain torsions.</li>
</ul>
</li>
</ol>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Vina Score</strong>: <code>vina_minimize</code> mode used for binding affinity.</li>
<li><strong>RMSD</strong>: Minimum RMSD between generated pocket and ground-truth holo conformations.</li>
<li><strong>Cover Ratio</strong>: % of ground-truth holo conformations covered by at least one generated sample (threshold $1.42\text{\AA}$).</li>
<li><strong>POVME 3</strong>: For pocket volume calculation.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Inference Benchmark</strong>: 1x Tesla V100-SXM2-32GB.</li>
<li><strong>Speed</strong>: Generates 10 ligands in ~35-36 seconds (100 NFE), significantly faster than diffusion baselines like Pocket2Mol (980s) or TargetDiff (156s).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhou, X., Xiao, Y., Lin, H., He, X., Guan, J., Wang, Y., Liu, Q., Zhou, F., Wang, L., &amp; Ma, J. (2025). Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://doi.org/10.48550/arXiv.2503.03989">https://doi.org/10.48550/arXiv.2503.03989</a></p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhouIntegratingProteinDynamics2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhou, Xiangxin and Xiao, Yi and Lin, Haowei and He, Xinheng and Guan, Jiaqi and Wang, Yang and Liu, Qiang and Zhou, Feng and Wang, Liang and Ma, Jianzhu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=uMAujpVi9m}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=uMAujpVi9m">OpenReview Page</a></li>
<li>Code: no public repository available at time of writing</li>
</ul>
]]></content:encoded></item><item><title>ChemDFM-X: Multimodal Foundation Model for Chemistry</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemdfm-x/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemdfm-x/</guid><description>Multimodal chemical model integrating 5 modalities (2D graphs, 3D conformations, images, MS2/IR spectra) trained on 7.6M instructions.</description><content:encoded><![CDATA[<h2 id="chemdfm-x-contribution-and-architecture">ChemDFM-X Contribution and Architecture</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> contribution.</p>
<p><strong>Method</strong>: The paper proposes a novel &ldquo;Cross-modal Dialogue Foundation Model&rdquo; architecture that aligns five distinct chemical modalities (2D graphs, 3D conformations, images, MS2 spectra, IR spectra) to a single LLM decoder using separate encoders and projection modules. It establishes strong baseline performance across multiple modalities compared against current generalist models.</p>
<p><strong>Resource</strong>: The paper addresses the scarcity of multimodal chemical data by constructing a <strong>7.6M instruction-tuning dataset</strong>. This dataset is largely synthesized from seed SMILES strings using approximate calculations (MMFF94, CFM-ID, Chemprop-IR) and specialist model predictions.</p>
<h2 id="bridging-experimental-data-and-llms">Bridging Experimental Data and LLMs</h2>
<p>Existing chemical AI models generally fall into two distinct categories. Task-specific specialist models achieve high accuracy on singular objectives, such as property prediction or molecular generation, but require strict formatting and lack conversational flexibility. Conversely, early chemical large language models provide natural language interaction but are restricted to text and SMILES strings. ChemDFM-X addresses this gap by enabling large multimodal models to process the experimental characterization data (MS2 and IR spectra) and visual data routinely used in practical chemistry workflows.</p>
<h2 id="synthetic-data-scaling-for-modality-alignment">Synthetic Data Scaling for Modality Alignment</h2>
<p>The core novelty lies in the <strong>&ldquo;Any-to-Text&rdquo; alignment strategy via synthetic data scaling</strong>:</p>
<ol>
<li>
<p><strong>Comprehensive Modality Support</strong>: ChemDFM-X incorporates experimental characterization data (MS2 and IR spectra) alongside 2D graphs, 3D conformations, and images. The data representations are formally defined mathematically rather than as raw pixels:</p>
<ul>
<li><strong>Molecular Graph</strong>: An undirected graph $G = (\textbf{V}, \textbf{E})$ with atom set $\textbf{V}$ and bond set $\textbf{E}$.</li>
<li><strong>Molecular Conformation</strong>: An undirected graph $G = (\textbf{V}&rsquo;, \textbf{E})$ storing spatial coordinates: $\textbf{v}_i = (x_i, y_i, z_i, a_i)$.</li>
<li><strong>MS2 Spectrum</strong>: Treated as a point sequence of discrete mass-to-charge ratios and intensities, tokenized via a discrete codebook: $\textbf{M} = ((r_1, I_1), (r_2, I_2), \dots, (r_n, I_n))$.</li>
<li><strong>IR Spectrum</strong>: Treated as a dense sequence of continuous wave lengths and absorption intensities, directly reshaped for feature extraction: $\textbf{R} = ((w_1, t_1), (w_2, t_2), \dots, (w_l, t_l))$.</li>
</ul>
<p>The authors trained new Sequence Transformer encoders from scratch for the MS2 and IR modalities since suitable pre-trained models did not exist.</p>
</li>
<li>
<p><strong>Synthetic Data Generation Pipeline</strong>: The authors generated a 7.6M sample dataset by starting with 1.3M seed SMILES and using &ldquo;approximate calculations&rdquo; to generate missing modalities:</p>
<ul>
<li>3D conformations via MMFF94 force field optimization</li>
<li>MS2 spectra via CFM-ID 4.0 (Competitive Fragmentation Modeling)</li>
<li>IR spectra via Chemprop-IR (Message Passing Neural Network)</li>
</ul>
</li>
<li>
<p><strong>Cross-Modal Synergy</strong>: The model demonstrates that training on reaction images improves recognition performance by leveraging semantic chemical knowledge (reaction rules) to correct visual recognition errors, an emergent capability from multimodal training.</p>
</li>
</ol>
<h2 id="multimodal-benchmarking-with-chemllmbench">Multimodal Benchmarking with ChemLLMBench</h2>
<p>The model was evaluated using a customized version of <strong>ChemLLMBench</strong> and <strong>MoleculeNet</strong> across three modality categories:</p>
<ol>
<li>
<p><strong>Structural Modalities</strong> (2D Graphs &amp; 3D Conformations):</p>
<ul>
<li>Molecule recognition and captioning</li>
<li>Property prediction (MoleculeNet: BACE, BBBP, ClinTox, HIV, Tox21)</li>
<li>Compared against specialist models (Mole-BERT, Uni-Mol, MolXPT, MolCA) and generalist models (3D-MoLM, ChemDFM, ChemLLM)</li>
</ul>
</li>
<li>
<p><strong>Visual Modalities</strong> (Images):</p>
<ul>
<li>Single molecule image recognition</li>
<li>Reaction image recognition</li>
<li>Compared against GPT-4O, Gemini 1.5 Pro, Qwen-VL, LLaVA, and specialist models MolNextr and MolScribe</li>
</ul>
</li>
<li>
<p><strong>Characterization Modalities</strong> (MS2 &amp; IR Spectra):</p>
<ul>
<li>Spectral analysis tasks (identifying molecules from spectra)</li>
<li>Contextualized spectral interpretation (combining spectra with reaction context)</li>
<li>Novel evaluation requiring integration of spectroscopic data with reaction knowledge</li>
</ul>
</li>
</ol>
<h2 id="cross-modal-synergy-and-generalist-performance">Cross-Modal Synergy and Generalist Performance</h2>
<p><strong>Key Findings</strong>:</p>
<ol>
<li>
<p><strong>Leading Generalist Performance</strong>: ChemDFM-X establishes a new benchmark among existing generalist models (such as 3D-MOLM and ChemLLM), achieving performance metrics that match dedicated specialist models across several multimodal tasks.</p>
</li>
<li>
<p><strong>Failure of General LMMs</strong>: General vision models (GPT-4O, Gemini 1.5 Pro, Qwen-VL, LLaVA, InternLM-XComposer2, DocOwl) failed significantly on chemical image recognition tasks (0% accuracy for most models on molecule and reaction recognition, Table 9), demonstrating that chemical domain knowledge cannot be assumed from general pre-training.</p>
</li>
<li>
<p><strong>Cross-Modal Error Correction</strong>: In reaction image recognition, ChemDFM-X achieved higher accuracy (53.0%) than on single molecules (46.0%) (Table 9). The authors conclude the model uses its internal knowledge of chemical reaction rules to correct recognition errors in the visual modality, an emergent capability from multimodal training.</p>
</li>
<li>
<p><strong>Reliance on Reaction Context for Spectra</strong>: In zero-shot scenarios, ChemDFM-X essentially fails at pure spectral recognition (achieving 0% and 1% top-1 accuracy on MS2 and IR spectra alone, Table 11). However, when SMILES-based reaction context is included, performance rises to 45% (MS2) and 64% (IR) on the reaction prediction task, and 29% (MS2) and 60% (IR) on retrosynthesis (Table 11). This indicates the model uses spectral data as a soft prior to constrain textual deductions. Furthermore, the paper compares ChemDFM-X’s spectral identification performance exclusively against text-only LLMs that cannot process spectra, omitting comparisons against established specialist tools.</p>
</li>
<li>
<p><strong>Surrogate Distillation Trade-offs</strong>: Because the spectral training data relies entirely on outputs from CFM-ID 4.0 and Chemprop-IR, ChemDFM-X effectively distills these surrogate models. Any inherent predictive biases or inaccuracies from these underlying tools are permanently embedded in the new ChemDFM-X encoders.</p>
</li>
</ol>
<p><strong>Main Conclusion</strong>: The &ldquo;separate encoders + unified decoder&rdquo; architecture with synthetic data generation enables effective multimodal chemical understanding, bridging the gap between specialist and generalist AI systems for chemistry.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors constructed a <strong>7.6M sample instruction-tuning dataset</strong> derived from <strong>1.3M seed SMILES</strong> (sourced from PubChem and USPTO). <strong>Note</strong>: The final 7.6M multimodal tuning dataset itself isn&rsquo;t publicly available.</p>
<p><strong>Generation Pipeline</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Generation Method</th>
          <th>Tool/Model</th>
          <th>Sample Count</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2D Graphs</strong></td>
          <td>Direct extraction from SMILES</td>
          <td>RDKit</td>
          <td>1.1M</td>
      </tr>
      <tr>
          <td><strong>3D Conformations</strong></td>
          <td>Force field optimization</td>
          <td>RDKit + MMFF94</td>
          <td>1.3M (pseudo-optimal)</td>
      </tr>
      <tr>
          <td><strong>Molecule Images</strong></td>
          <td>Rendering with augmentation</td>
          <td>RDKit, Indigo, ChemPix</td>
          <td>~1M (including handwritten style)</td>
      </tr>
      <tr>
          <td><strong>Reaction Images</strong></td>
          <td>Rendering from reaction SMILES</td>
          <td>RDKit</td>
          <td>300K</td>
      </tr>
      <tr>
          <td><strong>MS2 Spectra</strong></td>
          <td>Computational prediction</td>
          <td>CFM-ID 4.0</td>
          <td>~700K</td>
      </tr>
      <tr>
          <td><strong>IR Spectra</strong></td>
          <td>Computational prediction</td>
          <td>Chemprop-IR</td>
          <td>~1M</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Augmentation</strong>:</p>
<ul>
<li>Molecule images augmented with &ldquo;handwritten&rdquo; style using the ChemPix pipeline</li>
<li>Multiple rendering styles (RDKit default, Indigo clean)</li>
<li>Spectra generated at multiple energy levels (10eV, 20eV, 40eV for MS2)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Architecture</strong>: &ldquo;Separate Encoders + Unified Decoder&rdquo;</p>
<p><strong>Code Availability</strong>: The authors have only released inference logic. The cross-modal projection training and synthetic data-generation scripts are closed.</p>
<p><strong>Modality Alignment</strong>:</p>
<ul>
<li>Each modality has a dedicated encoder (frozen pre-trained models where available)</li>
<li>For graph, conformation, MS2, and IR modalities: <strong>2-layer MLP projector</strong> (Linear, GELU, Linear) maps encoder features to LLM input space</li>
<li>For images: <strong>H-Reducer</strong> module compresses image tokens by factor of $n=8$ to handle high-resolution chemical images, then projects to LLM input space</li>
<li>All projected features are concatenated and fed to the unified LLM decoder</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Base LLM</strong>:</p>
<ul>
<li><strong>ChemDFM (13B)</strong>: LLaMA-based model pre-trained on chemical text and SMILES</li>
</ul>
<p><strong>Modality Encoders</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Encoder</th>
          <th>Pre-training Data</th>
          <th>Parameter Count</th>
          <th>Status</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2D Graph</strong></td>
          <td>Mole-BERT</td>
          <td>2M molecules</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>3D Conformation</strong></td>
          <td>Uni-Mol</td>
          <td>209M conformations</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>Image</strong></td>
          <td>CLIP (ViT)</td>
          <td>General domain</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>MS2 Spectrum</strong></td>
          <td>Transformer (SeqT)</td>
          <td>Trained from scratch</td>
          <td>-</td>
          <td><strong>Trainable</strong></td>
      </tr>
      <tr>
          <td><strong>IR Spectrum</strong></td>
          <td>Transformer (SeqT)</td>
          <td>Trained from scratch</td>
          <td>-</td>
          <td><strong>Trainable</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Design Rationale</strong>: MS2 and IR encoders trained from scratch as Sequence Transformers treating spectral peaks as token sequences, since no suitable pre-trained models exist for chemical spectra.</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Accuracy (Acc)</strong> for recognition tasks</li>
<li><strong>BLEU-2/4</strong> and <strong>METEOR</strong> for captioning tasks</li>
<li><strong>AUC-ROC</strong> for property prediction (classification)</li>
</ul>
<p><strong>Code Availability</strong>: The adapted code for evaluating on ChemLLMBench and their custom spectral recognition tasks is closed-source.</p>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>ChemLLMBench</strong>: Adapted for multimodal inputs across molecule captioning, property prediction, and reaction understanding</li>
<li><strong>MoleculeNet</strong>: Standard molecular property prediction tasks (BACE, BBBP, ClinTox, HIV, Tox21)</li>
<li><strong>USPTO</strong>: Reaction prediction and retrosynthesis tasks</li>
<li><strong>Custom Spectral Tasks</strong>: Novel evaluations requiring spectral interpretation</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Note</strong>: The type and quantity of GPUs used, along with the total training wall-time, were not published.</p>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Total Batch Size</strong>: 256</li>
<li><strong>Epochs</strong>: 3</li>
<li><strong>Optimizer</strong>: AdamW</li>
</ul>
<p><strong>Modality-Specific Learning Rates (Peak)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Learning Rate</th>
          <th>Feature Dimension</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph</td>
          <td>1e-5</td>
          <td>300</td>
      </tr>
      <tr>
          <td>Conformation</td>
          <td>2e-4</td>
          <td>512</td>
      </tr>
      <tr>
          <td>Image</td>
          <td>2e-3</td>
          <td>1024</td>
      </tr>
      <tr>
          <td>MS2 / IR</td>
          <td>2e-4</td>
          <td>768</td>
      </tr>
  </tbody>
</table>
<p><strong>Note</strong>: Different learning rates reflect the varying degrees of domain adaptation required. Images (general CLIP) need more adaptation than graphs (chemical Mole-BERT).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/OpenDFM/ChemDFM-X">ChemDFM-X (GitHub)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Inference code only; training and data generation scripts are closed</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/OpenDFM/ChemDFM-X-v1.0-13B">ChemDFM-X-v1.0-13B (HuggingFace)</a></td>
          <td>Model</td>
          <td>AGPL-3.0</td>
          <td>13B parameter multimodal model weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhao, Z., Chen, B., Li, J., Chen, L., Wen, L., Wang, P., Zhu, Z., Zhang, D., Wan, Z., Li, Y., Dai, Z., Chen, X., &amp; Yu, K. (2024). ChemDFM-X: Towards Large Multimodal Model for Chemistry. <em>Science China Information Sciences</em>, 67(12), 220109. <a href="https://doi.org/10.1007/s11432-024-4243-0">https://doi.org/10.1007/s11432-024-4243-0</a></p>
<p><strong>Publication</strong>: Science China Information Sciences, December 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2409.13194">arXiv Version</a></li>
<li><a href="https://github.com/OpenDFM/ChemDFM-X">Code Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhaoChemDFMXLargeMultimodal2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemDFM-X}}: {{Towards Large Multimodal Model}} for {{Chemistry}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhao, Zihan and Chen, Bo and Li, Jingpiao and Chen, Lu and Wen, Liyang and Wang, Pengyu and Zhu, Zichen and Zhang, Danyang and Wan, Ziping and Li, Yansi and Dai, Zhongyang and Chen, Xin and Yu, Kai}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Science China Information Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{220109}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1007/s11432-024-4243-0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2409.13194}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RFL: Simplifying Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/rfl/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/rfl/</guid><description>Novel Ring-Free Language representation and Molecular Skeleton Decoder architecture for improved optical chemical structure recognition from images.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chang, Q., Chen, M., Pi, C., et al. (2024). RFL: Simplifying Chemical Structure Recognition with Ring-Free Language. <em>arXiv preprint arXiv:2412.07594</em>. <a href="https://doi.org/10.48550/arXiv.2412.07594">https://doi.org/10.48550/arXiv.2412.07594</a></p>
<p><strong>Publication</strong>: arXiv preprint (December 2024)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/JingMog/RFL-MSD">Official Code Repository</a></li>
</ul>
<h2 id="methodological-contribution">Methodological Contribution</h2>
<p>This is a <strong>Methodological</strong> paper ($\Psi_{\text{Method}}$). It introduces a novel representation system (Ring-Free Language) and a specialized neural architecture (Molecular Skeleton Decoder) designed to solve specific limitations in converting 2D images to 1D chemical strings. The paper validates this method through direct comparison with state-of-the-art baselines and ablation studies.</p>
<h2 id="motivation-limitations-of-1d-serialization">Motivation: Limitations of 1D Serialization</h2>
<p>Current Optical Chemical Structure Recognition (OCSR) methods typically rely on &ldquo;unstructured modeling,&rdquo; where 2D molecular graphs are flattened into 1D strings like SMILES or SSML. While simple, these linear formats struggle to explicitly capture complex spatial relationships, particularly in molecules with multiple rings and branches. End-to-end models often fail to &ldquo;understand&rdquo; the graph structure when forced to predict these implicit 1D sequences, leading to error accumulation in complex scenarios.</p>
<h2 id="innovation-ring-free-language-rfl-and-molecular-skeleton-decoder-msd">Innovation: Ring-Free Language (RFL) and Molecular Skeleton Decoder (MSD)</h2>
<p>The authors propose two primary contributions to decouple spatial complexity:</p>
<ol>
<li><strong>Ring-Free Language (RFL)</strong>: A divide-and-conquer representation that splits a molecular graph $G$ into three explicit components: a molecular skeleton $\mathcal{S}$, individual ring structures $\mathcal{R}$, and branch information $\mathcal{F}$. This allows rings to be collapsed into &ldquo;SuperAtoms&rdquo; or &ldquo;SuperBonds&rdquo; during initial parsing.</li>
<li><strong>Molecular Skeleton Decoder (MSD)</strong>: A hierarchical architecture that progressively predicts the skeleton first, then the individual rings (using SuperAtom features as conditions), and finally classifies the branch connections.</li>
</ol>
<h2 id="methodology-and-experiments">Methodology and Experiments</h2>
<p>The method was evaluated on both handwritten and printed chemical structures against two baselines: DenseWAP (Zhang et al. 2018) and RCGD (Hu et al. 2023).</p>
<ul>
<li><strong>Datasets</strong>:
<ul>
<li><strong>EDU-CHEMC</strong>: ~49k handwritten samples (challenging, diverse styles)</li>
<li><strong>Mini-CASIA-CSDB</strong>: ~89k printed samples (from ChEMBL)</li>
<li><strong>Synthetic Complexity Dataset</strong>: A custom split of ChEMBL data grouped by structural complexity (atoms + bonds + rings) to test generalization</li>
</ul>
</li>
<li><strong>Ablation Studies</strong>: Tested the necessity of the MSD architecture vs. standard decoders and the impact of the <code>[conn]</code> token for filtering branch candidates</li>
</ul>
<h2 id="outcomes-and-conclusions">Outcomes and Conclusions</h2>
<ul>
<li><strong>SOTA Performance</strong>: The proposed method (MSD-RCGD) outperformed the previous state-of-the-art (RCGD) on both handwritten (95.38% Exact Match) and printed (95.58% Exact Match) datasets.</li>
<li><strong>Universal Improvement</strong>: Applying MSD/RFL to an older baseline (DenseWAP) improved its accuracy significantly (e.g., from 87.6% to 81.9% on complex sets), proving the method is model-agnostic.</li>
<li><strong>Complexity Handling</strong>: The method showed robust generalization on unseen high-complexity molecules, where standard 1D baselines failed completely (0% accuracy for standard DenseWAP vs. ~30% for MSD version on the hardest tier).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors utilized one handwritten and one printed dataset, plus a synthetic set for stress-testing complexity.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td><strong>EDU-CHEMC</strong></td>
          <td>48,998 Train / 2,992 Test</td>
          <td>Handwritten images from educational scenarios</td>
      </tr>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td><strong>Mini-CASIA-CSDB</strong></td>
          <td>89,023 Train / 8,287 Test</td>
          <td>Printed images rendered from ChEMBL using RDKit</td>
      </tr>
      <tr>
          <td><strong>Generalization</strong></td>
          <td><strong>ChEMBL Subset</strong></td>
          <td>5 levels of complexity</td>
          <td>Custom split based on Eq: $N_{atom} + N_{bond} + 12 \times N_{ring}$</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>RFL Splitting (Encoding)</strong>:</p>
<ol>
<li><strong>Detect Rings</strong>: Use DFS to find all non-nested rings $\mathcal{R}$.</li>
<li><strong>Determine Adjacency ($\gamma$)</strong>: Calculate shared edges between rings.</li>
<li><strong>Merge</strong>:
<ul>
<li>If $\gamma(r_i) = 0$ (isolated), merge ring into a <strong>SuperAtom</strong> node.</li>
<li>If $\gamma(r_i) &gt; 0$ (adjacent), merge ring into a <strong>SuperBond</strong> edge.</li>
</ul>
</li>
<li><strong>Update</strong>: Record connection info in $\mathcal{F}$ and remove ring details from the main graph to form Skeleton $\mathcal{S}$.</li>
</ol>
<p><strong>MSD Decoding</strong>:</p>
<ul>
<li><strong>Hierarchical Prediction</strong>: The model predicts the Skeleton $\mathcal{S}$ first.</li>
<li><strong>Contextual Ring Prediction</strong>: When a SuperAtom/Bond token is predicted, its hidden state $f^s$ is stored. After the skeleton is finished, $f^s$ is used as a condition to autoregressively decode the specific ring structure.</li>
<li><strong>Token <code>[conn]</code></strong>: A special token separates connected ring bonds from unconnected ones to sparsify the branch classification task.</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture follows a standard Image-to-Sequence pattern but with a forked decoder.</p>
<ul>
<li><strong>Encoder</strong>: DenseNet (Growth rate=24, Depth=32 per block)</li>
<li><strong>Decoder (MSD)</strong>:
<ul>
<li><strong>Core</strong>: GRU with Attention (Hidden dim=256, Embedding dim=256, Dropout=0.15)</li>
<li><strong>Skeleton Module</strong>: Autoregressively predicts sequence tokens. Uses Maxout activation.</li>
<li><strong>Branch Module</strong>: A binary classifier (MLP) taking concatenated features of skeleton bonds $f_{bs}$ and ring bonds $f_{br}$ to predict connectivity matrix $\mathcal{F}$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>: $\mathcal{L} = \lambda_1 \mathcal{L}_{ce} + \lambda_2 \mathcal{L}_{cls}$ (where $\lambda=1$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics focus on exact image reconstruction and structural validity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>EM (Exact Match)</strong></td>
          <td>% of images where predicted graph exactly matches ground truth.</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td><strong>Struct-EM</strong></td>
          <td>% of correctly identified chemical structures (ignoring non-chemical text).</td>
          <td>Auxiliary metric</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA Tesla V100 (32GB RAM)</li>
<li><strong>Training Configuration</strong>:
<ul>
<li>Batch size: 8 (Handwritten), 32 (Printed)</li>
<li>Epochs: 50</li>
<li>Optimizer: Adam ($lr=2\times10^{-4}$, decayed by 0.5 via MultiStepLR)</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{changRFLSimplifyingChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{RFL: Simplifying Chemical Structure Recognition with Ring-Free Language}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{RFL}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chang, Qikai and Chen, Mingjun and Pi, Changpeng and Hu, Pengfei and Zhang, Zhenrong and Ma, Jiefeng and Du, Jun and Yin, Baocai and Hu, Jinshui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2412.07594}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2412.07594}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2412.07594}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>OCSU: Optical Chemical Structure Understanding</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/ocsu/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/ocsu/</guid><description>OCSU task for translating molecular images into multi-level descriptions. Introduces Vis-CheBI20 dataset and DoubleCheck/Mol-VL architectures for chemical OCR.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fan, S., Xie, Y., Cai, B., Xie, A., Liu, G., Qiao, M., Xing, J., &amp; Nie, Z. (2025). OCSU: Optical Chemical Structure Understanding for Molecule-centric Scientific Discovery. <em>arXiv preprint arXiv:2501.15415</em>. <a href="https://doi.org/10.48550/arXiv.2501.15415">https://doi.org/10.48550/arXiv.2501.15415</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/PharMolix/OCSU">Code and Dataset (GitHub)</a></li>
</ul>
<h2 id="multi-level-chemical-understanding-method-and-resource">Multi-Level Chemical Understanding (Method and Resource)</h2>
<p>This is primarily a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a significant <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution.</p>
<ul>
<li><strong>Methodological</strong>: It proposes two novel architectures, <strong>DoubleCheck</strong> (an enhanced recognition model) and <strong>Mol-VL</strong> (an end-to-end vision-language model), to solve the newly formulated OCSU task.</li>
<li><strong>Resource</strong>: It constructs and releases <strong>Vis-CheBI20</strong>, the first large-scale dataset specifically designed for optical chemical structure understanding, containing 29.7K images and 117.7K image-text pairs.</li>
</ul>
<h2 id="the-motivation-for-ocsu-beyond-basic-graph-recognition">The Motivation for OCSU Beyond Basic Graph Recognition</h2>
<p>Existing methods for processing molecular images focus narrowly on <strong>Optical Chemical Structure Recognition (OCSR)</strong>, which translates an image solely into a machine-readable graph or SMILES string. However, SMILES strings are not chemist-friendly and lack high-level semantic context.</p>
<ul>
<li><strong>Gap</strong>: There is a lack of systems that can translate chemical diagrams into human-readable descriptions (e.g., functional groups, IUPAC names) alongside the graph structure.</li>
<li><strong>Goal</strong>: To enable <strong>Optical Chemical Structure Understanding (OCSU)</strong>, bridging the gap between visual representations and both machine/chemist-readable descriptions to support drug discovery and property prediction.</li>
</ul>
<h2 id="key-innovations-doublecheck-mol-vl-and-the-vis-chebi20-dataset">Key Innovations: DoubleCheck, Mol-VL, and the Vis-CheBI20 Dataset</h2>
<p>The paper introduces the <strong>OCSU task</strong>, enabling multi-level understanding (motif, molecule, and abstract levels). To solve this, it introduces two distinct paradigms:</p>
<ol>
<li><strong>DoubleCheck (OCSR-based)</strong>: An enhancement to standard OCSR models (like MolScribe) that performs a &ldquo;second look&rdquo; at locally ambiguous atoms. It uses attentive feature enhancement to fuse global molecular features with local features from ambiguous regions.</li>
<li><strong>Mol-VL (OCSR-free)</strong>: An end-to-end Vision-Language Model (VLM) based on Qwen2-VL. It uses multi-task learning to directly generate text descriptions from molecular images without an intermediate SMILES step.</li>
<li><strong>Vis-CheBI20 Dataset</strong>: A new benchmark specifically constructed for OCSU, deriving captions and functional group data from ChEBI-20 and PubChem.</li>
</ol>
<h2 id="methodology-and-experimental-evaluation">Methodology and Experimental Evaluation</h2>
<p>The authors evaluated both paradigms on <strong>Vis-CheBI20</strong> and existing benchmarks (USPTO, ACS) across four subtasks:</p>
<ol>
<li><strong>Functional Group Caption</strong>: Retrieval/F1 score evaluation.</li>
<li><strong>Molecule Description</strong>: Natural language generation metrics (BLEU, ROUGE, METEOR).</li>
<li><strong>IUPAC Naming</strong>: Text generation metrics (BLEU, ROUGE).</li>
<li><strong>SMILES Naming (OCSR)</strong>: Exact matching accuracy ($Acc_s$).</li>
</ol>
<p><strong>Baselines</strong>:</p>
<ul>
<li><strong>Task-Specific</strong>: MolScribe, MolVec, OSRA.</li>
<li><strong>LLM/VLM</strong>: Qwen2-VL, BioT5+, Mol-Instructions.</li>
<li><strong>Ablation</strong>: DoubleCheck vs. MolScribe backbone to test the &ldquo;feature enhancement&rdquo; mechanism.</li>
</ul>
<h2 id="results-and-conclusions-paradigm-trade-offs">Results and Conclusions: Paradigm Trade-Offs</h2>
<ul>
<li><strong>DoubleCheck Superiority</strong>: DoubleCheck outperformed the MolScribe baseline on OCSR tasks, achieving <strong>92.85%</strong> accuracy on Vis-CheBI20 (vs. 92.57%) and showing significant gains on chiral molecules in the ACS dataset (+3.12%).</li>
<li><strong>Paradigm Trade-offs</strong>:
<ul>
<li><strong>Mol-VL (OCSR-free)</strong> excelled at semantic tasks like <strong>Functional Group Captioning</strong>, outperforming the strongest baseline by <strong>7.7%</strong> in F1 score. It benefits from end-to-end learning of structural context.</li>
<li><strong>DoubleCheck (OCSR-based)</strong> performed better on <strong>IUPAC naming recall</strong> and exact SMILES recovery, as explicit graph reconstruction is more precise for rigid nomenclature than VLM generation.</li>
</ul>
</li>
<li><strong>Conclusion</strong>: Enhancing submodules improves OCSR-based paradigms, while end-to-end VLMs offer stronger semantic understanding but struggle with exact syntax generation (SMILES/IUPAC).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Vis-CheBI20 Dataset</strong></p>
<ul>
<li><strong>Source</strong>: Derived from ChEBI-20 and PubChem.</li>
<li><strong>Size</strong>: 29,700 molecular diagrams, 117,700 image-text pairs.</li>
<li><strong>Generation</strong>: Images generated from SMILES using RDKit to simulate real-world journal/patent styles.</li>
<li><strong>Splits</strong>:
<ul>
<li><strong>Training</strong>: ~26,000 images (varies slightly by task).</li>
<li><strong>Test</strong>: ~3,300 images.</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Train Size</th>
          <th style="text-align: left">Test Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Functional Group</td>
          <td style="text-align: left">26,144</td>
          <td style="text-align: left">3,269</td>
      </tr>
      <tr>
          <td style="text-align: left">Description</td>
          <td style="text-align: left">26,407</td>
          <td style="text-align: left">3,300</td>
      </tr>
      <tr>
          <td style="text-align: left">IUPAC/SMILES</td>
          <td style="text-align: left">26,200</td>
          <td style="text-align: left">2,680</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>DoubleCheck (Attentive Feature Enhancement)</strong></p>
<ol>
<li><strong>Ambiguity Detection</strong>: Uses atom prediction confidence to identify &ldquo;ambiguous atoms&rdquo;.</li>
<li><strong>Masking</strong>: Applies a 2D Gaussian mask to the image centered on the ambiguous atom.</li>
<li><strong>Local Encoding</strong>: A Swin-B encoder ($\Phi_l$) encodes the masked image region.</li>
<li><strong>Fusion</strong>: Aligns local features ($\mathcal{F}_l$) with global features ($\mathcal{F}_g$) using a 2-layer MLP and fuses them via weighted summation.</li>
</ol>
<p>$$
\begin{aligned}
\mathcal{F}_e = \mathcal{F}_g + \text{MLP}(\mathcal{F}_g \oplus \hat{\mathcal{F}}_l) \cdot \hat{\mathcal{F}}_l
\end{aligned}
$$</p>
<ol start="5">
<li><strong>Two-Stage Training</strong>:
<ul>
<li>Stage 1: Train atom/bond predictors (30 epochs).</li>
<li>Stage 2: Train alignment/fusion modules with random Gaussian mask noise (10 epochs).</li>
</ul>
</li>
</ol>
<p><strong>Mol-VL (Multi-Task VLM)</strong></p>
<ul>
<li><strong>Prompting</strong>: System prompt: &ldquo;You are working as an excellent assistant in chemistry&hellip;&rdquo;</li>
<li><strong>Tokens</strong>: Uses <code>&lt;image&gt;</code> and <code>&lt;/image&gt;</code> special tokens.</li>
<li><strong>Auxiliary Task</strong>: Functional group recognition (identifying highlighted groups) added to training to improve context learning.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>DoubleCheck</strong>:
<ul>
<li><strong>Backbone</strong>: MolScribe architecture.</li>
<li><strong>Encoders</strong>: Swin-B for both global and local atom encoding.</li>
</ul>
</li>
<li><strong>Mol-VL</strong>:
<ul>
<li><strong>Base Model</strong>: Qwen2-VL (2B and 7B versions).</li>
<li><strong>Vision Encoder</strong>: ViT with naive dynamic resolution and M-ROPE.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Key Metrics</strong>:</p>
<ul>
<li><strong>SMILES</strong>: Exact Match Accuracy ($Acc_s$), Chiral Accuracy ($Acc_c$).</li>
<li><strong>Functional Groups</strong>: F1 Score (Information Retrieval task).</li>
<li><strong>Text Generation</strong>: BLEU-2/4, METEOR, ROUGE-L.</li>
</ul>
<p><strong>Selected Results</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Model</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>DoubleCheck</strong></td>
          <td style="text-align: left">OCSR (Vis-CheBI20)</td>
          <td style="text-align: left">$Acc_s$</td>
          <td style="text-align: left"><strong>92.85%</strong></td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>MolScribe</strong></td>
          <td style="text-align: left">OCSR (Vis-CheBI20)</td>
          <td style="text-align: left">$Acc_s$</td>
          <td style="text-align: left">92.57%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Mol-VL-7B</strong></td>
          <td style="text-align: left">Func. Group Caption</td>
          <td style="text-align: left">F1</td>
          <td style="text-align: left"><strong>97.32%</strong></td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>DoubleCheck</strong></td>
          <td style="text-align: left">Func. Group Caption</td>
          <td style="text-align: left">F1</td>
          <td style="text-align: left">93.63%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>DoubleCheck</strong>: Trained on <strong>4 NVIDIA A100 GPUs</strong> for <strong>4 days</strong>.
<ul>
<li>Max LR: 4e-4.</li>
</ul>
</li>
<li><strong>Mol-VL</strong>: Trained on <strong>4 NVIDIA A100 GPUs</strong> for <strong>10 days</strong>.
<ul>
<li>Max LR: 1e-5, 50 epochs.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{fanOCSUOpticalChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{OCSU: Optical Chemical Structure Understanding for Molecule-centric Scientific Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{OCSU}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Fan, Siqi and Xie, Yuguang and Cai, Bowen and Xie, Ailin and Liu, Gaochao and Qiao, Mu and Xing, Jie and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolSight: OCSR with RL and Multi-Granularity Learning</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molsight/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molsight/</guid><description>A three-stage OCSR framework using SMILES pretraining, auxiliary bond/coordinate tasks, and reinforcement learning to master stereochemistry recognition.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, W., Wang, X., Feng, B., &amp; Liu, W. (2025). MolSight: Optical Chemical Structure Recognition with SMILES Pretraining, Multi-Granularity Learning and Reinforcement Learning. <em>arXiv preprint arXiv:2511.17300</em>. <a href="https://doi.org/10.48550/arXiv.2511.17300">https://doi.org/10.48550/arXiv.2511.17300</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/hustvl/MolSight">Official Repository</a></li>
</ul>
<h2 id="contribution-a-framework-for-optical-chemical-structure-recognition">Contribution: A Framework for Optical Chemical Structure Recognition</h2>
<p>This is primarily a <strong>Method</strong> paper. It proposes a novel three-stage training framework (Pretraining → Fine-tuning → RL Post-training) to improve Optical Chemical Structure Recognition (OCSR). Specifically, it introduces the use of Group Relative Policy Optimization (GRPO) to solve non-differentiable chemical validity issues.</p>
<p>It also has a <strong>Resource</strong> component, as the authors construct and release <em>Stereo-200k</em>, a dataset specifically designed to train models on challenging stereoisomeric molecules.</p>
<h2 id="motivation-resolving-stereochemical-cues">Motivation: Resolving Stereochemical Cues</h2>
<p>Existing OCSR systems struggle to accurately recognize stereochemical information (e.g., chirality, geometric isomerism) because the visual cues distinguishing stereoisomers (such as wedge and dash bonds) are subtle. Current methods often fail to capture the geometric relationships required to distinguish molecules with identical connectivity but different spatial arrangements. Accurate recognition is critical for downstream tasks like drug discovery where stereochemistry determines pharmacological effects.</p>
<h2 id="core-innovations-grpo-and-multi-granularity-learning">Core Innovations: GRPO and Multi-Granularity Learning</h2>
<p>MolSight introduces three key technical innovations:</p>
<ol>
<li><strong>Reinforcement Learning for OCSR</strong>: It is the first OCSR system to incorporate RL (specifically GRPO) to directly optimize for chemical semantic correctness.</li>
<li><strong>Multi-Granularity Learning</strong>: It employs auxiliary heads for chemical bond classification and atom localization. Unlike previous approaches that optimize these jointly, MolSight decouples the coordinate head to prevent interference with SMILES generation.</li>
<li><strong>SMILES-M Notation</strong>: A lightweight extension to SMILES to handle Markush structures (common in patents) without significant sequence length increase.</li>
</ol>
<h2 id="experimental-methodology">Experimental Methodology</h2>
<p>The authors evaluated MolSight using a rigorous mix of real and synthetic benchmarks:</p>
<ul>
<li><strong>Baselines</strong>: Compared against rule-based (OSRA, MolVec) and deep learning methods (MolScribe, MolGrapher, DECIMER).</li>
<li><strong>Benchmarks</strong>: Evaluated on real-world datasets (USPTO, Maybridge UoB, CLEF-2012, JPO) and synthetic datasets (Staker, ChemDraw, Indigo, Stereo-2K).</li>
<li><strong>Ablation Studies</strong>: Tested the impact of the bond head, coordinate head, and RL stages separately.</li>
<li><strong>Transfer Learning</strong>: Assessed the quality of learned representations by using the frozen encoder for molecular property prediction on MoleculeNet.</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>SOTA Performance</strong>: MolSight achieved 85.1% stereochemical accuracy on the USPTO dataset, significantly outperforming the previous SOTA (MolScribe) which achieved 69.0%.</li>
<li><strong>RL Effectiveness</strong>: Reinforcement learning post-training specifically improved performance on stereoisomers, raising Tanimoto similarity and exact match rates on the Stereo-2k test set.</li>
<li><strong>Robustness</strong>: The model maintained high performance even on perturbed (rotated/sheared) and low-resolution images, outperforming rule-based methods significantly in these scenarios.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training pipeline uses three distinct data sources:</p>
<ol>
<li><strong>Pre-training</strong>: <em>MolParser-7M</em>. Contains diverse images but requires the <strong>SMILES-M</strong> extension to handle Markush structures.</li>
<li><strong>Fine-tuning</strong>: <em>PubChem-1M</em> and <em>USPTO-680K</em>. Used for multi-granularity learning with bond and coordinate labels.</li>
<li><strong>RL Post-training</strong>: <em>Stereo-200k</em>. A self-collected dataset from the first 2M compounds in PubChem, filtered for chirality (&rsquo;@&rsquo;) and cis-trans isomerism (&rsquo;/&rsquo;, &lsquo;\&rsquo;). It uses 5 different RDKit drawing styles to ensure robustness.</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Reinforcement Learning</strong>: Uses <strong>GRPO (Group Relative Policy Optimization)</strong>.
<ul>
<li><strong>Reward Function</strong>: A linear combination of Tanimoto similarity and Stereochemical exact match.
$$ R = w_t \cdot \text{Tanimoto} + w_s \cdot \text{ExactMatch} $$
where $w_t=0.4$ and $w_s=0.6$.</li>
<li><strong>Sampling</strong>: Samples 4 completions per image with temperature 1.0 during RL training.</li>
</ul>
</li>
<li><strong>Auxiliary Tasks</strong>:
<ul>
<li><strong>Bond Classification</strong>: Concatenates hidden states of two atom queries to predict bond type via MLP.</li>
<li><strong>Atom Localization</strong>: Treated as a classification task (SimCC) but optimized using <strong>Maximum Likelihood Estimation (MLE)</strong> to account for uncertainty.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder Transformer.
<ul>
<li><strong>Encoder</strong>: <strong>EfficientViT-L1</strong> (~53M params), chosen for linear attention efficiency.</li>
<li><strong>Decoder</strong>: 6-layer Transformer with <strong>RoPE</strong>, <strong>SwiGLU</strong>, and <strong>RMSNorm</strong>. Randomly initialized (no LLM weights) due to vocabulary mismatch.</li>
<li><strong>Coordinate Head</strong>: Separated from the main decoder. It adds 2 extra Transformer layers to process atom queries before prediction to improve accuracy.</li>
</ul>
</li>
<li><strong>Parameter Tuning</strong>:
<ul>
<li>Stage 3 (RL) uses <strong>LoRA</strong> (Rank=8, Alpha=16) to optimize the decoder.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Exact Match</strong>: String identity for standard SMILES.</li>
<li><strong>Tanimoto Coefficient</strong>: Fingerprint similarity for chemical semantics.</li>
<li><strong>OKS (Object Keypoint Similarity)</strong>: Used specifically for evaluating atom localization accuracy.</li>
</ul>
</li>
<li><strong>Perturbation</strong>: Robustness tested with random rotations [-5°, 5°] and xy-shearing [-0.1, 0.1].</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Training and inference performed on a single node.</li>
<li><strong>Processors</strong>: Intel Xeon Silver 4210R CPU.</li>
<li><strong>Accelerators</strong>: 4x <strong>NVIDIA GeForce RTX 3090/4090</strong> GPUs.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Stage 1: Batch size 512, LR $4 \x10^{-4}$.</li>
<li>Stage 2: Batch size 256, Bond head LR $4 \x10^{-4}$, Coord head LR $4 \x10^{-5}$.</li>
<li>Stage 3 (RL): Batch size 64, Base LR $1 \x10^{-4}$.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{zhang2025molsight,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolSight: Optical Chemical Structure Recognition with SMILES Pretraining, Multi-Granularity Learning and Reinforcement Learning}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wenrui Zhang and Xinggang Wang and Bin Feng and Wenyu Liu}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2511.17300}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.CV}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2511.17300}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolScribe: Image-to-Graph Molecular Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molscribe/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molscribe/</guid><description>Image-to-graph generation model for OCSR that predicts atoms, bonds, and coordinates jointly to better handle stereochemistry and abbreviations.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Qian, Y., Guo, J., Tu, Z., Li, Z., Coley, C. W., &amp; Barzilay, R. (2023). MolScribe: Robust Molecular Structure Recognition with Image-To-Graph Generation. <em>Journal of Chemical Information and Modeling</em>, 63(7), 1925-1934. <a href="https://doi.org/10.1021/acs.jcim.2c01480">https://doi.org/10.1021/acs.jcim.2c01480</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://huggingface.co/spaces/yujieq/MolScribe">Hugging Face Space</a></li>
</ul>
<h2 id="contribution-generative-image-to-graph-modelling">Contribution: Generative Image-to-Graph Modelling</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a secondary contribution to Resources ($\Psi_{\text{Resource}}$).</p>
<p>It proposes a novel architecture (image-to-graph generation) to solve the Optical Chemical Structure Recognition (OCSR) task, validating it through extensive ablation studies and comparisons against strong baselines like MolVec and DECIMER. It also contributes a new benchmark dataset of annotated images from ACS journals.</p>
<h2 id="motivation-limitations-in-existing-ocsr-pipelines">Motivation: Limitations in Existing OCSR Pipelines</h2>
<p>Translating molecular images into machine-readable graphs (OCSR) is challenging due to the high variance in drawing styles, stereochemistry conventions, and abbreviated structures found in literature.</p>
<p>Existing solutions face structural bottlenecks:</p>
<ul>
<li><strong>Rule-based systems</strong> (e.g., OSRA) rely on rigid heuristics that fail on diverse styles.</li>
<li><strong>Image-to-SMILES neural models</strong> treat the problem as captioning. They struggle with geometric reasoning (which is strictly required for chirality) and struggle to incorporate chemical constraints or verify correctness because they omit explicit atom locations.</li>
</ul>
<h2 id="core-innovation-joint-graph-and-coordinate-prediction">Core Innovation: Joint Graph and Coordinate Prediction</h2>
<p>MolScribe introduces an <strong>Image-to-Graph</strong> generation paradigm that combines the flexibility of neural networks with the precision of symbolic constraints. It frames the task probabilistically as:</p>
<p>$$
P(G | I) = P(A | I) P(B | A, I)
$$</p>
<p>Where the model predicts a sequence of atoms $A$ given an image $I$, followed by the bonds $B$ given both the atoms and the image.</p>
<ol>
<li><strong>Explicit Graph Prediction</strong>: It predicts a sequence of atoms (with 2D coordinates) and then predicts bonds between them.</li>
<li><strong>Symbolic Constraints</strong>: It uses the predicted graph structure and coordinates to strictly determine chirality and cis/trans isomerism.</li>
<li><strong>Abbreviation Expansion</strong>: It employs a greedy algorithm to parse and expand &ldquo;superatoms&rdquo; (e.g., &ldquo;CO2Et&rdquo;) into their full atomic structure.</li>
<li><strong>Dynamic Augmentation</strong>: It introduces a data augmentation strategy that randomly substitutes functional groups with abbreviations and adds R-groups during training to improve generalization.</li>
</ol>
<h2 id="methodology-autoregressive-atoms-and-pairwise-bonds">Methodology: Autoregressive Atoms and Pairwise Bonds</h2>
<p>The authors evaluate MolScribe on synthetic and real-world datasets, focusing on <strong>Exact Match Accuracy</strong> of the canonical SMILES string. The model generates atom sequences autoregressively:</p>
<p>$$
P(A | I) = \prod_{i=1}^n P(a_i | A_{&lt;i}, I)
$$</p>
<p>To handle continuous spatial locations, atom coordinates map to discrete bins (e.g., $\hat{x}_i = \lfloor \frac{x_i}{W} \times n_{\text{bins}} \rfloor$), and decode alongside element labels. Bonds act on a pairwise classifier over the hidden states of every atom pair:</p>
<p>$$
P(B | A, I) = \prod_{i=1}^n \prod_{j=1}^n P(b_{i,j} | A, I)
$$</p>
<ul>
<li><strong>Baselines</strong>: Compared against rule-based (MolVec, OSRA) and neural (Img2Mol, DECIMER, SwinOCSR) systems.</li>
<li><strong>Benchmarks</strong>:
<ul>
<li><strong>Synthetic</strong>: Indigo (in-domain) and ChemDraw (out-of-domain).</li>
<li><strong>Realistic</strong>: Five public benchmarks (CLEF, JPO, UOB, USPTO, Staker).</li>
<li><strong>New Dataset</strong>: 331 images from ACS Publications (journal articles).</li>
</ul>
</li>
<li><strong>Ablations</strong>: Tested performance without data augmentation, with continuous vs. discrete coordinates, and without non-atom tokens.</li>
<li><strong>Human Eval</strong>: Measured the time reduction for chemists using MolScribe to digitize molecules vs. drawing from scratch.</li>
</ul>
<h2 id="results-robust-exact-match-accuracy">Results: Robust Exact Match Accuracy</h2>
<ul>
<li><strong>State-of-the-Art Performance</strong>: MolScribe achieved <strong>76-93% accuracy</strong> across diverse benchmarks, significantly outperforming baselines (e.g., on the difficult Staker dataset, MolScribe achieved 71.9% compared to the next best 46.5%).</li>
<li><strong>Chirality Verification</strong>: Explicit geometric reasoning allowed MolScribe to predict chiral molecules significantly better than image-to-SMILES baselines. When chirality is ignored, the performance gap narrows (e.g., on Indigo, baseline accuracy rises from 94.1% to 96.3%), isolating MolScribe&rsquo;s primary advantage to geometric reasoning for stereochemistry.</li>
<li><strong>Hand-Drawn Generalization</strong>: The model achieved <strong>11.2% exact match accuracy</strong> on the DECIMER-HDM dataset, despite lacking hand-drawn images in the training set, with many errors bounded to a few atomic mismatches.</li>
<li><strong>Robustness</strong>: The model maintained high performance on perturbed images (rotation/shear), whereas rule-based systems degraded severely.</li>
<li><strong>Usability</strong>: The atom-level alignment allows for confidence visualization, and human evaluation showed it reduced digitization time from <strong>137s to 20s</strong> per molecule.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a mix of synthetic and patent data with extensive dynamic augmentation:</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td><strong>PubChem (Synthetic)</strong></td>
          <td>1M</td>
          <td>Molecules randomly sampled from PubChem and rendered via Indigo toolkit; includes atom coords.</td>
      </tr>
      <tr>
          <td>Training</td>
          <td><strong>USPTO (Patents)</strong></td>
          <td>680K</td>
          <td>Patent data lacks exact atom coordinates; relative coordinates normalized from MOLfiles to image dimensions (often introduces coordinate shifts).</td>
      </tr>
  </tbody>
</table>
<p><strong>Molecule Augmentation</strong>:</p>
<ul>
<li><strong>Functional Groups</strong>: Randomly substituted using 53 common substitution rules (e.g., replacing substructures with &ldquo;Et&rdquo; or &ldquo;Ph&rdquo;).</li>
<li><strong>R-Groups</strong>: Randomly added using vocabulary: <code>[R, R1...R12, Ra, Rb, Rc, Rd, X, Y, Z, A, Ar]</code>.</li>
<li><strong>Styles</strong>: Random variation of aromaticity (circle vs. bonds) and explicit hydrogens.</li>
</ul>
<p><strong>Image Augmentation</strong>:</p>
<ul>
<li><strong>Rendering</strong>: Randomized font (Arial, Times, Courier, Helvetica), line width, and label modes during synthetic generation.</li>
<li><strong>Perturbations</strong>: Applied rotation ($\pm 90^{\circ}$), cropping ($1\%$), padding ($40\%$), downscaling, blurring, and Salt-and-Pepper/Gaussian noise.</li>
</ul>
<p><strong>Preprocessing</strong>: Input images are resized to $384 \x384$.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Atom Prediction (Pix2Seq-style)</strong>:
<ul>
<li>The model generates a sequence of tokens: $S^A = [l_1, \hat{x}_1, \hat{y}_1, \dots, l_n, \hat{x}_n, \hat{y}_n]$.</li>
<li><strong>Discretization</strong>: Coordinates are binned into integer tokens ($n_{bins} = 64$).</li>
<li><strong>Tokenizer</strong>: Atom-wise tokenizer splits SMILES into atoms; non-atom tokens (parentheses, digits) are kept to help structure learning.</li>
</ul>
</li>
<li><strong>Bond Prediction</strong>:
<ul>
<li>Format: Pairwise classification for every pair of predicted atoms.</li>
<li>Symmetry: For symmetric bonds (single/double), the probability is averaged as:
$$
\hat{P}(b_{i,j} = t) = \frac{1}{2} \big( P(b_{i,j} = t) + P(b_{j,i} = t) \big)
$$
For wedges, directional logic strictly applies instead.</li>
</ul>
</li>
<li><strong>Abbreviation Expansion (Algorithm 1)</strong>:
<ul>
<li>A greedy algorithm connects atoms within an expanded abbreviation (e.g., &ldquo;COOH&rdquo;) until valences are full, avoiding the need for a fixed dictionary.</li>
<li><strong>Carbon Chains</strong>: Splits condensed chains like $C_aX_b$ into explicit sequences ($CX_q&hellip;CX_{q+r}$).</li>
<li><strong>Nested Formulas</strong>: Recursively parses nested structures like $N(CH_3)_2$ by treating them as superatoms attached to the current backbone.</li>
<li><strong>Valence Handling</strong>: Iterates through common valences first to resolve ambiguities.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is an encoder-decoder with a classification head:</p>
<ul>
<li><strong>Encoder</strong>: <strong>Swin Transformer (Swin-B)</strong>, pre-trained on ImageNet-22K (88M params).</li>
<li><strong>Decoder</strong>: 6-layer Transformer, 8 heads, hidden dimension 256.</li>
<li><strong>Bond Predictor</strong>: 2-layer MLP (Feedforward) with ReLU, taking concatenated atom hidden states as input.</li>
<li><strong>Training</strong>: Teacher forcing, Cross-Entropy Loss, Batch size 128, 30 epochs.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric</strong>: Exact Match of Canonical SMILES.</p>
<ul>
<li>Stereochemistry: Must match tetrahedral chirality; cis-trans ignored.</li>
<li>R-groups: Replaced with wildcards <code>*</code> or <code>[d*]</code> for evaluation.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Training performed on Linux server with <strong>96 CPUs</strong> and <strong>500GB RAM</strong>.</li>
<li><strong>GPUs</strong>: <strong>4x NVIDIA A100 GPUs</strong>.</li>
<li><strong>Training Time</strong>: Unspecified; comparative models on large datasets took &ldquo;more than one day&rdquo;.</li>
<li><strong>Inference</strong>: Requires autoregressive decoding for atoms, followed by a single forward pass for bonds.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{qianMolScribeRobustMolecular2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MolScribe}}: {{Robust Molecular Structure Recognition}} with {{Image-To-Graph Generation}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MolScribe}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Qian, Yujie and Guo, Jiang and Tu, Zhengkai and Li, Zhening and Coley, Connor W. and Barzilay, Regina}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1925--1934}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/acs.jcim.2c01480}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://pubs.acs.org/doi/10.1021/acs.jcim.2c01480}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolMole: Unified Vision Pipeline for Molecule Mining</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molmole/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molmole/</guid><description>A vision-based deep learning framework that unifies molecule detection, reaction parsing, and OCSR for page-level chemical data extraction.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chun, S., Kim, J., Jo, A., Jo, Y., &amp; Oh, S. (2025). MolMole: Molecule Mining from Scientific Literature. <em>arXiv preprint arXiv:2505.03777</em>. <a href="https://doi.org/10.48550/arXiv.2505.03777">https://doi.org/10.48550/arXiv.2505.03777</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://lgai-ddu.github.io/molmole/">Project Page</a></li>
</ul>
<h2 id="molmoles-dual-contribution-unified-ocsr-method-and-page-level-benchmarks">MolMole&rsquo;s Dual Contribution: Unified OCSR Method and Page-Level Benchmarks</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong <strong>Resource</strong> contribution.</p>
<p>It functions as a <strong>Method</strong> paper because it introduces &ldquo;MolMole,&rdquo; a unified deep learning framework that integrates molecule detection, reaction diagram parsing, and optical chemical structure recognition (OCSR) into a single pipeline. It validates this method through extensive comparisons against state-of-the-art baselines like DECIMER and OpenChemIE.</p>
<p>It also serves as a <strong>Resource</strong> paper because the authors construct and release a novel page-level benchmark dataset of 550 annotated pages (patents and articles) to address the lack of standardized evaluation metrics for full-page chemical extraction.</p>
<h2 id="addressing-the-limitations-of-fragmented-processing">Addressing the Limitations of Fragmented Processing</h2>
<p>The rapid accumulation of chemical literature has trapped valuable molecular and reaction data in unstructured formats like images and PDFs. Extracting this manually is time-consuming, while existing AI frameworks have significant limitations:</p>
<ul>
<li><strong>DECIMER</strong>: Lacks the ability to process reaction diagrams entirely.</li>
<li><strong>OpenChemIE</strong>: Relies on external layout parser models to crop elements before processing. This dependence often leads to detection failures in documents with complex layouts.</li>
<li><strong>Generative Hallucination</strong>: Existing generative OCSR models (like MolScribe) are prone to &ldquo;hallucinating&rdquo; structures or failing on complex notations like polymers.</li>
</ul>
<h2 id="a-unified-vision-pipeline-for-robust-detection">A Unified Vision Pipeline for Robust Detection</h2>
<p>MolMole introduces several architectural and workflow innovations:</p>
<ul>
<li><strong>Direct Page-Level Processing</strong>: Unlike OpenChemIE, MolMole processes full document pages directly without requiring an external layout parser, which improves robustness on complex layouts like two-column patents.</li>
<li><strong>Unified Vision Pipeline</strong>: It integrates three specialized vision models into one workflow:
<ul>
<li><strong>ViDetect</strong>: A DINO-based object detector for identifying molecular regions.</li>
<li><strong>ViReact</strong>: An RxnScribe-based model adapted for full-page reaction parsing.</li>
<li><strong>ViMore</strong>: A detection-based OCSR model that explicitly predicts atoms and bonds.</li>
</ul>
</li>
<li><strong>Hallucination Mitigation</strong>: By using a detection-based approach (ViMore), the model avoids hallucinating chemical structures and provides confidence scores.</li>
<li><strong>Advanced Notation Support</strong>: The system explicitly handles &ldquo;wavy bonds&rdquo; (variable attachments in patents) and polymer bracket notations, which confuse standard SMILES-based models.</li>
</ul>
<h2 id="page-level-benchmark-evaluation-and-unified-metrics">Page-Level Benchmark Evaluation and Unified Metrics</h2>
<p>The authors evaluated the framework on both a newly curated benchmark and existing public datasets:</p>
<ul>
<li><strong>New Benchmark Creation</strong>: They curated 550 pages (300 patents, 250 articles) fully annotated with bounding boxes, reaction roles (reactant, product, condition), and MOLfiles.</li>
<li><strong>Baselines</strong>: MolMole was compared against <strong>DECIMER 2.0</strong>, <strong>OpenChemIE</strong>, and <strong>ReactionDataExtractor 2.0</strong>.</li>
<li><strong>OCSR Benchmarking</strong>: ViMore was evaluated against DECIMER, MolScribe, and MolGrapher on four public datasets: <strong>USPTO</strong>, <strong>UOB</strong>, <strong>CLEF</strong>, and <strong>JPO</strong>.</li>
<li><strong>Metric Proposal</strong>: They introduced a combined &ldquo;End-to-End&rdquo; metric that modifies standard object detection Precision/Recall to strictly require correct SMILES conversion for a &ldquo;True Positive&rdquo;.</li>
</ul>
<p>$$ \text{True Positive (End-to-End)} = ( \text{IoU} \geq 0.5 ) \land ( \text{SMILES}_{\text{gt}} == \text{SMILES}_{\text{pred}} ) $$</p>
<h2 id="state-of-the-art-extraction-outcomes">State-of-the-Art Extraction Outcomes</h2>
<ul>
<li><strong>SOTA Page-Level Performance</strong>: On the new benchmark, MolMole achieved F1 scores of <strong>89.1%</strong> (Patents) and <strong>86.8%</strong> (Articles) for the combined detection-to-conversion task, significantly outperforming DECIMER and OpenChemIE.</li>
<li><strong>Reaction Parsing</strong>: ViReact achieved an F1 score of <strong>98.0%</strong> (soft match) on patents, compared to 82.2% for the next best model (RxnScribe).</li>
<li><strong>Public Benchmarks</strong>: ViMore outperformed competitors on 3 out of 4 public OCSR datasets (CLEF, JPO, USPTO).</li>
<li><strong>Qualitative Superiority</strong>: The authors demonstrated that MolMole successfully handles multi-column reaction diagrams where cropping-based models fail and faithfully preserves layout geometry in generated MOLfiles.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training Data</strong>: The models (ViDetect and ViMore) were trained on <strong>private/proprietary datasets</strong>, which is a limitation for full reproducibility from scratch.</li>
<li><strong>Benchmark Data</strong>: The authors introduce a test set of <strong>550 pages</strong> (3,897 molecules, 1,022 reactions) derived from patents and scientific articles. This dataset is stated to be made &ldquo;publicly available&rdquo;.</li>
<li><strong>Public Evaluation Data</strong>: Standard OCSR datasets used include USPTO (5,719 images), UOB (5,740 images), CLEF (992 images), and JPO (450 images).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pipeline Workflow</strong>: PDF → PNG Images → Parallel execution of <strong>ViDetect</strong> and <strong>ViReact</strong> → Cropping of molecular regions → <strong>ViMore</strong> conversion → Output (JSON/Excel).</li>
<li><strong>Post-Processing</strong>:
<ul>
<li><em>ViDetect</em>: Removes overlapping proposals based on confidence scores and size constraints.</li>
<li><em>ViReact</em>: Refines predictions by correcting duplicates and removing empty entities.</li>
<li><em>ViMore</em>: Assembles detected atom/bond information into structured representations (MOLfile).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture Basis</th>
          <th>Task</th>
          <th>Key Feature</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ViDetect</strong></td>
          <td>DINO (DETR-based)</td>
          <td>Molecule Detection</td>
          <td>End-to-end training; avoids slow autoregressive methods.</td>
      </tr>
      <tr>
          <td><strong>ViReact</strong></td>
          <td>RxnScribe</td>
          <td>Reaction Parsing</td>
          <td>Operates on full pages; autoregressive decoder for structured sequence generation.</td>
      </tr>
      <tr>
          <td><strong>ViMore</strong></td>
          <td>Custom Vision Model</td>
          <td>OCSR</td>
          <td>Detection-based (predicts atom/bond regions).</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Molecule Detection</strong>: Evaluated using COCO metrics (AP, AR, F1) at IoU thresholds 0.50-0.95.</li>
<li><strong>Molecule Conversion</strong>: Evaluated using SMILES exact match accuracy and Tanimoto similarity.</li>
<li><strong>Combined Metric</strong>: A custom metric where a True Positive requires both IoU $\geq$ 0.5 and a correct SMILES string match where $\text{SMILES}_{\text{gt}} == \text{SMILES}_{\text{pred}}$.</li>
<li><strong>Reaction Parsing</strong>: Evaluated using <strong>Hard Match</strong> (all components correct) and <strong>Soft Match</strong> (molecular entities only, ignoring text labels).</li>
</ul>
]]></content:encoded></item><item><title>MolGrapher: Graph-based Chemical Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molgrapher/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molgrapher/</guid><description>A graph-based deep learning approach for optical chemical structure recognition that outperforms image captioning methods.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Title</strong>: MolGrapher: Graph-based Visual Recognition of Chemical Structures</p>
<p><strong>Authors</strong>: Lucas Morin, Martin Danelljan, Maria Isabel Agea, Ahmed Nassar, Valery Weber, Ingmar Meijer, Peter Staar, Fisher Yu</p>
<p><strong>Citation</strong>: Morin, L., Danelljan, M., Agea, M. I., Nassar, A., Weber, V., Meijer, I., Staar, P., &amp; Yu, F. (2023). MolGrapher: Graph-based Visual Recognition of Chemical Structures. <em>Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)</em>, 19552-19561.</p>
<p><strong>Publication</strong>: ICCV 2023</p>
<p><strong>Links</strong>:</p>
<ul>
<li><a href="https://openaccess.thecvf.com/content/ICCV2023/html/Morin_MolGrapher_Graph-based_Visual_Recognition_of_Chemical_Structures_ICCV_2023_paper.html">Paper</a></li>
<li><a href="https://github.com/DS4SD/MolGrapher">GitHub Repository</a></li>
</ul>
<h2 id="1-contribution--type">1. Contribution / Type</h2>
<p>This is primarily a <strong>Methodological</strong> paper that proposes a novel neural architecture (MolGrapher), shifting the paradigm of Optical Chemical Structure Recognition (OCSR) from image captioning back to graph reconstruction. It also has a significant <strong>Resource</strong> component, releasing a synthetic data generation pipeline and a new large-scale benchmark (USPTO-30K) to address the scarcity of annotated real-world data.</p>
<h2 id="2-motivation">2. Motivation</h2>
<p>The automatic analysis of chemical literature is critical for accelerating drug and material discovery, but much of this information is locked in 2D images of molecular structures.</p>
<ul>
<li><strong>Problem:</strong> Existing rule-based methods are rigid, while recent deep learning methods based on &ldquo;image captioning&rdquo; (predicting <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a> strings) struggle with complex molecules and fail to exploit the natural graph structure of molecules.</li>
<li><strong>Gap:</strong> There is a lack of diverse, annotated real-world training data, and captioning models suffer from &ldquo;hallucinations&rdquo; where they predict valid SMILES that do not match the image.</li>
</ul>
<h2 id="3-novelty--core-innovation">3. Novelty / Core Innovation</h2>
<p>MolGrapher introduces a <strong>graph-based deep learning pipeline</strong> that explicitly models the molecule&rsquo;s geometry and topology.</p>
<ul>
<li><strong>Supergraph Concept:</strong> It first detects all atom keypoints and builds a &ldquo;supergraph&rdquo; of all plausible bonds.</li>
<li><strong>Hybrid Approach:</strong> It combines a ResNet-based keypoint detector with a Graph Neural Network (GNN) that classifies nodes (atoms) and edges (bonds) within the supergraph context.</li>
<li><strong>Synthetic Pipeline:</strong> A robust data generation pipeline that renders molecules with varying styles (fonts, bond widths) and augmentations to simulate real document noise.</li>
</ul>
<p>At the core of the Keypoint Detector&rsquo;s performance is the <strong>Weight-Adaptive Heatmap Regression (WAHR)</strong> loss. Since pixels without an atom drastically outnumber pixels containing an atom, WAHR loss is designed to counter the class imbalance. For ground-truth heatmap $y$ and prediction $p$:</p>
<p>$$ L_{WAHR}(p, y) = \sum_i \alpha_y (p_i - y_i)^2 $$</p>
<p>where $\alpha_y$ dynamically down-weights easily classified background pixels.</p>
<h2 id="4-methodology--experiments">4. Methodology &amp; Experiments</h2>
<p>The authors evaluated MolGrapher against both rule-based (OSRA, MolVec) and deep learning baselines (DECIMER, Img2Mol, Image2Graph).</p>
<ul>
<li><strong>Benchmarks:</strong> Evaluated on standard datasets: USPTO, Maybridge UoB, CLEF-2012, and JPO.</li>
<li><strong>New Benchmark:</strong> Introduced and tested on <strong>USPTO-30K</strong>, split into clean, abbreviated, and large molecule subsets.</li>
<li><strong>Ablations:</strong> Analyzed the impact of synthetic augmentations, keypoint loss functions, supergraph connectivity radius, and GNN layers.</li>
<li><strong>Robustness:</strong> Tested on perturbed images (rotations, shearing) to mimic scanned patent quality.</li>
</ul>
<p>The key mathematical formulation in the model involves the node context updates in the GNN mechanism. Messages $m_{j\to i}$ from neighboring nodes are aggregated to form a topological update $\Delta x_i$, which is then combined with visual updates.</p>
<h2 id="5-results--conclusions">5. Results &amp; Conclusions</h2>
<p>MolGrapher achieved state-of-the-art performance, significantly outperforming image captioning methods on standard benchmarks.</p>
<ul>
<li><strong>Accuracy:</strong> It achieved <strong>91.5%</strong> accuracy on USPTO compared to 61.0% for DECIMER 2.0 (the next best synthetic-only method).</li>
<li><strong>Large Molecules:</strong> It demonstrated superior scaling, correctly recognizing large molecules (USPTO-10K-L) where image captioning methods like Img2Mol failed completely (0.0% accuracy).</li>
<li><strong>Generalization:</strong> The method proved robust to image perturbations and style variations without requiring fine-tuning on real data. Important caveats: The paper acknowledges MolGrapher&rsquo;s sensitivity to disconnected structures and lack of support for stereochemistry mapping.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model relies on synthetic data for training due to the scarcity of annotated real-world images.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Synthetic Data</td>
          <td>300,000 images</td>
          <td>Generated from PubChem SMILES using RDKit. Augmentations include pepper patches, random lines, and variable bond styles.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>USPTO-30K</td>
          <td>30,000 images</td>
          <td>Created by authors from USPTO patents (2001-2020). Subsets: 10K clean, 10K abbreviated, 10K large (&gt;70 atoms).</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Standard Benchmarks</td>
          <td>Various</td>
          <td>USPTO (5,719), Maybridge UoB (5,740), CLEF-2012 (992), JPO (450).</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The pipeline consists of three distinct algorithmic stages:</p>
<ol>
<li>
<p><strong>Keypoint Detection</strong>:</p>
<ul>
<li>Predicts a heatmap of atom locations using a CNN.</li>
<li>Thresholds heatmaps at the bottom 10th percentile and uses a $5\times5$ window for local maxima.</li>
<li>Uses <strong>Weight-Adaptive Heatmap Regression (WAHR)</strong> loss to handle class imbalance (background vs. atoms).</li>
</ul>
</li>
<li>
<p><strong>Supergraph Construction</strong>:</p>
<ul>
<li>Connects every detected keypoint to neighbors within a radius of $3 \times$ the estimated bond length.</li>
<li>Prunes edges with no filled pixels or if obstructed by a third keypoint.</li>
<li>Keeps a maximum of 6 bond candidates per atom.</li>
</ul>
</li>
<li>
<p><strong>Superatom Recognition</strong>:</p>
<ul>
<li>Detects &ldquo;superatom&rdquo; nodes (abbreviations like <code>COOH</code>).</li>
<li>Uses <strong>PP-OCR</strong> to transcribe the text at these node locations.</li>
</ul>
</li>
</ol>
<h3 id="models">Models</h3>
<p>The architecture utilizes standard backbones tailored for specific sub-tasks:</p>
<ul>
<li><strong>Keypoint Detector</strong>: <strong>ResNet-18</strong> backbone with $8\times$ dilation to preserve spatial resolution.</li>
<li><strong>Node Classifier</strong>: <strong>ResNet-50</strong> backbone with $2\times$ dilation for extracting visual features at node locations.</li>
<li><strong>Graph Neural Network</strong>: A custom GNN that updates node embeddings based on visual features and neighborhood context. The initial node embedding combines the visual feature vector $v_i$ and a learnable type encoding $w_{t_i}$.</li>
<li><strong>Readout</strong>: MLPs classify nodes into atom types (e.g., C, O, N) and bond types (Single, Double, None).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Accuracy is defined strictly: the predicted molecule must have an identical <strong><a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a></strong> string to the ground truth. Stereochemistry and Markush structures are excluded from evaluation.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>MolGrapher Score</th>
          <th>SOTA Baseline (Synthetic)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>USPTO</td>
          <td><strong>91.5%</strong></td>
          <td>61.0% (DECIMER 2.0)</td>
          <td>Full USPTO benchmark</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>USPTO-10K-L</td>
          <td><strong>31.4%</strong></td>
          <td>0.0% (Img2Mol)</td>
          <td>Large molecules (&gt;70 atoms)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>JPO</td>
          <td><strong>67.5%</strong></td>
          <td>64.0% (DECIMER 2.0)</td>
          <td>Challenging, low-quality images</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: Trained on 3 NVIDIA A100 GPUs.</li>
<li><strong>Training Time</strong>: 20 epochs.</li>
<li><strong>Optimization</strong>: ADAM optimizer, learning rate 0.0001, decayed by 0.8 after 5000 iterations.</li>
<li><strong>Loss Weighting</strong>: Atom classifier loss weighted by 1; bond classifier loss weighted by 3.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{morinMolGrapherGraphbasedVisual2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MolGrapher}}: {{Graph-based Visual Recognition}} of {{Chemical Structures}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MolGrapher}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the {{IEEE}}/{{CVF International Conference}} on {{Computer Vision}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Danelljan, Martin and Agea, Maria Isabel and Nassar, Ahmed and Weber, Valery and Meijer, Ingmar and Staar, Peter and Yu, Fisher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{19552--19561}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICCV51070.2023.01793}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-10-18}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MMSSC-Net: Multi-Stage Sequence Cognitive Networks</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/mmssc-net/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/mmssc-net/</guid><description>A deep learning model for Optical Chemical Structure Recognition (OCSR) using SwinV2 and GPT-2 to convert molecular images to SMILES.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, D., Zhao, D., Wang, Z., Li, J., &amp; Li, J. (2024). MMSSC-Net: multi-stage sequence cognitive networks for drug molecule recognition. <em>RSC Advances</em>, 14(26), 18182-18191. <a href="https://doi.org/10.1039/D4RA02442G">https://doi.org/10.1039/D4RA02442G</a></p>
<p><strong>Publication</strong>: RSC Advances 2024</p>
<h2 id="contribution-a-multi-stage-architectural-pipeline">Contribution: A Multi-Stage Architectural Pipeline</h2>
<p><strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong>.
The paper proposes a novel deep learning architecture (<strong>MMSSC-Net</strong>) for Optical Chemical Structure Recognition (OCSR). It focuses on architectural innovation, specifically combining a SwinV2 visual encoder with a GPT-2 decoder, and validates this method through extensive benchmarking against existing rule-based and deep-learning baselines. It includes ablation studies to justify the choice of the visual encoder.</p>
<h2 id="motivation-addressing-noise-and-rigid-image-recognition">Motivation: Addressing Noise and Rigid Image Recognition</h2>
<ul>
<li><strong>Data Usage Gap</strong>: Drug discovery relies heavily on scientific literature, but molecular structures are often locked in vector graphics or images that computers cannot easily process.</li>
<li><strong>Limitations of Prior Work</strong>: Existing Rule-based methods are rigid and sensitive to noise. Previous Deep Learning approaches (Encoder-Decoder &ldquo;Image Captioning&rdquo; styles) often lack precision, interpretability, and struggle with varying image resolutions or large molecules.</li>
<li><strong>Need for &ldquo;Cognition&rdquo;</strong>: The authors argue that treating the image as a single isolated whole is insufficient; a model needs to &ldquo;perceive&rdquo; fine-grained details (atoms and bonds) to handle noise and varying pixel qualities effectively.</li>
</ul>
<h2 id="novelty-a-fine-grained-perception-pipeline">Novelty: A Fine-Grained Perception Pipeline</h2>
<ul>
<li><strong>Multi-Stage Cognitive Architecture</strong>: MMSSC-Net splits the task into stages:
<ol>
<li><strong>Fine-grained Perception</strong>: Detecting atom and bond sequences (including spatial coordinates) using SwinV2.</li>
<li><strong>Graph Construction</strong>: Assembling these into a molecular graph.</li>
<li><strong>Sequence Evolution</strong>: converting the graph into a machine-readable format (SMILES).</li>
</ol>
</li>
<li><strong>Hybrid Transformer Model</strong>: It combines a hierarchical vision transformer (<strong>SwinV2</strong>) for encoding with a generative pre-trained transformer (<strong>GPT-2</strong>) and MLPs for decoding atomic and bond targets.</li>
<li><strong>Robustness Mechanisms</strong>: The inclusion of random noise sequences during training to improve generalization to new molecular targets.</li>
</ul>
<h2 id="methodology-and-benchmarks">Methodology and Benchmarks</h2>
<ul>
<li><strong>Baselines</strong>: compared against 7 other tools:
<ul>
<li><em>Rule-based</em>: MolVec, OSRA.</li>
<li><em>Image-Smiles (DL)</em>: ABC-Net, Img2Mol, MolMiner.</li>
<li><em>Image-Graph-Smiles (DL)</em>: Image-To-Graph, MolScribe, ChemGrapher.</li>
</ul>
</li>
<li><strong>Datasets</strong>: Evaluated on 5 diverse datasets: STAKER (synthetic), USPTO, CLEF, JPO, and UOB (real-world).</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Accuracy</strong>: Exact string match of the predicted SMILES.</li>
<li><strong>Tanimoto Similarity</strong>: Chemical similarity using Morgan fingerprints.</li>
</ul>
</li>
<li><strong>Ablation Study</strong>: Tested different visual encoders (Swin Transformer, ViT-B, ResNet-50) to validate the choice of SwinV2.</li>
<li><strong>Resolution Sensitivity</strong>: Tested model performance across image resolutions from 256px to 2048px.</li>
</ul>
<h2 id="results-and-core-outcomes">Results and Core Outcomes</h2>
<ul>
<li><strong>State-of-the-Art Performance</strong>: MMSSC-Net achieved 75-94% accuracy across datasets, outperforming baselines on most benchmarks.</li>
<li><strong>Resolution Robustness</strong>: The model maintained high accuracy (approx 90-95%) across resolutions (256px to 2048px), whereas baselines like Img2Mol dropped significantly at higher resolutions.</li>
<li><strong>Efficiency</strong>: The SwinV2 encoder was noted to be more efficient than ViT-B in this context.</li>
<li><strong>Limitations</strong>: The model struggles with stereochemistry (virtual vs. solid wedge bonds) and &ldquo;irrelevant text&rdquo; noise (e.g., in JPO/DECIMER datasets).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a combination of PubChem and USPTO data, augmented to handle visual variability.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>PubChem</strong></td>
          <td>1,000,000</td>
          <td>Converted from <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a> to SMILES; random sampling.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>USPTO</strong></td>
          <td>600,000</td>
          <td>Patent images; converted from MOL to SMILES.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>STAKER</strong></td>
          <td>40,000</td>
          <td>Synthetic; Avg res $256 \times 256$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>USPTO</strong></td>
          <td>4,862</td>
          <td>Real; Avg res $721 \times 432$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>CLEF</strong></td>
          <td>881</td>
          <td>Real; Avg res $1245 \times 412$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>JPO</strong></td>
          <td>380</td>
          <td>Real; Avg res $614 \times 367$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>UOB</strong></td>
          <td>5,720</td>
          <td>Real; Avg res $759 \times 416$.</td>
      </tr>
  </tbody>
</table>
<p><strong>Augmentation</strong>:</p>
<ul>
<li><strong>Image</strong>: Random perturbations using RDKit/Indigo (rotation, filling, cropping, bond thickness/length, font size, Gaussian noise).</li>
<li><strong>Molecular</strong>: Introduction of functional group abbreviations and R-substituents (dummy atoms) using SMARTS templates.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Target Sequence Formulation</strong>: The model predicts a sequence containing bounding box coordinates and type labels: ${y_{\text{min}}, x_{\text{min}}, y_{\text{max}}, x_{\text{max}}, C_{\text{type}}}$.</li>
<li><strong>Loss Function</strong>: Cross-entropy loss with maximum likelihood estimation.
$$ \max \sum_{i=1}^{N} \sum_{j=1}^{L} \omega_{j} \log P(t_{j}^{i}|x^{i}, t_{1}^{i}, \dots, t_{j-1}^{i}) $$</li>
<li><strong>Noise Injection</strong>: A random sequence $T_r$ is appended to the target sequence during training to improve generalization to new goals.</li>
<li><strong>Graph Construction</strong>: Atoms ($v$) and bonds ($e$) are recognized separately; bonds are defined by connecting spatial atomic coordinates.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Encoder</strong>: <strong>Swin Transformer V2</strong>.
<ul>
<li>Pre-trained on ImageNet-1K.</li>
<li>Window size: $16 \times 16$.</li>
<li>Parameters: 88M.</li>
<li>Input resolution: $256 \times 256$.</li>
<li>Features: Scaled cosine attention; log-space continuous position bias.</li>
</ul>
</li>
<li><strong>Decoder</strong>: <strong>GPT-2</strong> + <strong>MLP</strong>.
<ul>
<li><strong>GPT-2</strong>: Used for recognizing atom types.
<ul>
<li>Layers: 24.</li>
<li>Attention Heads: 12.</li>
<li>Hidden Dimension: 768.</li>
<li>Dropout: 0.1.</li>
</ul>
</li>
<li><strong>MLP</strong>: Used for classifying bond types (single, double, triple, aromatic, wedge).</li>
</ul>
</li>
<li><strong>Vocabulary</strong>:
<ul>
<li>Standard: 95 common numbers/characters ([0], [C], [=], etc.).</li>
<li>Extended: 2000 SMARTS-based characters for isomers/groups (e.g., &ldquo;[C2F5]&rdquo;, &ldquo;[halo]&rdquo;).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ol>
<li><strong>Accuracy</strong>: Exact match of the generated SMILES string.</li>
<li><strong>Tanimoto Similarity</strong>: Similarity of Morgan fingerprints between predicted and ground truth molecules.</li>
</ol>
<p><strong>Key Results (Accuracy)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MMSSC-Net</th>
          <th>MolVec (Rule)</th>
          <th>ABC-Net (DL)</th>
          <th>MolScribe (DL)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Indigo</strong></td>
          <td>98.14</td>
          <td>95.63</td>
          <td>96.4</td>
          <td>99.0</td>
      </tr>
      <tr>
          <td><strong>USPTO</strong></td>
          <td>94.24</td>
          <td>88.47</td>
          <td>*</td>
          <td>51.7</td>
      </tr>
      <tr>
          <td><strong>CLEF</strong></td>
          <td>91.26</td>
          <td>81.61</td>
          <td>96.1</td>
          <td>82.9</td>
      </tr>
      <tr>
          <td><strong>UOB</strong></td>
          <td>92.71</td>
          <td>81.32</td>
          <td>*</td>
          <td>86.9</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Configuration</strong>:
<ul>
<li>Batch Size: 128.</li>
<li>Learning Rate: $4 \times 10^{-5}$.</li>
<li>Epochs: 40.</li>
</ul>
</li>
<li><strong>Inference Speed</strong>: The SwinV2 encoder demonstrated higher efficiency (faster inference time) compared to ViT-B and ResNet-50 baselines during ablation.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhangMMSSCNetMultistageSequence2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{MMSSC-Net: Multi-Stage Sequence Cognitive Networks for Drug Molecule Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{MMSSC-Net}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Dehai and Zhao, Di and Wang, Zhengwu and Li, Junhui and Li, Jin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{RSC Advances}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{26}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{18182--18191}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D4RA02442G}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://pubs.rsc.org/en/content/articlelanding/2024/ra/d4ra02442g}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MarkushGrapher: Multi-modal Markush Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/markushgrapher/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/markushgrapher/</guid><description>Multi-modal transformer combining vision, text, and layout encoding to extract complex Markush structures from patent documents with OCSR.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Morin, L., Weber, V., Nassar, A., Meijer, G. I., Van Gool, L., Li, Y., &amp; Staar, P. (2025). MarkushGrapher: Joint Visual and Textual Recognition of Markush Structures. <em>2025 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)</em>, 14505-14515. <a href="https://doi.org/10.1109/CVPR52734.2025.01352">https://doi.org/10.1109/CVPR52734.2025.01352</a></p>
<p><strong>Publication</strong>: CVPR 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/DS4SD/MarkushGrapher">GitHub Repository</a></li>
</ul>
<h2 id="overcoming-unimodal-limitations-for-markush-structures">Overcoming Unimodal Limitations for Markush Structures</h2>
<p>The automated analysis of chemical literature, particularly patents, is critical for drug discovery and material science. A major bottleneck is the extraction of <strong>Markush structures</strong>, which are complex chemical templates that represent families of molecules using a core backbone image and textual variable definitions. Existing methods are limited because they either rely solely on images (OCSR) and miss the textual context, or focus solely on text and miss the structural backbone. This creates a practical need for a unified, multi-modal approach that jointly interprets visual and textual data to accurately extract these structures for prior-art search and database construction. This paper proposes a <strong>Method</strong> and introduces a new <strong>Resource</strong> (M2S dataset) to bridge this gap.</p>
<h2 id="markushgrapher-the-multi-modal-architecture">MarkushGrapher: The Multi-Modal Architecture</h2>
<p>The core innovation is <strong>MarkushGrapher</strong>, a multi-modal architecture that jointly encodes image, text, and layout information. Key contributions include:</p>
<ul>
<li><strong>Dual-Encoder Architecture</strong>: Combines a Vision-Text-Layout (VTL) encoder (based on UDOP) with a specialized, pre-trained Optical Chemical Structure Recognition (OCSR) encoder (MolScribe). Let $E_{\text{VTL}}$ represent the combined sequence embedding and $E_{\text{OCSR}}$ represent the domain-specific visual embeddings.</li>
<li><strong>Joint Recognition</strong>: The model autoregressively generates a sequential graph representation (Optimized CXSMILES) and a substituent table simultaneously. It leverages cross-modal dependencies, allowing text to clarify ambiguous visual details like bond types.</li>
<li><strong>Synthetic Data Pipeline</strong>: A comprehensive pipeline generates realistic synthetic Markush structures (images and text) from PubChem data, overcoming the lack of labeled training data.</li>
<li><strong>Optimized Representation</strong>: A compacted version of CXSMILES moves variable groups into the SMILES string and adds explicit atom indexing to handle complex &ldquo;frequency&rdquo; and &ldquo;position&rdquo; variation indicators.</li>
</ul>
<h2 id="experimental-validation-on-the-new-m2s-benchmark">Experimental Validation on the New M2S Benchmark</h2>
<p>The authors validated their approach using the following setup:</p>
<ul>
<li><strong>Baselines</strong>: Compared against image-only chemistry models (DECIMER, MolScribe) and general-purpose multi-modal models (Uni-SMART, GPT-4o, Pixtral, Llama-3.2).</li>
<li><strong>Datasets</strong>: Evaluated on three benchmarks:
<ol>
<li><strong>MarkushGrapher-Synthetic</strong>: 1,000 generated samples.</li>
<li><strong>M2S</strong>: A new benchmark of 103 manually annotated real-world patent images.</li>
<li><strong>USPTO-Markush</strong>: 74 Markush backbone images from USPTO patents.</li>
</ol>
</li>
<li><strong>Ablation Studies</strong>: Analyzed the impact of the OCSR encoder, late fusion strategies, and the optimized CXSMILES format.</li>
</ul>
<h2 id="reaching-state-of-the-art-exact-matches">Reaching State-of-the-Art Exact Matches</h2>
<ul>
<li><strong>Performance Limits</strong>: MarkushGrapher significantly outperformed all baselines. On the M2S benchmark, it achieved a 38% Exact Match on CXSMILES (compared to 21% for MolScribe) and 29% Exact Match on tables.</li>
<li><strong>Complex Feature Handling</strong>: The model demonstrates a unique capability to reliably recognize complex Markush features like frequency variation (&lsquo;Sg&rsquo;) and position variation (&rsquo;m&rsquo;) indicators. The alternative baselines scored near zero on these specific sub-tasks.</li>
<li><strong>Cross-Modal Reasoning</strong>: Qualitative analysis verified the model can correctly infer visual details (such as bond order) that appear ambiguous in the image but become apparent with the text description.</li>
<li><strong>Robustness</strong>: The model generalizes well to real-world data despite being trained purely on synthetic data iterations.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong></p>
<ul>
<li><strong>Source</strong>: Synthetic dataset generated from PubChem SMILES.</li>
<li><strong>Size</strong>: 210,000 synthetic images.</li>
<li><strong>Pipeline</strong>:
<ol>
<li><strong>Selection</strong>: Sampled SMILES from PubChem based on substructure diversity.</li>
<li><strong>Augmentation</strong>: SMILES augmented to artificial CXSMILES using RDKit (inserting variable groups, frequency indicators).</li>
<li><strong>Rendering</strong>: Images rendered using Chemistry Development Kit (CDK) with randomized drawing parameters (font, bond width, spacing).</li>
<li><strong>Text Generation</strong>: Textual definitions generated using manual templates extracted from patents; 10% were paraphrased using Mistral-7B-Instruct-v0.3 to increase diversity.</li>
<li><strong>OCR</strong>: Bounding boxes extracted via a custom SVG parser aligned with MOL files.</li>
</ol>
</li>
</ul>
<p><strong>Evaluation Data</strong></p>
<ul>
<li><strong>M2S Dataset</strong>: 103 images from USPTO, EPO, and WIPO patents (1999-2023), manually annotated with CXSMILES and substituent tables.</li>
<li><strong>USPTO-Markush</strong>: 74 images from USPTO patents (2010-2016).</li>
<li><strong>MarkushGrapher-Synthetic</strong>: 1,000 samples generated via the pipeline.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimized CXSMILES</strong>:
<ul>
<li><strong>Compression</strong>: Variable groups moved from the extension block to the main SMILES string as special atoms to reduce sequence length.</li>
<li><strong>Indexing</strong>: Atom indices appended to each atom (e.g., <code>C:1</code>) to explicitly link the graph to the extension block (crucial for <code>m</code> and <code>Sg</code> sections).</li>
<li><strong>Vocabulary</strong>: Specific tokens used for atoms and bonds.</li>
</ul>
</li>
<li><strong>Augmentation</strong>: Standard image augmentations (shift, scale, blur, pepper noise, random lines) and OCR text augmentations (character substitution/insertion/deletion).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder Transformer.
<ul>
<li><strong>VTL Encoder</strong>: T5-large encoder (initialized from UDOP) that processes image patches, text tokens, and layout (bounding boxes).</li>
<li><strong>OCSR Encoder</strong>: Vision encoder from MolScribe (Swin Transformer), frozen during training.</li>
<li><strong>Text Decoder</strong>: T5-large decoder.</li>
</ul>
</li>
<li><strong>Fusion Strategy</strong>: <strong>Late Fusion</strong>. The core multi-modal alignment combines the textual layout features with specialized chemical vision explicitly. The fused representation $H_{\text{fused}}$ relies on the VTL output $e_1$ concatenated with the projected OCSR output $e_2$ before decoding:
$$ H_{\text{fused}} = [e_1 \oplus \text{Proj}(e_2)] $$</li>
<li><strong>Parameters</strong>: 831M total (744M trainable).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>CXSMILES Exact Match (EM)</strong>: Requires perfect match of SMILES string, variable groups, <code>m</code> sections, and <code>Sg</code> sections (ignoring stereochemistry).</li>
<li><strong>Tanimoto Score</strong>: Similarity of RDKit DayLight fingerprints (Markush features removed).</li>
<li><strong>Table Exact Match</strong>: All variable groups and substituents must match.</li>
<li><strong>Table F1-Score</strong>: Aggregated recall and precision of substituents per variable group.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Trained on a single NVIDIA H100 GPU.</li>
<li><strong>Training Config</strong>: 10 epochs, batch size of 10, learning rate 5e-4, 100 warmup steps, weight decay 1e-3.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{morinMarkushGrapherJointVisual2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{MarkushGrapher: Joint Visual and Textual Recognition of Markush Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{MarkushGrapher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2025 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Weber, Valéry and Nassar, Ahmed and Meijer, Gerhard Ingmar and Van Gool, Luc and Li, Yawei and Staar, Peter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jun,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{14505--14515}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/CVPR52734.2025.01352}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image2InChI: SwinTransformer for Molecular Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image2inchi/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image2inchi/</guid><description>Deep learning model using improved SwinTransformer encoder and attention-based feature fusion to convert molecular images to InChI strings.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, D., Xu, X., Pan, J., Gao, W., &amp; Zhang, S. (2024). Image2InChI: Automated Molecular Optical Image Recognition. <em>Journal of Chemical Information and Modeling</em>, 64(9), 3640-3649. <a href="https://doi.org/10.1021/acs.jcim.3c02082">https://doi.org/10.1021/acs.jcim.3c02082</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling (JCIM) 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.kaggle.com/c/bms-molecular-translation">BMS Dataset (Kaggle)</a></li>
</ul>
<p><strong>Note</strong>: These notes are based on the Abstract and Supporting Information files only.</p>
<h2 id="image2inchi-as-a-methodological-innovation">Image2InChI as a Methodological Innovation</h2>
<p>This is a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong>. It proposes a specific new deep learning architecture (&ldquo;Image2InChI&rdquo;) to solve the task of Optical Chemical Structure Recognition (OCSR). The rhetorical focus is on engineering a system that outperforms baselines on specific metrics (InChI accuracy, MCS accuracy) and providing a valuable reference for future algorithmic work.</p>
<h2 id="bottlenecks-in-chemical-literature-digitization">Bottlenecks in Chemical Literature Digitization</h2>
<p>The accurate digitization of chemical literature is a bottleneck in AI-driven drug discovery. Chemical structures in patents and papers exist as optical images (pixels), but machine learning models require machine-readable string representations (like <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a> or <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>). Efficiently and automatically bridging this gap is a prerequisite for large-scale data mining in chemistry.</p>
<h2 id="hierarchical-swintransformer-and-attention-integration">Hierarchical SwinTransformer and Attention Integration</h2>
<p>The core novelty is the <strong>Image2InChI</strong> architecture, which integrates:</p>
<ol>
<li><strong>Improved SwinTransformer Encoder</strong>: Uses a hierarchical vision transformer to capture image features.</li>
<li><strong>Feature Fusion with Attention</strong>: A novel network designed to integrate image patch features with InChI prediction steps.</li>
<li><strong>End-to-End InChI Prediction</strong>: The architecture frames the problem as a direct image-to-sequence translation targeting InChI strings directly, diverging from techniques predicting independent graph components. The model is optimized using a standard Cross-Entropy Loss over the token vocabulary:
$$ \mathcal{L}<em>{\text{CE}} = - \sum</em>{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{X}) $$
where $\mathbf{X}$ represents the input image features, $y_t$ is the predicted token, and $T$ is the sequence length.</li>
</ol>
<h2 id="benchmarking-on-the-bms-dataset">Benchmarking on the BMS Dataset</h2>
<ul>
<li><strong>Benchmark Validation</strong>: The model was trained and tested on the <strong>BMS1000 (Bristol-Myers Squibb)</strong> dataset from a Kaggle competition.</li>
<li><strong>Ablation/Comparative Analysis</strong>: The authors compared their method against other models in the supplement.</li>
<li><strong>Preprocessing Validation</strong>: They justified their choice of denoising algorithms (8-neighborhood vs. Gaussian/Mean) to ensure preservation of bond lines while removing &ldquo;spiky point noise&rdquo;.</li>
</ul>
<h2 id="high-inchi-recognition-metrics">High InChI Recognition Metrics</h2>
<ul>
<li><strong>High Accuracy</strong>: The model achieved <strong>99.8% InChI accuracy</strong>, 94.8% Maximum Common Substructure (MCS) accuracy, and 96.2% Longest Common Subsequence (LCS) accuracy on the benchmarked dataset. It remains to be seen how well these models generalize to heavily degraded real-world patent images.</li>
<li><strong>Effective Denoising</strong>: The authors concluded that <strong>eight-neighborhood filtering</strong> is superior to mean or Gaussian filtering for this specific domain because it removes isolated noise points without blurring the fine edges of chemical bonds.</li>
<li><strong>Open Source</strong>: The authors committed to releasing the code to facilitate transparency and further research.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The primary dataset used is the <strong>BMS (Bristol-Myers Squibb) Dataset</strong>.</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Source</strong></td>
          <td>Kaggle Competition (BMS-Molecular-Translation)</td>
      </tr>
      <tr>
          <td><strong>Total Size</strong></td>
          <td>2.4 million images</td>
      </tr>
      <tr>
          <td><strong>Training Set</strong></td>
          <td>1.8 million images</td>
      </tr>
      <tr>
          <td><strong>Test Set</strong></td>
          <td>0.6 million images</td>
      </tr>
      <tr>
          <td><strong>Content</strong></td>
          <td>Each image corresponds to a unique International Chemical Identifier (<a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a>)</td>
      </tr>
  </tbody>
</table>
<p><strong>Other Datasets</strong>: The authors also utilized JPO (Japanese Patent Office), CLEF (CLEF-IP 2012), UOB (MolrecUOB), and USPTO datasets for broader benchmarking.</p>
<p><strong>Preprocessing Pipeline</strong>:</p>
<ol>
<li><strong>Denoising</strong>: <strong>Eight-neighborhood filtering</strong> (threshold &lt; 4 non-white pixels) is used to remove salt-and-pepper noise while preserving bond lines. Mean and Gaussian filtering were rejected due to blurring.</li>
<li><strong>Sequence Padding</strong>:
<ul>
<li>Analysis showed max InChI length &lt; 270.</li>
<li>Fixed sequence length set to <strong>300</strong>.</li>
<li>Tokens: <code>&lt;sos&gt;</code> (190), <code>&lt;eos&gt;</code> (191), <code>&lt;pad&gt;</code> (192) used for padding/framing.</li>
</ul>
</li>
<li><strong>Numerization</strong>: Characters are mapped to integers based on a fixed vocabulary (e.g., &lsquo;C&rsquo; -&gt; 178, &lsquo;H&rsquo; -&gt; 182).</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Eight-Neighborhood Filtering (Denoising)</strong>:</p>
<p>Pseudocode logic:</p>
<ul>
<li>Iterate through every pixel.</li>
<li>Count non-white neighbors in the 3x3 grid (8 neighbors).</li>
<li>If count &lt; threshold (default 4), treat as noise and remove.</li>
</ul>
<p><strong>InChI Tokenization</strong>:</p>
<ul>
<li>InChI strings are split into character arrays.</li>
<li>Example: Vitamin C <code>InChI=1S/C6H8O6...</code> becomes <code>[&lt;sos&gt;, C, 6, H, 8, O, 6, ..., &lt;eos&gt;, &lt;pad&gt;...]</code>.</li>
<li>Mapped to integer tensor for model input.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Image2InChI</p>
<ul>
<li><strong>Encoder</strong>: Improved SwinTransformer (Hierarchical Vision Transformer).</li>
<li><strong>Decoder</strong>: Transformer Decoder with patch embedding.</li>
<li><strong>Fusion</strong>: A novel &ldquo;feature fusion network with attention&rdquo; integrates the visual tokens with the sequence generation process.</li>
<li><strong>Framework</strong>: PyTorch 1.8.1.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>InChI Acc</strong>: Exact match accuracy of the predicted InChI string (Reported: 99.8%).</li>
<li><strong>MCS Acc</strong>: Maximum Common Substructure accuracy (structural similarity) (Reported: 94.8%).</li>
<li><strong>LCS Acc</strong>: Longest Common Subsequence accuracy (string similarity) (Reported: 96.2%).</li>
<li><strong>Morgan FP</strong>: Morgan Fingerprint similarity (Reported: 94.1%).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Specification</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>GPU</strong></td>
          <td>NVIDIA Tesla P100 (16GB VRAM)</td>
      </tr>
      <tr>
          <td><strong>Platform</strong></td>
          <td>MatPool cloud platform</td>
      </tr>
      <tr>
          <td><strong>CPU</strong></td>
          <td>Intel Xeon Gold 6271</td>
      </tr>
      <tr>
          <td><strong>RAM</strong></td>
          <td>32GB System Memory</td>
      </tr>
      <tr>
          <td><strong>Driver</strong></td>
          <td>NVIDIA-SMI 440.100</td>
      </tr>
      <tr>
          <td><strong>OS</strong></td>
          <td>Ubuntu 18.04</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2024image2inchi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Image2InChI: Automated Molecular Optical Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Da-zhou and Xu, Xin and Pan, Jia-heng and Gao, Wei and Zhang, Shi-rui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3640--3649}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.3c02082}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Enhanced DECIMER for Hand-Drawn Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-hand-drawn/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-hand-drawn/</guid><description>An improved encoder-decoder model (EfficientNetV2 + Transformer) converts hand-drawn chemical structures into SMILES strings using synthetic training data.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Brinkhaus, H.O., Zielesny, A. et al. (2024). Advancements in hand-drawn chemical structure recognition through an enhanced DECIMER architecture. <em>Journal of Cheminformatics</em>, 16(78). <a href="https://doi.org/10.1186/s13321-024-00872-7">https://doi.org/10.1186/s13321-024-00872-7</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://pypi.org/project/decimer/">PyPi Package</a></li>
<li><a href="https://doi.org/10.5281/zenodo.10781330">Model Weights (Zenodo)</a></li>
</ul>
<h2 id="method-contribution-architectural-optimization">Method Contribution: Architectural Optimization</h2>
<p>This is a <strong>Method</strong> paper. It proposes an enhanced neural network architecture (EfficientNetV2 + Transformer) specifically designed to solve the problem of recognizing hand-drawn chemical structures. The primary contribution is architectural optimization and a data-driven training strategy, validated through ablation studies (comparing encoders) and benchmarked against existing rule-based and deep learning tools.</p>
<h2 id="motivation-digitizing-dark-chemical-data">Motivation: Digitizing &ldquo;Dark&rdquo; Chemical Data</h2>
<p>Chemical information in legacy laboratory notebooks and modern tablet-based inputs often exists as hand-drawn sketches.</p>
<ul>
<li><strong>Gap:</strong> Existing Optical Chemical Structure Recognition (OCSR) tools (particularly rule-based ones) lack robustness and fail when images have variability in style, line thickness, or noise.</li>
<li><strong>Need:</strong> There is a critical need for automated tools to digitize this &ldquo;dark data&rdquo; effectively to preserve it and make it machine-readable and searchable.</li>
</ul>
<h2 id="core-innovation-decoder-only-design-and-synthetic-scaling">Core Innovation: Decoder-Only Design and Synthetic Scaling</h2>
<p>The core novelty is the <strong>architectural enhancement</strong> and <strong>synthetic training strategy</strong>:</p>
<ol>
<li><strong>Decoder-Only Transformer:</strong> Moving from a full Transformer (encoder-decoder) to a <strong>Decoder-only</strong> setup significantly improved performance.</li>
<li><strong>EfficientNetV2 Integration:</strong> Replacing standard CNNs or EfficientNetV1 with <strong>EfficientNetV2-M</strong> provided better feature extraction and 2x faster training speeds.</li>
<li><strong>Scale of Synthetic Data:</strong> The authors demonstrate that scaling synthetic training data (up to 152 million images generated by RanDepict) directly correlates with improved generalization to real-world hand-drawn images, without ever training on real hand-drawn data.</li>
</ol>
<h2 id="experimental-setup-ablation-and-real-world-baselines">Experimental Setup: Ablation and Real-World Baselines</h2>
<ul>
<li><strong>Model Selection (Ablation):</strong> Tested three architectures (EfficientNetV2-M + Full Transformer, EfficientNetV1-B7 + Decoder-only, EfficientNetV2-M + Decoder-only) on standard benchmarks (JPO, CLEF, USPTO, UOB).</li>
<li><strong>Data Scaling:</strong> Trained the best model on four progressively larger datasets (from 4M to 152M images) to measure performance gains.</li>
<li><strong>Real-World Benchmarking:</strong> Validated the final model on the <strong>DECIMER Hand-drawn dataset</strong> (5088 real images drawn by volunteers) and compared against 9 other tools (OSRA, MolVec, Img2Mol, MolScribe, etc.).</li>
</ul>
<h2 id="results-and-conclusions-strong-accuracy-on-hand-drawn-scans">Results and Conclusions: Strong Accuracy on Hand-Drawn Scans</h2>
<ul>
<li><strong>Strong Performance:</strong> The final DECIMER model achieved <strong>99.72% valid predictions</strong> and <strong>73.25% exact accuracy</strong> on the hand-drawn benchmark, outperforming the next best tool (MolScribe at ~8% accuracy on this specific dataset).</li>
<li><strong>Robustness:</strong> Deep learning methods outperform rule-based methods (which had &lt;3% accuracy) on hand-drawn data.</li>
<li><strong>Data Saturation:</strong> Quadrupling the dataset from 38M to 152M images yielded only marginal gains (~3% accuracy), suggesting current synthetic data strategies may be hitting a plateau.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained entirely on <strong>synthetic data</strong> generated using the <a href="https://github.com/OBrink/RanDepict">RanDepict</a> toolkit. No real hand-drawn images were used for training.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset Source</th>
          <th>Size (Molecules)</th>
          <th>Images Generated</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (Phase 1)</td>
          <td>ChEMBL-32</td>
          <td>~2.2M</td>
          <td>~4.4M - 13.1M</td>
          <td>Used for model selection</td>
      </tr>
      <tr>
          <td>Training (Phase 2)</td>
          <td>PubChem</td>
          <td>~9.5M</td>
          <td>~38M</td>
          <td>Scaling experiment</td>
      </tr>
      <tr>
          <td>Training (Final)</td>
          <td>PubChem</td>
          <td>~38M</td>
          <td><strong>152.16M</strong></td>
          <td>Final model training</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>DECIMER Hand-Drawn</td>
          <td>5,088 images</td>
          <td>N/A</td>
          <td>Real-world benchmark</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing:</strong></p>
<ul>
<li><a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a> strings length &lt; 300 characters.</li>
<li>Images resized to $512 \x512$.</li>
<li>Images generated with and without &ldquo;hand-drawn style&rdquo; augmentations.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization:</strong> SMILES split by heavy atoms, brackets, bond symbols, and special characters. Start <code>&lt;start&gt;</code> and end <code>&lt;end&gt;</code> tokens added; padded with <code>&lt;pad&gt;</code>.</li>
<li><strong>Optimization:</strong> Adam optimizer with a custom learning rate schedule.</li>
<li><strong>Loss Function:</strong> Trained using focal loss to address class imbalance for rare tokens. The focal loss formulation reduces the relative loss for well-classified examples:
$$
\text{FL}(p_{\text{t}}) = -\alpha_{\text{t}} (1 - p_{\text{t}})^\gamma \log(p_{\text{t}})
$$</li>
<li><strong>Augmentations:</strong> RanDepict applied synthetic distortions to mimic handwriting (wobbly lines, variable thickness, etc.).</li>
</ul>
<h3 id="models">Models</h3>
<p>The final architecture (Model 3) is an Encoder-Decoder structure:</p>
<ul>
<li><strong>Encoder:</strong> <strong>EfficientNetV2-M</strong> (pretrained ImageNet backbone).
<ul>
<li>Input: $512 \times 512 \times 3$ image.</li>
<li>Output Features: $16 \times 16 \times 512$ (reshaped to sequence length 256, dimension 512).</li>
<li><em>Note:</em> The final fully connected layer of the CNN is removed.</li>
</ul>
</li>
<li><strong>Decoder:</strong> <strong>Transformer (Decoder-only)</strong>.
<ul>
<li>Layers: 6</li>
<li>Attention Heads: 8</li>
<li>Embedding Dimension: 512</li>
</ul>
</li>
<li><strong>Output:</strong> Predicted SMILES string token by token.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics used for evaluation:</p>
<ol>
<li><strong>Valid Predictions (%):</strong> Percentage of outputs that are syntactically valid SMILES.</li>
<li><strong>Exact Match Accuracy (%):</strong> Canonical SMILES string identity.</li>
<li><strong>Tanimoto Similarity:</strong> Fingerprint similarity (PubChem fingerprints) between ground truth and prediction.</li>
</ol>
<p><strong>Key Results (Hand-Drawn Dataset):</strong></p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>DECIMER (Ours)</th>
          <th>MolScribe</th>
          <th>Img2Mol</th>
          <th>OSRA (Rule-based)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid Predictions</td>
          <td><strong>99.72%</strong></td>
          <td>95.66%</td>
          <td>98.96%</td>
          <td>54.66%</td>
      </tr>
      <tr>
          <td>Exact Accuracy</td>
          <td><strong>73.25%</strong></td>
          <td>7.65%</td>
          <td>5.25%</td>
          <td>0.57%</td>
      </tr>
      <tr>
          <td>Tanimoto</td>
          <td><strong>0.94</strong></td>
          <td>0.59</td>
          <td>0.52</td>
          <td>0.17</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute:</strong> Google Cloud TPU v4-128 pod slice.</li>
<li><strong>Training Time:</strong>
<ul>
<li>EfficientNetV2-M model trained ~2x faster than EfficientNetV1-B7.</li>
<li>Average training time per epoch: 34 minutes (for Model 3 on 1M dataset subset).</li>
</ul>
</li>
<li><strong>Epochs:</strong> Models trained for 25 epochs.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanAdvancementsHanddrawnChemical2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Advancements in Hand-Drawn Chemical Structure Recognition through an Enhanced {{DECIMER}} Architecture}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Brinkhaus, Henning Otto and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{78}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00872-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Dual-Path Global Awareness Transformer (DGAT)</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/dgat/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/dgat/</guid><description>A Transformer-based OCSR model introducing dual-path modules (CGFE and SDGLA) to improve global context awareness and complex motif recognition.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, R., Ji, Y., Li, Y., &amp; Lee, S.-T. (2025). Dual-Path Global Awareness Transformer for Optical Chemical Structure Recognition. <em>The Journal of Physical Chemistry Letters</em>, 16(50), 12787-12795. <a href="https://doi.org/10.1021/acs.jpclett.5c03057">https://doi.org/10.1021/acs.jpclett.5c03057</a></p>
<p><strong>Publication</strong>: The Journal of Physical Chemistry Letters 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Drwr97/DGAT">GitHub Repository</a></li>
</ul>
<h2 id="contribution-type-deep-learning-method-for-ocsr">Contribution Type: Deep Learning Method for OCSR</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>The classification is based on the proposal of a novel deep learning architecture (DGAT) designed to address specific limitations in existing Optical Chemical Structure Recognition (OCSR) systems. The contribution is validated through rigorous benchmarking against external baselines (DeepOCSR, DECIMER, SwinOCSR) and ablation studies that isolate the impact of the new modules.</p>
<h2 id="motivation-addressing-global-context-loss">Motivation: Addressing Global Context Loss</h2>
<p>Existing multimodal fusion methods for OCSR suffer from limited awareness of global context.</p>
<ul>
<li><strong>Problem</strong>: Models often generate erroneous sequences when processing complex motifs, such as rings or long chains, due to a disconnect between local feature extraction and global structural understanding.</li>
<li><strong>Gap</strong>: Current architectures struggle to capture the &ldquo;fine-grained differences between global and local features,&rdquo; leading to topological errors.</li>
<li><strong>Practical Need</strong>: Accurate translation of chemical images to machine-readable sequences (SMILES/SELFIES) is critical for materials science and AI-guided chemical research.</li>
</ul>
<h2 id="core-innovation-dual-path-global-awareness-transformer">Core Innovation: Dual-Path Global Awareness Transformer</h2>
<p>The authors propose the <strong>Dual-Path Global Awareness Transformer (DGAT)</strong>, which redesigns the decoder with two novel mechanisms to better handle global context:</p>
<ol>
<li>
<p><strong>Cascaded Global Feature Enhancement (CGFE)</strong>: This module bridges cross-modal gaps by emphasizing global context. It concatenates global visual features with sequence features and processes them through a Cross-Modal Assimilation MLP and an Adaptive Alignment MLP to align multimodal representations. The feature enhancement conceptually computes:</p>
<p>$$ f_{\text{enhanced}} = \text{MLP}<em>{\text{align}}(\text{MLP}</em>{\text{assimilate}}([f_{\text{global}}, f_{\text{seq}}])) $$</p>
</li>
<li>
<p><strong>Sparse Differential Global-Local Attention (SDGLA)</strong>: A module that dynamically captures fine-grained differences between global and local features. It uses sequence features (embedded with global info) as queries, while utilizing local and global visual features as keys/values in parallel attention heads to generate initial multimodal features.</p>
</li>
</ol>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The model was evaluated on a newly constructed dataset and compared against five major baselines.</p>
<ul>
<li><strong>Baselines</strong>: DeepOCSR, DECIMER 1.0, DECIMER V2, SwinOCSR, and MPOCSR.</li>
<li><strong>Ablation Studies</strong>:
<ul>
<li><strong>Layer Depth</strong>: Tested Transformer depths from 1 to 5 layers; 3 layers proved optimal for balancing gradient flow and parameter sufficiency.</li>
<li><strong>Beam Size</strong>: Tested inference beam sizes 1-5; size 3 achieved the best balance between search depth and redundancy.</li>
<li><strong>Module Contribution</strong>: Validated that removing CGFE results in a drop in structural similarity (Tanimoto), proving the need for pre-fusion alignment.</li>
</ul>
</li>
<li><strong>Robustness Analysis</strong>: Performance broken down by molecule complexity (atom count, ring count, bond count).</li>
<li><strong>Chirality Validation</strong>: Qualitative analysis of attention maps on chiral molecules to verify the model learns stereochemical cues implicitly.</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>Performance Over Baselines</strong>: DGAT outperformed the MPOCSR baseline across all metrics:
<ul>
<li><strong>BLEU-4</strong>: 84.0% (+5.3% improvement)</li>
<li><strong>ROUGE</strong>: 90.8% (+1.9% improvement)</li>
<li><strong>Tanimoto Similarity</strong>: 98.8% (+1.2% improvement)</li>
<li><strong>Exact Match Accuracy</strong>: 54.6% (+10.9% over SwinOCSR)</li>
</ul>
</li>
<li><strong>Chiral Recognition</strong>: The model explicitly recognizes chiral centers (e.g., generating <code>[C@@H1]</code> tokens correctly) based on 2D wedge cues without direct stereochemical supervision.</li>
<li><strong>Limitations</strong>: Performance drops for extreme cases, such as molecules with 4+ rings or 4+ double/triple bonds, due to dataset imbalance. The model still hallucinate branches in highly complex topologies.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is primarily drawn from PubChem and augmented to improve robustness.</p>
<ul>
<li><strong>Augmentation Strategy</strong>: Each sequence generates three images with random rendering parameters.
<ul>
<li><strong>Rotation</strong>: 0, 90, 180, 270, or random [0, 360)</li>
<li><strong>Bond Width</strong>: 1, 2, or 3 pixels</li>
<li><strong>Bond Offset</strong>: Sampled from 0.08-0.18 (inherited from Image2SMILES)</li>
<li><strong>CoordGen</strong>: Enabled with 20% probability</li>
</ul>
</li>
<li><strong>Evaluation Set</strong>: A newly constructed benchmark dataset was used for final reporting.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Training Configuration</strong>:
<ul>
<li><strong>Encoder LR</strong>: $5 \x10^{-5}$ (Pretrained ResNet-101)</li>
<li><strong>Decoder LR</strong>: $1 \x10^{-4}$ (Randomly initialized Transformer)</li>
<li><strong>Optimizer</strong>: Implied SGD/Adam (context mentions Momentum 0.9, Weight Decay 0.0001)</li>
<li><strong>Batch Size</strong>: 256</li>
</ul>
</li>
<li><strong>Inference</strong>:
<ul>
<li><strong>Beam Search</strong>: A beam size of <strong>3</strong> is used. Larger beam sizes (4-5) degraded BLEU/ROUGE scores due to increased redundancy.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Visual Encoder</strong>:
<ul>
<li><strong>Backbone</strong>: ResNet-101 initialized with ImageNet weights</li>
<li><strong>Structure</strong>: Convolutional layers preserved up to the final module. Classification head removed.</li>
<li><strong>Pooling</strong>: A $7 \x7$ average pooling layer is used to extract global visual features.</li>
</ul>
</li>
<li><strong>Sequence Decoder</strong>:
<ul>
<li><strong>Architecture</strong>: Transformer-based with CGFE and SDGLA modules.</li>
<li><strong>Depth</strong>: 3 Transformer layers</li>
<li><strong>Dropout</strong>: Not utilized</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is reported using sequence-level and structure-level metrics.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">DGAT Score</th>
          <th style="text-align: left">Baseline (MPOCSR)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>BLEU-4</strong></td>
          <td style="text-align: left"><strong>84.0%</strong></td>
          <td style="text-align: left">78.7%</td>
          <td style="text-align: left">Measures n-gram precision</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>ROUGE</strong></td>
          <td style="text-align: left"><strong>90.8%</strong></td>
          <td style="text-align: left">88.9%</td>
          <td style="text-align: left">Sequence recall metric</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Tanimoto</strong></td>
          <td style="text-align: left"><strong>98.8%</strong></td>
          <td style="text-align: left">97.6%</td>
          <td style="text-align: left">Structural similarity fingerprint</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Accuracy</strong></td>
          <td style="text-align: left"><strong>54.6%</strong></td>
          <td style="text-align: left">35.7%</td>
          <td style="text-align: left">Exact structure match rate</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wang2025dgat,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Dual-Path Global Awareness Transformer for Optical Chemical Structure Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wang, Rui and Ji, Yujin and Li, Youyong and Lee, Shuit-Tong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{The Journal of Physical Chemistry Letters}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{50}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{12787--12795}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jpclett.5c03057}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DECIMER.ai: Optical Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-ai/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-ai/</guid><description>Open-source OCSR platform combining Mask R-CNN segmentation and Transformer recognition, trained on 450M+ synthetic images from RanDepict.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Brinkhaus, H. O., Agea, M. I., Zielesny, A., &amp; Steinbeck, C. (2023). DECIMER.ai: an open platform for automated optical chemical structure identification, segmentation and recognition in scientific publications. <em>Nature Communications</em>, 14(1), 5045. <a href="https://doi.org/10.1038/s41467-023-40782-0">https://doi.org/10.1038/s41467-023-40782-0</a></p>
<p><strong>Publication</strong>: Nature Communications 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://decimer.ai">Web Application</a></li>
<li><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER Image Transformer GitHub</a></li>
<li><a href="https://github.com/OBrink/RanDepict">RanDepict GitHub</a></li>
</ul>
<h2 id="project-scope-and-contribution-type">Project Scope and Contribution Type</h2>
<p>This is primarily a <strong>Resource</strong> paper (Infrastructure Basis) with a significant <strong>Method</strong> component.</p>
<p>The primary contribution is DECIMER.ai, a fully open-source platform (web app and Python packages) for the entire chemical structure mining pipeline, filling a gap where most tools were proprietary or fragmented. It also contributes the RanDepict toolkit for massive synthetic data generation.</p>
<p>The secondary methodological contribution proposes and validates a specific deep learning architecture (EfficientNet-V2 encoder + Transformer decoder) that treats chemical structure recognition as an image-to-text translation task (SMILES generation).</p>
<h2 id="the-scarcity-of-machine-readable-chemical-data">The Scarcity of Machine-Readable Chemical Data</h2>
<p><strong>Data Scarcity</strong>: While the number of chemical publications is increasing, most chemical information is locked in non-machine-readable formats (images in PDFs) and is not available in public databases.</p>
<p><strong>Limitations of Existing Tools</strong>: Prior OCSR (Optical Chemical Structure Recognition) tools were largely rule-based (fragile to noise) or proprietary.</p>
<p><strong>Lack of Integration</strong>: There was no existing open-source system that combined segmentation (finding the molecule on a page), classification (confirming it is a molecule), and recognition (translating it to SMILES) into a single workflow.</p>
<h2 id="decimer-architecture-and-novel-image-to-smiles-approach">DECIMER Architecture and Novel Image-to-SMILES Approach</h2>
<p><strong>Comprehensive Workflow</strong>: It is the first open-source platform to integrate segmentation (Mask R-CNN), classification (EfficientNet), and recognition (Transformer) into a unified pipeline.</p>
<p><strong>Data-Driven Approach</strong>: Unlike tools like MolScribe which use intermediate graph representations and rules, DECIMER uses a purely data-driven &ldquo;image-to-SMILES&rdquo; translation approach without hard-coded chemical rules. The core recognition model operates as an sequence-to-sequence generator, mathematically formalizing the task as maximizing the conditional probability of a SMILES sequence given an image.</p>
<p><strong>Massive Synthetic Training</strong>: The use of RanDepict to generate over 450 million synthetic images, covering diverse depiction styles and augmentations (including Markush structures), to train the model from scratch.</p>
<h2 id="benchmarking-and-evaluation-methodology">Benchmarking and Evaluation Methodology</h2>
<p><strong>Benchmarking</strong>: The system was tested against openly available tools (OSRA, MolVec, Imago, Img2Mol, SwinOCSR, MolScribe) on standard datasets: USPTO, UOB, CLEF, JPO, and a custom &ldquo;Hand-drawn&rdquo; dataset.</p>
<p><strong>Robustness Testing</strong>: Performance was evaluated on both clean images and images with added distortions (rotation, shearing) to test the fragility of rule-based systems vs. DECIMER.</p>
<p><strong>Markush Structure Analysis</strong>: Specific evaluation of the model&rsquo;s ability to interpret Markush structures (generic structures with R-groups).</p>
<p><strong>Comparison of Approaches</strong>: A direct comparison with MolScribe by training DECIMER on MolScribe&rsquo;s smaller training set to isolate the impact of architecture vs. data volume.</p>
<h2 id="performance-outcomes-and-key-findings">Performance Outcomes and Key Findings</h2>
<p><strong>Comparative Performance</strong>: DECIMER Image Transformer achieved the highest average Tanimoto similarity (&gt;0.95) across all benchmarks and extremely low rates of catastrophic failure. Tanimoto similarity is calculated based on molecular fingerprints $A$ and $B$ as:
$$ T(A, B) = \frac{A \cdot B}{|A|^2 + |B|^2 - A \cdot B} $$</p>
<p><strong>Data Volume Necessity</strong>: When trained on small datasets, MolScribe (graph/rule-based) outperformed DECIMER. DECIMER&rsquo;s performance advantage relies heavily on its massive training scale (&gt;400M images).</p>
<p><strong>Robustness</strong>: The model showed no performance degradation on distorted images, unlike rule-based legacy tools.</p>
<p><strong>Generalization</strong>: Despite having no hand-drawn images in the training set, the model performed competitively on hand-drawn benchmarks, suggesting strong generalization capabilities.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The models were trained on synthetic data generated from PubChem molecules.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Generation/Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_1</code></td>
          <td>~100M mols</td>
          <td>PubChem molecules (mass &lt; 1500 Da), processed with RanDepict (v1.0.5). Included image augmentations.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_2</code></td>
          <td>~126M mols</td>
          <td>Included Markush structures generated by pseudo-randomly replacing atoms with R-groups. Image size 299x299.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_3</code></td>
          <td>&gt;453M images</td>
          <td>Re-depicted <code>pubchem_2</code> molecules at <strong>512x512</strong> resolution. Used RanDepict v1.0.8.</td>
      </tr>
      <tr>
          <td><strong>Test</strong></td>
          <td>In-domain</td>
          <td>250,000</td>
          <td>Held-out set generated similarly to training data.</td>
      </tr>
      <tr>
          <td><strong>Benchmark</strong></td>
          <td>External</td>
          <td>Various</td>
          <td>USPTO (5719), UOB (5740), CLEF (992), JPO (450), Indigo (50k), Hand-drawn (5088).</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Generation</strong>:</p>
<ul>
<li><strong>Tool</strong>: RanDepict (uses CDK, RDKit, Indigo, PIKACHU)</li>
<li><strong>Augmentations</strong>: Rotation, shearing, noise, pixelation, curved arrows, text labels</li>
<li><strong>Format</strong>: Data saved as TFRecord files for TPU training</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES Tokenization</strong>: Regex-based splitting (atoms, brackets, bonds). Added <code>&lt;start&gt;</code>, <code>&lt;end&gt;</code>, and padded with <code>&lt;pad&gt;</code>. <code>&lt;unk&gt;</code> used for unknown tokens.</li>
<li><strong>Markush Token Handling</strong>: To avoid ambiguity, digits following &lsquo;R&rsquo; (e.g., R1) were replaced with unique non-digit characters during training to distinguish them from ring-closure numbers.</li>
<li><strong>Image Augmentation Pipeline</strong>: Albumentations and custom RanDepict features were used to simulate &ldquo;hand-drawn-like&rdquo; styles.</li>
</ul>
<h3 id="models">Models</h3>
<p>The platform consists of three distinct models:</p>
<ol>
<li>
<p><strong>DECIMER Segmentation</strong>:</p>
<ul>
<li><strong>Architecture</strong>: Mask R-CNN (TensorFlow 2.10.0 implementation)</li>
<li><strong>Purpose</strong>: Detects and cuts chemical structures from full PDF pages</li>
</ul>
</li>
<li>
<p><strong>DECIMER Image Classifier</strong>:</p>
<ul>
<li><strong>Architecture</strong>: EfficientNet-B0 (Noisy Student weights)</li>
<li><strong>Input</strong>: 224x224 pixels</li>
<li><strong>Training</strong>: Fine-tuned on ~10M images (balanced chemical/non-chemical)</li>
<li><strong>Performance</strong>: AUC 0.99 on in-domain test set</li>
</ul>
</li>
<li>
<p><strong>DECIMER Image Transformer (OCSR Engine)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: EfficientNet-V2-M (CNN). Input size <strong>512x512</strong>. 52M parameters</li>
<li><strong>Decoder</strong>: Transformer. 4 encoder blocks, 4 decoder blocks, 8 attention heads. d_model=512, d_ff=2048. 59M parameters</li>
<li><strong>Total Params</strong>: ~111 Million</li>
</ul>
</li>
</ol>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Tanimoto Similarity (calculated on PubChem fingerprints of the predicted vs. ground truth SMILES)</li>
<li><strong>Secondary Metrics</strong>: Exact Match (Identity), BLEU score (for string similarity, esp. Markush)</li>
<li><strong>Failure Analysis</strong>: &ldquo;Catastrophic failure&rdquo; defined as Tanimoto similarity of 0 or invalid SMILES</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on Google Cloud TPUs due to the massive dataset size.</p>
<ul>
<li><strong><code>pubchem_1</code>/<code>pubchem_2</code></strong>: Trained on TPU v3-32 pod slice</li>
<li><strong><code>pubchem_3</code> (Final Model)</strong>: Trained on <strong>TPU v3-256</strong> pod slice</li>
<li><strong>Training Time</strong>:
<ul>
<li>Data generation (512x512): ~2 weeks on cluster (20 threads, 36 cores)</li>
<li>Model Training (EffNet-V2-M): <strong>1 day and 7 hours per epoch</strong> on TPU v3-256</li>
</ul>
</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMERaiOpenPlatform2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{DECIMER.ai: an open platform for automated optical chemical structure identification, segmentation and recognition in scientific publications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Brinkhaus, Henning Otto and Agea, M. Isabel and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{5045}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41467-023-40782-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemVLM: A Multimodal Large Language Model for Chemistry</title><link>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemvlm/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/chemical-language-models/chemvlm/</guid><description>A 26B parameter multimodal LLM for chemistry, combining InternViT-6B and ChemLLM-20B for molecular structure recognition, property prediction, and reasoning.</description><content:encoded><![CDATA[<h2 id="paper-classification-method-and-resource">Paper Classification: Method and Resource</h2>
<p>This paper is a combination of <strong>Method</strong> (primary) and <strong>Resource</strong> (secondary).</p>
<p>It is primarily a <strong>Method</strong> paper because it proposes <strong>ChemVLM</strong>, a novel multimodal architecture specifically tailored for the chemical domain, utilizing a &ldquo;ViT-MLP-LLM&rdquo; framework. The authors introduce a specific two-stage training strategy to align visual features with chemical text representations.</p>
<p>Secondarily, it is a <strong>Resource</strong> paper as it introduces a comprehensive suite of three new datasets: <strong>ChemOCR</strong>, <strong>MMCR-Bench</strong>, and <strong>MMChemBench</strong>, developed to rigorously evaluate multimodal capabilities in chemistry, covering OCR, reasoning, and property prediction.</p>
<h2 id="bridging-the-visual-gap-in-chemical-llms">Bridging the Visual Gap in Chemical LLMs</h2>
<p>The primary motivation is the limitation of existing models in handling the multimodal nature of chemistry.</p>
<ul>
<li><strong>Visual Data Gap</strong>: Chemical tasks heavily rely on visual information (molecular structures, reactions) which purely text-based chemical LLMs cannot process.</li>
<li><strong>Limitations of Generalist Models</strong>: General multimodal models (like GPT-4V or LLaVA) lack specialized chemical domain knowledge, leading to hallucinations or misinterpretations.</li>
<li><strong>Inadequacy of OCR Tools</strong>: Traditional chemical OCR tools (like MolScribe) excel at modality conversion (Image-to-SMILES) but fail at complex reasoning tasks.</li>
</ul>
<h2 id="domain-specific-data-curation-and-benchmarking">Domain-Specific Data Curation and Benchmarking</h2>
<ul>
<li><strong>Data-Driven Alignment</strong>: The underlying &ldquo;ViT-MLP-LLM&rdquo; framework is standard in multimodal modeling, paralleling architectures like LLaVA. The core innovation here is the rigorous creation of a bilingual multimodal dataset spanning hand-drawn molecules, reactions, and exam questions augmented with style transfers. The training data pipeline heavily relies on generating synthetic variance using tools like RanDepict and RDKit to introduce distortions, rotations, and handwritten styles, alongside GPT-4 generated prompts to ensure linguistic diversity.</li>
<li><strong>Model Integration</strong>: ChemVLM merges <strong>InternViT-6B</strong> (a large-scale vision transformer) with <strong>ChemLLM-20B</strong> (a chemical language model). Visual features $X_v$ are mapped into the linguistic embedding space via an MLP projector, producing aligned token sequences alongside text instructions $X_q$. The joint multimodal sequence is trained using standard autoregressive next-token prediction:
$$ \mathcal{L} = -\sum_{i} \log P(y_i \mid X_v, X_q, y_{&lt;i}) $$</li>
<li><strong>Three Custom Benchmarks</strong>: The authors introduce tailored benchmarks to assess distinct competencies:
<ul>
<li><strong>ChemOCR</strong>: For image-to-SMILES conversion.</li>
<li><strong>MMCR-Bench</strong>: College entrance exam questions testing complex logical reasoning.</li>
<li><strong>MMChemBench</strong>: For molecule captioning and zero-shot property prediction.</li>
</ul>
</li>
</ul>
<h2 id="evaluating-chemical-ocr-and-reasoning">Evaluating Chemical OCR and Reasoning</h2>
<p>The authors benchmarked ChemVLM against both open-source (LLaVA, Qwen-VL, InternVL) and proprietary (GPT-4V) models across three primary domains:</p>
<ol>
<li><strong>Chemical OCR</strong>: Evaluated on 1,000 image-text pairs from ChemOCR. The primary metric is the Tanimoto similarity between the Morgan fingerprints of the generated structure ($A$) and the ground-truth SMILES ($B$):
$$ T(A, B) = \frac{|A \cap B|}{|A| + |B| - |A \cap B|} $$
They report both the average Tanimoto similarity and the strict exact-match rate (<code>Tanimoto@1.0</code>).</li>
<li><strong>Multimodal Chemical Reasoning (MMCR)</strong>: Tested on MMCR-Bench (1,000 exam questions), ScienceQA, and CMMU. Performance was scored based on accuracy for multiple-choice and fill-in-the-blank questions.</li>
<li><strong>Multimodal Molecule Understanding</strong>: Evaluated on MMChemBench for molecule captioning and property prediction.</li>
<li><strong>Text-Only Reasoning</strong>: Tested on SciBench, a text-only benchmark for university-level science, to ensure the model retains fundamental linguistic reasoning.</li>
<li><strong>Generalization</strong>: Tested on non-chemistry subjects within the CMMU framework (Biology, Physics, Math) to assess cross-domain competence.</li>
</ol>
<h2 id="performance-gains-and-existing-limitations">Performance Gains and Existing Limitations</h2>
<ul>
<li><strong>Multimodal Reasoning Leadership</strong>: ChemVLM achieved state-of-the-art results on MMCR-Bench (41.7%), surpassing generalist models like GPT-4V (40.1%). However, scoring for portions of these benchmarks relied heavily on an LLM-as-a-judge (the Qwen-max API), which can introduce bias as LLM evaluators often favor structural characteristics and verbosity produced by similar autoregressive models. Furthermore, the model was fine-tuned on 200,000 exam questions and tested on MMCR-Bench (also derived from Chinese college entrance exams). While the authors state the data was deduplicated, the potential for data leakage remains a significant unaddressed confounder.</li>
<li><strong>Superior Understanding</strong>: In molecule captioning and prediction, ChemVLM showed significant improvements over general baseline models, scoring 80.9% on prediction compared to GPT-4V&rsquo;s 38.6%. This is a natural consequence of testing a custom-trained model on domain-specific benchmarks.</li>
<li><strong>OCR Capabilities vs. Dedicated Tools</strong>: ChemVLM outperformed generalist MLLMs in chemical structure recognition, achieving an average Tanimoto similarity of 71.0% (vs. GPT-4V&rsquo;s 15.0%). However, it remains significantly inferior to pure structural OCR tools like MolScribe in strict modality conversion tasks, only achieving an exact structural match (<code>Tanimoto@1.0</code>) of 42.9% compared to MolScribe&rsquo;s 89.1%.</li>
<li><strong>Textual Retention and Generalization Claims</strong>: The authors claim the diverse training strategy imparts broad scientific reasoning, pointing to performance retention on non-chemistry subjects (Biology, Physics, Math) and strong results on the purely textual SciBench benchmark. However, this cross-domain generalization highly likely stems from the underlying base model (ChemLLM-20B/InternLM2) or the inclusion of 1.3 million &ldquo;General&rdquo; visual QA pairs in their training blend, rather than emergent general scientific skills originating purely from learning chemistry representations.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training and evaluation data relied on a mix of open-source repositories and custom curation. Many of the curated datasets have been formally released by the authors on Hugging Face (<a href="https://huggingface.co/datasets/di-zhang-fdu/chemvlm-sft-datasets"><code>di-zhang-fdu/chemvlm-sft-datasets</code></a>).</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Source/Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong>DECIMER HDM</strong></td>
          <td>7,000+ hand-drawn molecular images.</td>
      </tr>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong>MolScribe Data</strong></td>
          <td>Scanned/photographed images from literature.</td>
      </tr>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong>Synthetic</strong></td>
          <td>Generated via ChemDraw, RDKit, and Indigo with style transfer (blurring, rotation, handwritten styles).</td>
      </tr>
      <tr>
          <td><strong>Training (Reaction)</strong></td>
          <td><strong>PEACE &amp; USPTO-50K</strong></td>
          <td>Inorganic and organic reaction schemes.</td>
      </tr>
      <tr>
          <td><strong>Training (Reasoning)</strong></td>
          <td><strong>Exam Questions</strong></td>
          <td>200,000 questions from OpenDataLab (Chinese education level). <a href="https://huggingface.co/collections/di-zhang-fdu/multi-corpus-datasets-for-chemllm-66e7f14fd683a3f4e51d737b">Available on Hugging Face</a>.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>ChemOCR</strong></td>
          <td>1,000 bilingual image-text pairs for SMILES recognition. Released via Google Drive link in repo.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>MMCR-Bench</strong></td>
          <td>1,000 multimodal chemistry exam questions. <strong>Requires emailing authors directly for access.</strong></td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>MMChemBench</strong></td>
          <td>Extension of ChemBench for captioning and property prediction. Released via Google Drive link in repo.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>: Images were augmented using <strong>RanDepict</strong> for style variation. Text data (SMILES) was validated and cleaned. Prompts were diversified using GPT-4 to generate different linguistic styles.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: &ldquo;ViT-MLP-LLM&rdquo; structure.
<ul>
<li><strong>Vision Encoder</strong>: InternViT-6B, processing images at $448 \times 448$ resolution. Images are segmented into tiles (max 12).</li>
<li><strong>Projector</strong>: Multi-Layer Perceptron (MLP) initialized randomly to map visual features to text embedding space.</li>
<li><strong>LLM</strong>: ChemLLM-20B, a domain-specific model.</li>
</ul>
</li>
<li><strong>Training Strategy</strong>: Two-stage supervised fine-tuning.
<ol>
<li><strong>Modal Alignment</strong>: Freeze LLM and base Vision Encoder weights. Train only the randomly initialized MLP projector and LoRA layers (rank 32) of the Vision Encoder. Uses diverse multimodal data.</li>
<li><strong>Supervised Fine-Tuning (SFT)</strong>: Keep LLM and Vision Encoder base weights frozen, but add LoRA (rank 16) to the LLM and retain LoRA (rank 32) on the Vision Encoder. The MLP projector is fully trained. Data includes specialized chemistry and general corpora.</li>
</ol>
</li>
<li><strong>Optimization</strong>:
<ul>
<li>Optimizer: AdamW</li>
<li>Context Length: 2048 tokens</li>
<li>Chat Template: InternLM2 dialogue schema</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>ChemVLM-26B</strong>: The primary model released. It combines the 6B parameter vision encoder and the 20B parameter language model. Weights are fully available at <a href="https://huggingface.co/AI4Chem/ChemVLM-26B-1-2"><code>AI4Chem/ChemVLM-26B-1-2</code></a>. An 8B version is also available.</li>
<li><strong>Baselines</strong>: Comparisons were made against <strong>GPT-4V</strong>, <strong>Qwen-VL-Chat</strong>, <strong>LLaVA-v1.5-13B</strong>, <strong>InternVL-v1.5</strong>, and <strong>Yi-VL-Plus</strong>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance was measured across three distinct task types. Exact <a href="https://github.com/lijunxian111/ChemVlm/tree/master/evaluation">evaluation scripts</a> have been released in the official repository.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Method</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto Similarity</strong></td>
          <td>ChemOCR</td>
          <td>Comparison of generated SMILES vs. ground truth using RDKit. Reports Average Similarity and <code>Tanimoto@1.0</code> (exact match).</td>
      </tr>
      <tr>
          <td><strong>Accuracy</strong></td>
          <td>MMCR (Reasoning)</td>
          <td>+1 point for correct multiple-choice/fill-in-the-blank; 0 otherwise. Scored via Qwen-max API prompting.</td>
      </tr>
      <tr>
          <td><strong>Prediction Score</strong></td>
          <td>Property Prediction</td>
          <td>Evaluated on MMChemBench subsets.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Compute</strong>: Training utilized <strong>16 NVIDIA A100 (80GB)</strong> GPUs.</li>
<li><strong>Configuration</strong>:
<ul>
<li>Batch size: 4 (per GPU, resulting in an effective global batch size of 256)</li>
<li>Gradient Accumulation: 4 iterations</li>
<li>Precision: <strong>Deepspeed bfloat16 (bf16)</strong> with <strong>ZeRO-3</strong> offloading strategy</li>
<li>Framework: Training runs on the InternVL-v1.5 codebase rather than standalone scripts.</li>
</ul>
</li>
<li><strong>Inference Compute</strong>: Evaluating the 26B model requires at least one 80GB A100 GPU (with Flash Attention + bfloat16). The 8B variant requires a GPU with at least 48GB of VRAM.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemVLM-26B">ChemVLM-26B</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Original 26B model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemVLM-26B-1-2">ChemVLM-26B-1-2</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Updated 26B model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/di-zhang-fdu/chemvlm-sft-datasets">chemvlm-sft-datasets</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>SFT training data (~51.7k rows)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/lijunxian111/ChemVlm">ChemVlm (GitHub)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training, evaluation, and inference code</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, J., et al. (2025). ChemVLM: Exploring the Power of Multimodal Large Language Models in Chemistry Area. <em>Proceedings of the AAAI Conference on Artificial Intelligence</em>, 39(1), 415-423. <a href="https://doi.org/10.1609/aaai.v39i1.32020">https://doi.org/10.1609/aaai.v39i1.32020</a></p>
<p><strong>Publication</strong>: AAAI 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{li2025chemvlm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemVLM: Exploring the Power of Multimodal Large Language Models in Chemistry Area}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Junxian and Zhang, Di and Wang, Xunzhi and Hao, Zeying and Lei, Jingdi and Tan, Qian and Zhou, Cai and Liu, Wei and Yang, Yaotian and Xiong, Xinrui and Wang, Weiyun and Chen, Zhe and Wang, Wenhai and Li, Wei and Su, Mao and Zhang, Shufei and Ouyang, Wanli and Li, Yuqiang and Zhou, Dongzhan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the AAAI Conference on Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{39}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{415--423}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://doi.org/10.1609/aaai.v39i1.32020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1609/aaai.v39i1.32020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/lijunxian111/ChemVlm">Official Repository</a></li>
</ul>
]]></content:encoded></item><item><title>ChemReco: Hand-Drawn Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemreco/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemreco/</guid><description>A deep learning method using EfficientNet and Transformer to convert hand-drawn chemical structures into SMILES codes, achieving 96.9% accuracy.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ouyang, H., Liu, W., Tao, J., et al. (2024). ChemReco: automated recognition of hand-drawn carbon-hydrogen-oxygen structures using deep learning. <em>Scientific Reports</em>, 14, 17126. <a href="https://doi.org/10.1038/s41598-024-67496-7">https://doi.org/10.1038/s41598-024-67496-7</a></p>
<p><strong>Publication</strong>: Scientific Reports 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/a-die/hdr-DeepLearning">Official Code Repository</a></li>
</ul>
<h2 id="research-contribution--classification">Research Contribution &amp; Classification</h2>
<p>This is a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a significant <strong>Resource ($\Psi_{\text{Resource}}$)</strong> component.</p>
<ul>
<li><strong>Method</strong>: The primary contribution is &ldquo;ChemReco,&rdquo; a specific deep learning pipeline (EfficientNet + Transformer) designed to solve the Optical Chemical Structure Recognition (OCSR) task for hand-drawn images. The authors conduct extensive ablation studies on architecture and data mixing ratios to validate performance.</li>
<li><strong>Resource</strong>: The authors explicitly state that &ldquo;the primary focus of this paper is constructing datasets&rdquo; due to the scarcity of hand-drawn molecular data. They introduce a comprehensive synthetic data generation pipeline involving RDKit modifications and image degradation to create training data.</li>
</ul>
<h2 id="motivation-digitizing-hand-drawn-chemical-sketches">Motivation: Digitizing Hand-Drawn Chemical Sketches</h2>
<p>Hand-drawing is the most intuitive method for chemists and students to record molecular structures. However, digitizing these drawings into machine-readable formats (like <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>) usually requires time-consuming manual entry or specialized software.</p>
<ul>
<li><strong>Gap</strong>: Existing OCSR tools and rule-based methods often fail on hand-drawn sketches due to diverse writing styles, poor image quality, and the absence of labeled data.</li>
<li><strong>Application</strong>: Automated recognition enables efficient chemical research and allows for automatic grading in educational settings.</li>
</ul>
<h2 id="core-innovation-synthetic-pipeline-and-hybrid-architecture">Core Innovation: Synthetic Pipeline and Hybrid Architecture</h2>
<p>The paper introduces <strong>ChemReco</strong>, an end-to-end system for recognizing C-H-O structures. Key novelties include:</p>
<ol>
<li><strong>Synthetic Data Pipeline</strong>: A multi-stage generation method that modifies RDKit source code to randomize bond/angle parameters, followed by OpenCV-based augmentation, degradation, and background addition to simulate realistic hand-drawn artifacts.</li>
<li><strong>Architectural Choice</strong>: The specific application of <strong>EfficientNet</strong> (encoder) combined with a <strong>Transformer</strong> (decoder) for this domain, which the authors demonstrate outperforms the more common ResNet+LSTM baselines.</li>
<li><strong>Hybrid Training Strategy</strong>: Finding that a mix of 90% synthetic and 10% real data yields optimal performance, superior to using either dataset alone.</li>
</ol>
<h2 id="methodology--ablation-studies">Methodology &amp; Ablation Studies</h2>
<p>The authors performed a series of ablation studies and comparisons:</p>
<ul>
<li><strong>Synthesis Ablation</strong>: Evaluated the impact of each step in the generation pipeline (RDKit only $\rightarrow$ Augmentation $\rightarrow$ Degradation $\rightarrow$ Background) on validation loss and accuracy.</li>
<li><strong>Dataset Size Ablation</strong>: Tested model performance when trained on synthetic datasets ranging from 100,000 to 1,000,000 images.</li>
<li><strong>Real/Synthetic Ratio</strong>: Investigated the optimal mixing ratio of real hand-drawn images to synthetic images (0:100, 10:90, 50:50, 90:10, 100:0).</li>
<li><strong>Architecture Comparison</strong>: Benchmarked four encoder-decoder combinations: ResNet vs. EfficientNet encoders paired with LSTM vs. Transformer decoders.</li>
<li><strong>Baseline Comparison</strong>: Compared results against a related study utilizing a CNN+LSTM framework.</li>
</ul>
<h2 id="results--interpretations">Results &amp; Interpretations</h2>
<ul>
<li><strong>Best Performance</strong>: The EfficientNet + Transformer model trained on a 90:10 synthetic-to-real ratio achieved a <strong>96.90% Exact Match</strong> rate on the test set.</li>
<li><strong>Background Robustness</strong>: Adding random backgrounds during training significantly improved accuracy on complex test images (approx. 53% accuracy on background datasets vs 46% without degradation training), preventing the model from overfitting to clean white backgrounds.</li>
<li><strong>Data Volume</strong>: Increasing the synthetic dataset size from 100k to 1M consistently improved accuracy.</li>
<li><strong>Superiority over Baselines</strong>: The model significantly outperformed the cited CNN+LSTM baseline (93% vs 76% on the provided test set).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study utilizes a combination of collected SMILES data, real hand-drawn images, and generated synthetic images.</p>
<ul>
<li><strong>Source Data</strong>: SMILES codes collected from PubChem, ZINC, GDB-11, and GDB-13. Filtered for C, H, O atoms and max 1 ring.</li>
<li><strong>Real Dataset</strong>: 670 selected SMILES codes drawn by multiple volunteers, totaling <strong>2,598 images</strong>.</li>
<li><strong>Synthetic Dataset</strong>: Generated up to <strong>1,000,000 images</strong> using the pipeline below.</li>
<li><strong>Training Mix</strong>: The optimal training set used 1 million images with a <strong>90:10 ratio</strong> of synthetic to real images.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Dataset Type</th>
          <th style="text-align: left">Source</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Real</strong></td>
          <td style="text-align: left">Volunteer Drawings</td>
          <td style="text-align: left">2,598 images</td>
          <td style="text-align: left">Used for mixed training and testing</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Synthetic</strong></td>
          <td style="text-align: left">Generated</td>
          <td style="text-align: left">100k - 1M</td>
          <td style="text-align: left">Generated via RDKit + OpenCV + Stable Diffusion</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The <strong>Synthetic Image Generation Pipeline</strong> is critical for reproduction:</p>
<ol>
<li><strong>RDKit Modification</strong>: Modify source code to introduce random keys, character width, length, and bond angles.</li>
<li><strong>Augmentation (OpenCV)</strong>: Apply sequence: Resize ($p=0.5$), Blur ($p=0.4$), Erode/Dilate ($p=0.2$), Distort ($p=0.8$), Flip ($p=0.5$), Affine ($p=0.7$).</li>
<li><strong>Degradation</strong>: Apply sequence: Salt+pepper noise ($p=0.1$), Contrast ($p=0.7$), Sharpness ($p=0.5$), Invert ($p=0.3$).</li>
<li><strong>Background Addition</strong>: Random backgrounds are augmented (Crop, Distort, Flip) and added to the molecular image to prevent background overfitting.</li>
<li><strong>Diffusion Enhancement</strong>: Stable Diffusion (v1-4) is used for image-to-image enhancement to better simulate hand-drawn styles (prompt: &ldquo;A pencil sketch of [Formula]&hellip; without charge distribution&rdquo;).</li>
</ol>
<h3 id="models">Models</h3>
<p>The system uses an encoder-decoder architecture:</p>
<ul>
<li><strong>Encoder</strong>: <strong>EfficientNet</strong> (pre-trained on ImageNet). The last layer is removed, and features are extracted into a Numpy array.</li>
<li><strong>Decoder</strong>: <strong>Transformer</strong>. Utilizes self-attention to generate the SMILES sequence. Chosen over LSTM for better handling of long-range dependencies.</li>
<li><strong>Output</strong>: Canonical SMILES string.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The standard performance metrics for sequence generation and recognition align precisely with their chosen evaluations. The model predictions $\hat{y}$ are compared against ground truth sequences $y$ of length $N$.</p>
<ul>
<li><strong>Primary Metric</strong>: <strong>Exact Match (EM)</strong>. A strict binary evaluation verifying if the complete generated SMILES perfectly replicates the original target string:
$$ \text{EM} = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1}[\hat{y}_i = y_i] $$</li>
<li><strong>Other Metrics</strong>: <strong>Levenshtein Distance</strong> measures edit-level character proximity, while the <strong>Tanimoto coefficient</strong> evaluates structural similarity based on chemical fingerprints, both monitored strictly during validation ablation runs.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Value</th>
          <th style="text-align: left">Baseline (CNN+LSTM)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Exact Match</strong></td>
          <td style="text-align: left"><strong>96.90%</strong></td>
          <td style="text-align: left">76%</td>
          <td style="text-align: left">Tested on the provided test set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CPU</strong>: Intel(R) Xeon(R) Gold 6130 (40 GB RAM).</li>
<li><strong>GPU</strong>: NVIDIA Tesla V100 (32 GB video memory).</li>
<li><strong>Framework</strong>: PyTorch 1.9.1.</li>
<li><strong>Training Configuration</strong>:
<ul>
<li>Optimizer: Adam (learning rate 1e-4).</li>
<li>Batch size: 32.</li>
<li>Epochs: 100.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ouyangChemRecoAutomatedRecognition2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{ChemReco: Automated Recognition of Hand-Drawn Carbon--Hydrogen--Oxygen Structures Using Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ouyang, Hengjie and Liu, Wei and Tao, Jiajun and Luo, Yanghong and Zhang, Wanjia and Zhou, Jiayu and Geng, Shuqi and Zhang, Chengpeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{17126}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41598-024-67496-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>AtomLenz: Atom-Level OCSR with Limited Supervision</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/atomlenz/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/atomlenz/</guid><description>Weakly supervised OCSR framework combining object detection and graph construction to recognize chemical structures from hand-drawn images using SMILES.</description><content:encoded><![CDATA[<h2 id="dual-contribution-method-and-data-resource">Dual Contribution: Method and Data Resource</h2>
<p>The paper proposes a novel architecture (AtomLenz) and training framework (ProbKT* + Edit-Correction) to solve the problem of Optical Chemical Structure Recognition (OCSR) in data-sparse domains. As a secondary projection, it serves as a <strong>Resource</strong> paper by releasing a curated, relabeled dataset of hand-drawn molecules explicitly bound to atom-level bounding box annotations.</p>
<h2 id="overcoming-annotation-bottlenecks-in-ocsr">Overcoming Annotation Bottlenecks in OCSR</h2>
<p>Optical Chemical Structure Recognition (OCSR) is critical for digitizing chemical literature and lab notes. However, existing methods face three main limitations:</p>
<ol>
<li><strong>Generalization Limits:</strong> They struggle with sparse or stylistically unique domains, such as hand-drawn images, where massive datasets for pretraining are unavailable.</li>
<li><strong>Annotation Cost:</strong> &ldquo;Atom-level&rdquo; methods (which detect individual atoms and bonds) require expensive bounding box annotations, which are rarely available for real-world sketch data.</li>
<li><strong>Lack of Interpretability/Localization:</strong> Pure &ldquo;Image-to-SMILES&rdquo; models (like DECIMER) work well but fail to localize the atoms or bonds in the original image, limiting human-in-the-loop review and mechanistic interpretability.</li>
</ol>
<h2 id="atomlenz-probkt-and-graph-edit-correction">AtomLenz, ProbKT*, and Graph Edit-Correction</h2>
<p>The core contribution is <strong>AtomLenz</strong>, an OCSR framework that achieves atom-level entity detection using <strong>only SMILES supervision</strong> on target domains. The authors construct an explicit object detection pipeline using Faster R-CNN trained via a composite multi-task loss. The objective aims to optimize a multi-class log loss $L_{cls}$ for predicted class $\hat{c}$ and a regression loss $L_{reg}$ for predicted bounding box coordinates $\hat{b}$:</p>
<p>$$ \mathcal{L} = L_{cls}(c, \hat{c}) + L_{reg}(b, \hat{b}) $$</p>
<p>To bridge the gap between image inputs and the weakly supervised SMILES labels, the system leverages:</p>
<ul>
<li><em><em>ProbKT</em> (Probabilistic Knowledge Transfer):</em>* Uses probabilistic logic and Hungarian matching to align predicted objects with the &ldquo;ground truth&rdquo; derived from the SMILES strings, enabling backpropagation without explicitly bounding boxes.</li>
<li><strong>Graph Edit-Correction:</strong> Generates pseudo-labels by solving for the minimum graph edit distance between the model&rsquo;s predicted output graph and the ground-truth SMILES graph, which forces fine-tuning on less frequent atom types.</li>
<li><strong>ChemExpert:</strong> A chemically sound ensemble strategy that cascades predictions from multiple models (e.g., passing through DECIMER, then AtomLenz), halting at the first output that clears basic RDKit chemical validity checks.</li>
</ul>
<h2 id="data-efficiency-and-domain-adaptation-experiments">Data Efficiency and Domain Adaptation Experiments</h2>
<p>The authors evaluated the model specifically on domain adaptation and sample efficiency, treating hand-drawn molecules as the primary low-data target distribution:</p>
<ul>
<li><strong>Pretraining:</strong> Initially trained on ~214k synthetic images from ChEMBL explicitly labeled with bounding boxes (generated via RDKit).</li>
<li><strong>Target Domain Adaptation:</strong> Fine-tuned on the Brinkhaus hand-drawn dataset (4,070 images) using purely SMILES supervision.</li>
<li><strong>Evaluation Sets:</strong>
<ul>
<li><strong>Hand-drawn test set</strong>: 1,018 images.</li>
<li><strong>ChemPix</strong>: 614 out-of-domain hand-drawn images.</li>
<li><strong>Atom Localization set</strong>: 1,000 synthetic images to evaluate precise bounding box capabilities.</li>
</ul>
</li>
<li><strong>Baselines:</strong> Compared against leading OCSR methods, including DECIMER (v2.2), Img2Mol, MolScribe, ChemGrapher, and OSRA.</li>
</ul>
<h2 id="state-of-the-art-ensembles-vs-standalone-limitations">State-of-the-Art Ensembles vs. Standalone Limitations</h2>
<ul>
<li><strong>SOTA Ensemble Performance:</strong> The <strong>ChemExpert</strong> module (combining AtomLenz and DECIMER) achieved state-of-the-art accuracy on both hand-drawn (63.5%) and ChemPix (51.8%) test sets.</li>
<li><strong>Data Efficiency under Bottleneck Regimes:</strong> AtomLenz effectively bypassed the massive data constraints of competing models. When strictly limited to a training set of 4,000 samples from scratch, AtomLenz achieved 33.8% exact accuracy, outperforming baselines like Img2Mol (0.0%), MolScribe (1.3%), and DECIMER (0.1%), illustrating its tremendous sample efficiency.</li>
<li><strong>Localization Success:</strong> The base framework confidently maintained strong localization capabilities (mAP 0.801) absent from end-to-end transformers like DECIMER.</li>
<li><strong>Methodological Tradeoffs:</strong> While AtomLenz is highly sample efficient, its standalone performance when fine-tuned on the target domain (33.8% accuracy) underperforms fine-tuned models trained on larger datasets like DECIMER (62.2% accuracy). AtomLenz achieves state-of-the-art results primarily when deployed as part of the ChemExpert ensemble alongside DECIMER, functioning effectively as a synergistic error-correction mechanism.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/molden/atomlenz">Official Repository (AtomLenz)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Complete pipeline for AtomLenz, ProbKT*, and Graph Edit-Correction.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/molden/atomlenz/tree/main/models">Pre-trained Models</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Downloadable weights for Faster R-CNN detection backbones.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://dx.doi.org/10.6084/m9.figshare.24599412">Hand-drawn Dataset (Brinkhaus)</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Images and annotations used for target domain fine-tuning and evaluation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/spaces/moldenhof/atomlenz">AtomLenz Web Demo</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Interactive Hugging Face space for testing model inference.</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The study utilizes a mix of large synthetic datasets and smaller curated hand-drawn datasets.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Pretraining</strong></td>
          <td>Synthetic ChEMBL</td>
          <td>~214,000</td>
          <td>Generated via RDKit/Indigo. Annotated with atoms, bonds, charges, stereocenters.</td>
      </tr>
      <tr>
          <td><strong>Fine-tuning</strong></td>
          <td>Hand-drawn (Brinkhaus)</td>
          <td>4,070</td>
          <td>Used for weakly supervised adaptation (SMILES only).</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Hand-drawn Test</td>
          <td>1,018</td>
          <td></td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>ChemPix</td>
          <td>614</td>
          <td>Out-of-distribution hand-drawn images.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Atom Localization</td>
          <td>1,000</td>
          <td>Synthetic images with ground truth bounding boxes.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Molecular Graph Constructor (Algorithm 1):</strong> A rule-based system to assemble the graph from detected objects:
<ol>
<li><strong>Filtering:</strong> Removes overlapping atom boxes (IoU threshold).</li>
<li><strong>Node Creation:</strong> Merges charge and stereocenter objects with the nearest atom objects.</li>
<li><strong>Edge Creation:</strong> Iterates over bond objects; if a bond overlaps with exactly two atoms, an edge is added. If &gt;2, it selects the most probable pair.</li>
<li><strong>Validation:</strong> Checks valency constraints; removes bonds iteratively if constraints are violated.</li>
</ol>
</li>
<li><strong>Weakly Supervised Training:</strong>
<ul>
<li><strong>ProbKT*:</strong> Uses Hungarian matching to align predicted objects with the &ldquo;ground truth&rdquo; implied by the SMILES string, allowing backpropagation without explicit boxes.</li>
<li><strong>Graph Edit-Correction:</strong> Computes the Minimum Graph Edit Distance between the predicted graph and the true SMILES graph to generate pseudo-labels for retraining.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Object Detection Backbone:</strong> <strong>Faster R-CNN</strong>.
<ul>
<li>Four distinct models are trained for different entity types: Atoms ($O^a$), Bonds ($O^b$), Charges ($O^c$), and Stereocenters ($O^s$).</li>
<li><strong>Loss Function:</strong> Multi-task loss combining Multi-class Log Loss ($L_{cls}$) and Regression Loss ($L_{reg}$).</li>
</ul>
</li>
<li><strong>ChemExpert:</strong> An ensemble wrapper that prioritizes models based on user preference (e.g., DECIMER first, then AtomLenz). It accepts the first prediction that passes RDKit chemical validity checks.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Primary metrics focused on structural correctness and localization accuracy.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (Hand-drawn)</th>
          <th>Baseline (DECIMER FT)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (T=1)</strong></td>
          <td>33.8% (AtomLenz+EditKT*)</td>
          <td>62.2%</td>
          <td>Exact ECFP6 fingerprint match.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto Sim.</strong></td>
          <td>0.484</td>
          <td>0.727</td>
          <td>Average similarity.</td>
      </tr>
      <tr>
          <td><strong>mAP</strong></td>
          <td>0.801</td>
          <td>N/A</td>
          <td>Localization accuracy (IoU 0.05-0.35).</td>
      </tr>
      <tr>
          <td><strong>Ensemble Acc.</strong></td>
          <td><strong>63.5%</strong></td>
          <td>62.2%</td>
          <td>ChemExpert (DECIMER + AtomLenz).</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute:</strong> Experiments utilized the Flemish Supercomputer Center (VSC) resources.</li>
<li><strong>Note:</strong> Specific GPU models (e.g., A100/V100) are not explicitly detailed in the text, but Faster R-CNN training is standard on consumer or enterprise GPUs.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Oldenhof, M., De Brouwer, E., Arany, A., &amp; Moreau, Y. (2024). Atom-Level Optical Chemical Structure Recognition with Limited Supervision. <em>arXiv preprint arXiv:2404.01743</em>. <a href="https://doi.org/10.48550/arXiv.2404.01743">https://doi.org/10.48550/arXiv.2404.01743</a></p>
<p><strong>Publication venue/year</strong>: arXiv 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/molden/atomlenz">Official Repository</a></li>
<li><a href="https://dx.doi.org/10.6084/m9.figshare.24599412">Hand-drawn Dataset on Figshare</a></li>
</ul>
<p><strong>BibTeX</strong>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{oldenhofAtomLevelOpticalChemical2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Atom-{{Level Optical Chemical Structure Recognition}} with {{Limited Supervision}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Oldenhof, Martijn and Brouwer, Edward De and Arany, Adam and Moreau, Yves}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2404.01743}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2404.01743}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2404.01743}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-10-25}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{Identifying the chemical structure from a graphical representation, or image, of a molecule is a challenging pattern recognition task that would greatly benefit drug development. Yet, existing methods for chemical structure recognition do not typically generalize well, and show diminished effectiveness when confronted with domains where data is sparse, or costly to generate, such as hand-drawn molecule images. To address this limitation, we propose a new chemical structure recognition tool that delivers state-of-the-art performance and can adapt to new domains with a limited number of data samples and supervision. Unlike previous approaches, our method provides atom-level localization, and can therefore segment the image into the different atoms and bonds. Our model is the first model to perform OCSR with atom-level entity detection with only SMILES supervision. Through rigorous and extensive benchmarking, we demonstrate the preeminence of our chemical structure recognition approach in terms of data efficiency, accuracy, and atom-level entity prediction.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{Computer Science - Computer Vision and Pattern Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">file</span> = <span style="color:#e6db74">{/Users/hunterheidenreich/Zotero/storage/4ILTDIFX/Oldenhof et al. - 2024 - Atom-Level Optical Chemical Structure Recognition with Limited Supervision.pdf}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SwinOCSR: Vision Transformers for Chemical OCR</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/swinocsr/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/swinocsr/</guid><description>Deep learning model using Swin Transformer and Focal Loss for OCSR, achieving 98.58% accuracy on synthetic benchmarks.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, Z., Li, J., Yang, Z. et al. (2022). SwinOCSR: end-to-end optical chemical structure recognition using a Swin Transformer. <em>Journal of Cheminformatics</em>, 14(41). <a href="https://doi.org/10.1186/s13321-022-00624-5">https://doi.org/10.1186/s13321-022-00624-5</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/suanfaxiaohuo/SwinOCSR">GitHub Repository</a></li>
</ul>
<h2 id="contribution-methodological-architecture-and-datasets">Contribution: Methodological Architecture and Datasets</h2>
<p>This is a <strong>Methodological Paper</strong> with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architecture (Swin Transformer backbone) and a specific loss function optimization (Focal Loss) for the task of Optical Chemical Structure Recognition (OCSR).</li>
<li><strong>Resource</strong>: It constructs a large-scale synthetic dataset of 5 million molecules, specifically designing it to cover complex cases like substituents and aromatic rings.</li>
</ul>
<h2 id="motivation-addressing-visual-context-and-data-imbalance">Motivation: Addressing Visual Context and Data Imbalance</h2>
<ul>
<li><strong>Problem</strong>: OCSR (converting images of chemical structures to <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>) is difficult due to complex chemical patterns and long sequences. Existing deep learning methods (often CNN-based) struggle to achieve satisfactory recognition rates.</li>
<li><strong>Technical Gap</strong>: Standard CNN backbones (like ResNet or EfficientNet) focus on local feature extraction and miss global dependencies required for interpreting complex molecular diagrams.</li>
<li><strong>Data Imbalance</strong>: Chemical strings suffer from severe class imbalance (e.g., &lsquo;C&rsquo; and &lsquo;H&rsquo; are frequent; &lsquo;Br&rsquo; or &lsquo;Cl&rsquo; are rare), which causes standard Cross Entropy loss to underperform.</li>
</ul>
<h2 id="core-innovation-swin-transformers-and-focal-loss">Core Innovation: Swin Transformers and Focal Loss</h2>
<ul>
<li><strong>Swin Transformer Backbone</strong>: SwinOCSR is the first approach to replace the standard CNN backbone with a <strong>Swin Transformer</strong>. This allows the model to leverage shifted window attention to capture both local and global image features more effectively.</li>
<li><strong>Multi-label Focal Loss (MFL)</strong>: The paper introduces a modified Focal Loss to OCSR to explicitly penalize the model for errors on rare tokens, addressing the &ldquo;long-tail&rdquo; distribution of chemical elements. The standard Focal Loss formulation heavily weights hard-to-classify examples:
$$
\begin{aligned}
FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) \\
\end{aligned}
$$</li>
<li><strong>Structured Synthetic Dataset</strong>: Creation of a dataset explicitly balanced across four structural categories: Kekule rings, Aromatic rings, and their combinations with substituents.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<ul>
<li><strong>Backbone Comparison</strong>: The authors benchmarked SwinOCSR against the backbones of leading competitors: ResNet-50 (used in Image2SMILES) and EfficientNet-B3 (used in DECIMER 1.0).</li>
<li><strong>Loss Function Ablation</strong>: They compared the performance of standard Cross Entropy (CE) loss against their proposed Multi-label Focal Loss (MFL).</li>
<li><strong>Category Stress Test</strong>: Performance was evaluated separately on molecules with/without substituents and with/without aromaticity to test robustness.</li>
<li><strong>Real-world Evaluation</strong>: The model was tested on a small set of 100 images manually extracted from literature vs. 100 generated images to measure domain shift.</li>
</ul>
<h2 id="results-and-limitations">Results and Limitations</h2>
<ul>
<li><strong>SOTA Performance</strong>: SwinOCSR achieved <strong>98.58% accuracy</strong> on the synthetic test set, significantly outperforming ResNet-50 (89.17%) and EfficientNet-B3 (86.70%) backbones.</li>
<li><strong>Effective Handling of Length</strong>: The model maintained high accuracy (94.76%) even on very long DeepSMILES strings (76-100 characters), indicating superior global feature extraction.</li>
<li><strong>Domain Shift Issues</strong>: While performance on synthetic data was near-perfect, accuracy dropped to <strong>25%</strong> on real-world literature images. The authors attribute this to noise, low resolution, and stylistic variations (e.g., abbreviations) present in the real world.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: The first 8.5 million structures from <strong>PubChem</strong> were downloaded, yielding ~6.9 million unique SMILES.</li>
<li><strong>Generation Pipeline</strong>:
<ul>
<li><strong>Tools</strong>: <strong>CDK</strong> (Chemistry Development Kit) for image rendering; <strong>RDKit</strong> for SMILES canonicalization.</li>
<li><strong>Augmentation</strong>: To ensure diversity, the dataset was split into 4 categories (1.25M each): (1) Kekule, (2) Aromatic, (3) Kekule + Substituents, (4) Aromatic + Substituents. Substituents were randomly added from a list of 224 common patent substituents.</li>
<li><strong>Preprocessing</strong>: Images rendered as binary, resized to <strong>224x224</strong>, and copied to 3 channels (RGB simulation).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>4,500,000</td>
          <td>18:1:1 split (Train/Val/Test)</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>250,000</td>
          <td></td>
      </tr>
      <tr>
          <td>Test</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>250,000</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Loss Function</strong>: <strong>Multi-label Focal Loss (MFL)</strong>. The single-label classification task was cast as multi-label to apply Focal Loss, using a sigmoid activation on logits.</li>
<li><strong>Optimization</strong>:
<ul>
<li><strong>Optimizer</strong>: <strong>Adam</strong> with initial learning rate <code>5e-4</code>.</li>
<li><strong>Schedulers</strong>: Cosine decay for the Swin Transformer backbone; Step decay for the Transformer encoder/decoder.</li>
<li><strong>Regularization</strong>: Dropout rate of <code>0.1</code>.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Backbone (Encoder 1)</strong>: <strong>Swin Transformer</strong>.
<ul>
<li>Patch size: $4 \times 4$.</li>
<li>Linear embedding dimension: 192.</li>
<li>Structure: 4 stages with Swin Transformer Blocks (Window MSA + Shifted Window MSA).</li>
<li>Output: Flattened patch sequence $S_b$.</li>
</ul>
</li>
<li><strong>Transformer Encoder (Encoder 2)</strong>: 6 standard Transformer encoder layers. Uses Positional Embedding + Multi-Head Attention + MLP.</li>
<li><strong>Transformer Decoder</strong>: 6 standard Transformer decoder layers. Uses Masked Multi-Head Attention (to prevent look-ahead) + Multi-Head Attention (connecting to encoder output $S_e$).</li>
<li><strong>Tokenization</strong>: <strong>DeepSMILES</strong> format used (syntactically more robust than SMILES). Vocabulary size: <strong>76 tokens</strong> (76 unique characters found in dataset). Embedding dimension: 256.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Accuracy (Exact Match), Tanimoto Similarity (PubChem fingerprints), BLEU, ROUGE.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SwinOCSR (MFL)</th>
          <th>ResNet-50</th>
          <th>EfficientNet-B3</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td><strong>98.58%</strong></td>
          <td>89.17%</td>
          <td>86.70%</td>
          <td>MFL loss provides ~1% boost over CE</td>
      </tr>
      <tr>
          <td>Tanimoto</td>
          <td><strong>99.77%</strong></td>
          <td>98.79%</td>
          <td>98.46%</td>
          <td>High similarity even when exact match fails</td>
      </tr>
      <tr>
          <td>BLEU</td>
          <td><strong>99.59%</strong></td>
          <td>98.62%</td>
          <td>98.37%</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Trained on <strong>NVIDIA Tesla V100-PCIE</strong>.</li>
<li><strong>Training Time</strong>: 30 epochs.</li>
<li><strong>Batch Size</strong>: 256 images ($224 \times 224$ pixels).</li>
</ul>
]]></content:encoded></item><item><title>String Representations for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/rajan-string-representations-2022/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/rajan-string-representations-2022/</guid><description>Ablation study comparing SMILES, DeepSMILES, SELFIES, and InChI for OCSR. SMILES achieves highest accuracy; SELFIES guarantees validity.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Steinbeck, C., &amp; Zielesny, A. (2022). Performance of chemical structure string representations for chemical image recognition using transformers. <em>Digital Discovery</em>, 1(2), 84-90. <a href="https://doi.org/10.1039/D1DD00013F">https://doi.org/10.1039/D1DD00013F</a></p>
<p><strong>Publication</strong>: Digital Discovery 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/6144553b4853d253c6a7adde/original/performance-of-chemical-structure-string-representations-for-chemical-image-recognition-using-transformers.pdf">ChemRxiv Preprint (PDF)</a></li>
<li><a href="https://github.com/Kohulan/DECIMER_Short_Communication">Official Code Repository</a></li>
<li><a href="https://doi.org/10.5281/zenodo.5155037">Data on Zenodo</a></li>
<li>Related work: <a href="/notes/computational-chemistry/optical-structure-recognition/decimer/">DECIMER</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/decimer-1.0/">DECIMER 1.0</a>, <a href="/notes/computational-chemistry/optical-structure-recognition/img2smi/">IMG2SMI</a></li>
</ul>
<h2 id="methodological-focus-and-resource-contributions">Methodological Focus and Resource Contributions</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a secondary contribution as a <strong>Resource Paper</strong> ($\Psi_{\text{Resource}}$).</p>
<p>It functions as a systematic ablation study, keeping the model architecture (EfficientNet-B3 + Transformer) constant while varying the input/output representation (<a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>, DeepSMILES, <a href="/notes/computational-chemistry/molecular-representations/selfies/">SELFIES</a>, <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a>) to determine which format yields the best performance for Optical Chemical Structure Recognition (OCSR). It also contributes large-scale benchmarking datasets derived from ChEMBL and PubChem.</p>
<h2 id="the-syntax-challenge-in-chemical-image-recognition">The Syntax Challenge in Chemical Image Recognition</h2>
<p>Optical Chemical Structure Recognition (OCSR) is essential for extracting chemical information buried in scientific literature and patents. While deep learning offers a promising alternative to rule-based approaches, neural networks struggle with the syntax of standard chemical representations like SMILES. Specifically, the tokenization of SMILES strings (where ring closures and branches are marked by single characters potentially far apart in the sequence) creates learning difficulties for sequence-to-sequence models. Newer representations like DeepSMILES and SELFIES were developed to address these syntax issues, but their comparative performance in image-to-text tasks had not been rigorously benchmarked.</p>
<h2 id="isolating-string-representation-variables">Isolating String Representation Variables</h2>
<p>The core novelty is the <strong>comparative isolation of the string representation variable</strong> in an OCSR context. Previous approaches often selected a representation (usually SMILES) without validating if it was optimal for the learning task. This study specifically tests the hypothesis that syntax-robust representations (like SELFIES) improve deep learning performance compared to standard SMILES. It provides empirical evidence on the trade-off between <em>validity</em> (guaranteed by SELFIES) and <em>accuracy</em> (highest with SMILES).</p>
<h2 id="large-scale-image-to-text-translation-experiments">Large-Scale Image-to-Text Translation Experiments</h2>
<p>The authors performed a large-scale image-to-text translation experiment:</p>
<ul>
<li><strong>Task</strong>: Converting 2D chemical structure images into text strings.</li>
<li><strong>Data</strong>:
<ul>
<li><strong>ChEMBL</strong>: ~1.6M molecules, split into two datasets (with and without stereochemistry).</li>
<li><strong>PubChem</strong>: ~3M molecules, split similarly, to test performance scaling with data size.</li>
</ul>
</li>
<li><strong>Representations</strong>: The same chemical structures were converted into four formats: SMILES, DeepSMILES, SELFIES, and InChI.</li>
<li><strong>Metric</strong>: The models were evaluated on:
<ul>
<li><strong>Validity</strong>: Can the predicted string be decoded back to a molecule?</li>
<li><strong>Exact Match</strong>: Is the predicted string identical to the ground truth?</li>
<li><strong>Tanimoto Similarity</strong>: How chemically similar is the prediction to the ground truth (using PubChem fingerprints)? The similarity $\mathcal{T}$ between two molecular fingerprints $A$ and $B$ is calculated as:
$$ \mathcal{T}(A, B) = \frac{A \cdot B}{||A||^2 + ||B||^2 - A \cdot B} $$</li>
</ul>
</li>
</ul>
<h2 id="comparative-performance-and-validity-trade-offs">Comparative Performance and Validity Trade-offs</h2>
<ul>
<li><strong>SMILES is the most accurate</strong>: Contrary to the hypothesis that syntax-robust formats would learn better, SMILES consistently achieved the highest exact match accuracy (up to 88.62% on PubChem data) and average Tanimoto similarity (0.98). This is likely due to SMILES having shorter string lengths and fewer unique tokens compared to SELFIES.</li>
<li><strong>SELFIES guarantees validity</strong>: While slightly less accurate in direct translation, SELFIES achieved 100% structural validity (every prediction could be decoded), whereas SMILES predictions occasionally contained syntax errors.</li>
<li><strong>InChI is unsuitable</strong>: InChI performed significantly worse (approx. 64% exact match) due to extreme string lengths (up to 273 tokens).</li>
<li><strong>Stereochemistry adds difficulty</strong>: Including stereochemistry reduced accuracy across all representations due to increased token count and visual complexity.</li>
<li><strong>Recommendation</strong>: Use SMILES for maximum accuracy; use SELFIES if generating valid structures is the priority (e.g., generative tasks).</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used curated subsets from ChEMBL and PubChem. Images were generated synthetically.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL (Dataset 1/2)</td>
          <td>~1.5M</td>
          <td>Filtered for MW &lt; 1500, specific elements (C,H,O,N,P,S,F,Cl,Br,I,Se,B).</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>PubChem (Dataset 3/4)</td>
          <td>~3.0M</td>
          <td>Same filtering rules, used to test scaling.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Test Split</td>
          <td>~120k - 250k</td>
          <td>Created using RDKit MaxMin algorithm to ensure chemical diversity.</td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation</strong>:</p>
<ul>
<li><strong>Tool</strong>: CDK Structure Diagram Generator (SDG).</li>
<li><strong>Specs</strong>: $300 \times 300$ pixels, rotated by random angles ($0-360^{\circ}$), saved as 8-bit PNG.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Tokenization Rules</strong> (Critical for replication):</p>
<ul>
<li><strong>SELFIES</strong>: Split at every <code>][</code> (e.g., <code>[C][N]</code> $\rightarrow$ <code>[C]</code>, <code>[N]</code>).</li>
<li><strong>SMILES / DeepSMILES</strong>: Regex-based splitting:
<ul>
<li>Every heavy atom (e.g., <code>C</code>, <code>N</code>).</li>
<li>Every bracket <code>(</code> and <code>)</code>.</li>
<li>Every bond symbol <code>=</code> and <code>#</code>.</li>
<li>Every single-digit number.</li>
<li>Everything inside square brackets <code>[]</code> is kept as a single token.</li>
</ul>
</li>
<li><strong>InChI</strong>: The prefix <code>InChI=1S/</code> was treated as a single token and removed during training, then re-added for evaluation.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model follows the <strong>DECIMER</strong> architecture.</p>
<ul>
<li><strong>Encoder</strong>: EfficientNet-B3 (pre-trained with &ldquo;Noisy Student&rdquo; weights).
<ul>
<li>Output: Image feature vectors of shape $10 \times 10 \times 1536$.</li>
</ul>
</li>
<li><strong>Decoder</strong>: Transformer (similar to the &ldquo;Base&rdquo; model from <em>Attention Is All You Need</em>).
<ul>
<li>Layers: 4 encoder-decoder layers.</li>
<li>Attention Heads: 8.</li>
<li>Dimension ($d_{\text{model}}$): 512.</li>
<li>Feed-forward ($d_{\text{ff}}$): 2048.</li>
<li>Dropout: 10%.</li>
</ul>
</li>
<li><strong>Loss</strong>: Sparse categorical cross-entropy.</li>
<li><strong>Optimizer</strong>: Adam with custom learning rate scheduler.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics were calculated after converting all predictions back to standard SMILES.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Baseline (SMILES)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Identical Match</strong></td>
          <td>88.62% (PubChem)</td>
          <td>Strict character-for-character equality.</td>
      </tr>
      <tr>
          <td><strong>Valid Structure</strong></td>
          <td>99.78%</td>
          <td>SMILES had rare syntax errors; SELFIES achieved 100%.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto (Avg)</strong></td>
          <td>0.98</td>
          <td>Calculated using PubChem fingerprints via CDK.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: Google Cloud TPUs (v3-8).</li>
<li><strong>Format</strong>: Data converted to TFRecords (128 image/text pairs per record) for TPU efficiency.</li>
<li><strong>Batch Size</strong>: 1024.</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanPerformanceChemicalStructure2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Performance of Chemical Structure String Representations for Chemical Image Recognition Using Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Steinbeck, Christoph and Zielesny, Achim}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{84--90}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D1DD00013F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>One Strike, You're Out: Detecting Markush Structures</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/jurriaans-markush-detection-2023/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/jurriaans-markush-detection-2023/</guid><description>Patch-based CNN method for detecting Markush structures in chemical images, addressing low signal-to-noise ratios in OCSR.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jurriaans, T., Szarkowska, K., Nalisnick, E., Schwörer, M., Thorne, C., &amp; Akhondi, S. (2023). One Strike, You&rsquo;re Out: Detecting Markush Structures in Low Signal-to-Noise Ratio Images. <em>arXiv preprint arXiv:2311.14633</em>. <a href="https://doi.org/10.48550/arXiv.2311.14633">https://doi.org/10.48550/arXiv.2311.14633</a></p>
<p><strong>Publication</strong>: arXiv 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Thomasjurriaans/markush-recognition-msc-thesis">GitHub Repository</a></li>
</ul>
<h2 id="methodology-and-classification">Methodology and Classification</h2>
<p>This is a <strong>Method</strong> paper (Classification: $\Psi_{\text{Method}}$).</p>
<p>It proposes a patch-based classification pipeline to solve a technical failure mode in Optical Chemical Structure Recognition (OCSR). Distinct rhetorical indicators include a strong baseline comparison (CNN vs. traditional ORB), rigorous ablation studies (architecture, pretraining), and a focus on evaluating the filtering efficacy against a known failure mode.</p>
<h2 id="the-markush-structure-challenge">The Markush Structure Challenge</h2>
<p><strong>The Problem</strong>: Optical Chemical Structure Recognition (OCSR) tools convert 2D images of molecules into machine-readable formats. These tools struggle with &ldquo;Markush structures,&rdquo; generic structural templates used frequently in patents that contain variables rather than specific atoms (e.g., $R$, $X$, $Y$).</p>
<p><strong>The Gap</strong>: Markush structures are difficult to detect because they often appear as small indicators (a single &ldquo;R&rdquo; or variable) within a large image, resulting in a very low Signal-to-Noise Ratio (SNR). Existing OCSR research pipelines typically bypass this by manually excluding these structures from their datasets.</p>
<p><strong>The Goal</strong>: To build an automated filter that can identify images containing Markush structures so they can be removed from OCSR pipelines, improving overall database quality without requiring manual data curation.</p>
<h2 id="patch-based-classification-pipeline">Patch-Based Classification Pipeline</h2>
<p>The core technical contribution is an end-to-end deep learning pipeline explicitly tailored for low-SNR chemical images where standard global resizing or cropping fails due to massive variations in image resolution and pixel scales.</p>
<ul>
<li><strong>Patch Generation</strong>: The system slices input images into overlapping patches generated from two offset grids, ensuring that variables falling on boundaries are fully captured in at least one crop.</li>
<li><strong>Targeted Annotation</strong>: The labels rely on pixel-level bounding boxes around Markush indicators, minimizing the noise that would otherwise overwhelm a full-image classification attempt.</li>
<li><strong>Inference Strategy</strong>: During inference, the query image is broken into patches, individually classified, and aggregated entirely using a maximum pooling rule where $X = \max_{i=1}^{n} \{ x_i \}$.</li>
<li><strong>Evaluation</strong>: Provides the first rigorous comparison between fixed-feature extraction (ORB + XGBoost) and end-to-end deep learning for this specific domain.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The authors compared two distinct paradigms on a manually annotated dataset:</p>
<ol>
<li>
<p><strong>Fixed-Feature Baseline</strong>: Used <strong>ORB</strong> (Oriented FAST and Rotated BRIEF) to detect keypoints and match them against a template bank of known Markush symbols. Features (match counts, Hamming distances) were fed into an <strong>XGBoost</strong> model.</p>
</li>
<li>
<p><strong>Deep Learning Method</strong>: Fine-tuned <strong>ResNet18</strong> and <strong>Inception V3</strong> models on the generated image patches.</p>
<ul>
<li><strong>Ablations</strong>: Contrasted pretraining sources, evaluating general domain (ImageNet) against chemistry-specific domain (USPTO images).</li>
<li><strong>Fine-tuning</strong>: Compared full-network fine-tuning against freezing all but the fully connected layers.</li>
</ul>
</li>
</ol>
<p>To handle significant class imbalance, the primary evaluation metric was the Macro F1 score, defined as:</p>
<p>$$ \text{Macro F1} = \frac{1}{N} \sum_{i=1}^{N} \frac{2 \cdot \text{precision}_i \cdot \text{recall}_i}{\text{precision}_i + \text{recall}_i} $$</p>
<h2 id="performance-outcomes-and-model-superiority">Performance Outcomes and Model Superiority</h2>
<ul>
<li>
<p><strong>CNN Superiority</strong>: Deep learning architectures fundamentally outperformed the structured fixed-feature baseline. The best model (<strong>Inception V3</strong> pretrained on ImageNet) achieved a patch-level Macro F1 of <strong>0.928</strong>, compared to <strong>0.701</strong> (image-level) for the ORB baseline.</p>
</li>
<li>
<p><strong>The Pretraining Surprise</strong>: Counterintuitively, ImageNet pretraining consistently outperformed the domain-specific USPTO pretraining. This suggests that the robust, varied features learned across millions of general images transfer better to low-level stroke recognition than features trained on specialized, but smaller, chemical datasets.</p>
</li>
<li>
<p><strong>Full Model Tuning</strong>: Unfreezing the entire network yielded significantly higher performance than tuning only the classifier head, indicating that standard low-level visual filters require substantial adaptation to reliably distinguish chemical line drawings.</p>
</li>
<li>
<p><strong>Limitations and Edge Cases</strong>: While the ROC AUC of <strong>0.97</strong> implies high reliability, the aggregation metric ($X = \max \{ x_i \}$) is naive. Furthermore, the patching approach creates inherent label noise when a Markush indicator is cleanly bisected by a patch edge, potentially forcing the network to learn incomplete visual features.</p>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used a primary dataset labeled by domain experts and a larger auxiliary dataset for evaluation.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Val</strong></td>
          <td><strong>Primary Dataset</strong></td>
          <td>272 Images</td>
          <td>Manually annotated with bounding boxes for Markush indicators. Split 60/20/20.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>Auxiliary Dataset</strong></td>
          <td>~5.4k Images</td>
          <td>5117 complete structures, 317 Markush. Used for image-level testing only (no bbox).</td>
      </tr>
  </tbody>
</table>
<p><strong>Patch Generation</strong>:</p>
<ul>
<li>Images are cropped into patches of size <strong>224x224</strong> (ResNet) or <strong>299x299</strong> (Inception).</li>
<li>Patches are generated from 2 grids offset by half the patch width/height to ensure annotations aren&rsquo;t lost on edges.</li>
<li><strong>Labeling Rule</strong>: A patch is labeled &ldquo;Markush&rdquo; if &gt;50% of an annotation&rsquo;s pixels fall inside it.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>ORB (Baseline)</strong>:</p>
<ul>
<li>Matches query images against a bank of template patches containing Markush indicators.</li>
<li><strong>Features</strong>: Number of keypoints, number of matches, Hamming distance of best 5 matches.</li>
<li><strong>Classifier</strong>: XGBoost trained on these features.</li>
<li><strong>Hyperparameters</strong>: Search over number of features (500-2000) and template patches (50-250).</li>
</ul>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Framework</strong>: PyTorch with Optuna for optimization.</li>
<li><strong>Optimization</strong>: 25 trials per configuration.</li>
<li><strong>Augmentations</strong>: Random perspective shift, posterization, sharpness/blur.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main architectures were compared.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Input Size</th>
          <th>Parameters</th>
          <th>Pretraining Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ResNet18</strong></td>
          <td>224x224</td>
          <td>11.5M</td>
          <td>ImageNet</td>
      </tr>
      <tr>
          <td><strong>Inception V3</strong></td>
          <td>299x299</td>
          <td>23.8M</td>
          <td>ImageNet &amp; USPTO</td>
      </tr>
  </tbody>
</table>
<p><strong>Best Configuration</strong>: Inception V3, ImageNet weights, Full Model fine-tuning (all layers unfrozen).</p>
<h3 id="evaluation">Evaluation</h3>
<p>Primary metric was <strong>Macro F1</strong> due to class imbalance.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best CNN (Inception V3)</th>
          <th>Baseline (ORB)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Patch Test F1</strong></td>
          <td>$0.928 \pm 0.035$</td>
          <td>N/A</td>
          <td>ORB does not support patch-level</td>
      </tr>
      <tr>
          <td><strong>Image Test F1</strong></td>
          <td>$0.917 \pm 0.014$</td>
          <td>$0.701 \pm 0.052$</td>
          <td>CNN aggregates patch predictions</td>
      </tr>
      <tr>
          <td><strong>Aux Test F1</strong></td>
          <td>0.914</td>
          <td>0.533</td>
          <td>Evaluation on large secondary dataset</td>
      </tr>
      <tr>
          <td><strong>ROC AUC</strong></td>
          <td>0.97</td>
          <td>0.81</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Tesla V100-SXM2-16GB</li>
<li><strong>CPU</strong>: Intel Xeon E5-2686 @ 2.30GHz</li>
<li><strong>RAM</strong>: 64 GB</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{jurriaansOneStrikeYoure2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{One {{Strike}}, {{You}}&#39;re {{Out}}: {{Detecting Markush Structures}} in {{Low Signal-to-Noise Ratio Images}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{One {{Strike}}, {{You}}&#39;re {{Out}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Jurriaans, Thomas and Szarkowska, Kinga and Nalisnick, Eric and Schwoerer, Markus and Thorne, Camilo and Akhondi, Saber}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = nov,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MICER: Molecular Image Captioning with Transfer Learning</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/micer/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/micer/</guid><description>Encoder-decoder model using pre-trained ResNet and attention-based LSTM to translate molecular images into SMILES, achieving SOTA OCSR performance.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yi, J., Wu, C., Zhang, X., Xiao, X., Qiu, Y., Zhao, W., Hou, T., &amp; Cao, D. (2022). MICER: a pre-trained encoder-decoder architecture for molecular image captioning. <em>Bioinformatics</em>, 38(19), 4562-4572. <a href="https://doi.org/10.1093/bioinformatics/btac545">https://doi.org/10.1093/bioinformatics/btac545</a></p>
<p><strong>Publication</strong>: Bioinformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Jiacai-Yi/MICER">GitHub Repository</a></li>
</ul>
<h2 id="micers-contribution-to-optical-structure-recognition">MICER&rsquo;s Contribution to Optical Structure Recognition</h2>
<p>This is a <strong>Method</strong> paper according to the AI for Physical Sciences taxonomy. It proposes MICER, an encoder-decoder architecture that integrates transfer learning (fine-tuning pre-trained models) and attention mechanisms for Optical Chemical Structure Recognition (OCSR). The study includes rigorous benchmarking comparing MICER against three rule-based tools (OSRA, MolVec, Imago) and existing deep learning methods (DECIMER). The authors conduct extensive factor comparison experiments to isolate the effects of stereochemistry, molecular complexity, data volume, and encoder backbone choices.</p>
<h2 id="the-challenge-of-generalizing-in-ocsr">The Challenge of Generalizing in OCSR</h2>
<p>Chemical structures in scientific literature are valuable for drug discovery, but they are locked in image formats that are difficult to mine automatically. Traditional OCSR tools (like OSRA) rely on hand-crafted rules and expert knowledge. They are brittle, struggle with stylistic variations, and have low generalization ability. While deep learning has been applied (e.g., DECIMER), previous attempts often used frozen pre-trained feature extractors (without fine-tuning) or failed to fully exploit transfer learning, leading to suboptimal performance. The goal of this work is to build an end-to-end &ldquo;image captioning&rdquo; system that translates molecular images directly into <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a> strings without intermediate segmentation steps.</p>
<h2 id="integrating-fine-tuning-and-attention-for-chemistry">Integrating Fine-Tuning and Attention for Chemistry</h2>
<p>The core novelty lies in the specific architectural integration of transfer learning with fine-tuning for the chemical domain. Unlike DECIMER, which used a frozen network, MICER fine-tunes a pre-trained ResNet on molecular images. This allows the encoder to adapt from general object recognition to specific chemical feature extraction.</p>
<p>The model incorporates an attention mechanism into the LSTM decoder, allowing the model to focus on specific image regions (atoms and bonds) when generating each character of the SMILES string. The paper explicitly analyzes &ldquo;intrinsic features&rdquo; of molecular data (stereochemistry, complexity) to guide the design of a robust training dataset, combining multiple chemical toolkits (Indigo, RDKit) to generate diverse styles.</p>
<h2 id="experimental-setup-and-ablation-studies">Experimental Setup and Ablation Studies</h2>
<p>The authors performed two types of experiments: Factor Comparison (ablations) and Benchmarking.</p>
<p><strong>Factor Comparisons</strong>: They evaluated how performance is affected by:</p>
<ul>
<li><strong>Stereochemistry (SI)</strong>: Comparing models trained on data with and without stereochemical information.</li>
<li><strong>Molecular Complexity (MC)</strong>: Analyzing performance across 5 molecular weight intervals.</li>
<li><strong>Data Volume (DV)</strong>: Training on datasets ranging from 0.64 million to 10 million images.</li>
<li><strong>Pre-trained Models (PTMs)</strong>: Comparing 8 different backbones (e.g., ResNet, VGG, Inception, MobileNet) versus a base CNN.</li>
</ul>
<p><strong>Benchmarking</strong>:</p>
<ul>
<li><strong>Baselines</strong>: OSRA, MolVec, Imago (rule-based); Base CNN, DECIMER (deep learning).</li>
<li><strong>Datasets</strong>: Four test sets (100k images each, except UOB): Uni-style, Multi-style, Noisy, and Real-world (UOB dataset).</li>
<li><strong>Metrics</strong>: Sequence Accuracy (Exact Match), Levenshtein Distance (ALD), and Tanimoto Similarity (Fingerprint match).</li>
</ul>
<h2 id="results-and-core-insights">Results and Core Insights</h2>
<p>MICER achieved 97.54% Sequence Accuracy on uni-style data and 82.33% on the real-world UOB dataset, significantly outperforming the next best method (OSRA scored approximately 23% on uni-style; DECIMER scored roughly 21% on UOB). ResNet101 was identified as the most effective encoder (87.58% accuracy in preliminary tests), outperforming deeper (DenseNet) or lighter (MobileNet) networks. Performance saturates around 6 million training samples. Stereochemical information drops accuracy by nearly 6%, indicating wedge and dash bonds are harder to recognize. Visualizing attention maps showed the model correctly attends to specific atoms (e.g., focusing on &lsquo;S&rsquo; or &lsquo;Cl&rsquo; pixels) when generating the corresponding character.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data was curated from the <strong>ZINC20</strong> database.</p>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Filtering</strong>: Removed organometallics, mixtures, and invalid molecules.</li>
<li><strong>Standardization</strong>: SMILES were canonicalized and de-duplicated.</li>
<li><strong>Generation</strong>: Images generated using <strong>Indigo</strong> and <strong>RDKit</strong> toolkits to vary styles.</li>
</ul>
<p><strong>Dataset Size</strong>:</p>
<ul>
<li><strong>Total</strong>: 10 million images selected for the final model.</li>
<li><strong>Composition</strong>: 6 million &ldquo;default style&rdquo; (Indigo) + 4 million &ldquo;multi-style&rdquo; (Indigo + RDKit).</li>
<li><strong>Splits</strong>: 8:1:1 ratio for Training/Validation/Test.</li>
</ul>
<p><strong>Vocabulary</strong>: A token dictionary of 39 characters + special tokens: <code>[pad]</code>, <code>[sos]</code>, <code>[eos]</code>, <code>[0]-[9]</code>, <code>[C]</code>, <code>[1]</code>, <code>[c]</code>, <code>[O]</code>, <code>[N]</code>, <code>[n]</code>, <code>[F]</code>, <code>[H]</code>, <code>[o]</code>, <code>[S]</code>, <code>[s]</code>, <code>[B]</code>, <code>[r]</code>, <code>[I]</code>, <code>[i]</code>, <code>[P]</code>, <code>[p]</code>, <code>(</code>, <code>)</code>, <code>[</code>, <code>]</code>, <code>@</code>, <code>=</code>, <code>#</code>, <code>/</code>, <code>-</code>, <code>+</code>, <code>\</code>, <code>%</code>. Note that two-letter atoms like &lsquo;Br&rsquo; are tokenized as distinct characters <code>[B]</code>, <code>[r]</code>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: Character-level tokenization (not atom-level); the model learns to assemble &lsquo;C&rsquo; and &rsquo;l&rsquo; into &lsquo;Cl&rsquo;.</li>
<li><strong>Attention Mechanism</strong>: Uses a soft attention mechanism where the decoder calculates an attention score between the encoder&rsquo;s feature map ($8 \times 8 \times 512$) and the current hidden vector. Formula:
$$
\begin{aligned}
\text{att_score} &amp;= \text{softmax}(L_a(\tanh(L_f(F) + L_b(b_t))))
\end{aligned}
$$</li>
<li><strong>Training Configuration</strong>:
<ul>
<li><strong>Loss Function</strong>: Cross-entropy loss</li>
<li><strong>Optimizer</strong>: Adam optimizer</li>
<li><strong>Learning Rate</strong>: 2e-5</li>
<li><strong>Batch Size</strong>: 256</li>
<li><strong>Epochs</strong>: 15</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Encoder</strong>:</p>
<ul>
<li><strong>Backbone</strong>: Pre-trained <strong>ResNet101</strong> (trained on ImageNet).</li>
<li><strong>Modifications</strong>: The final layer is removed to output a Feature Map of size $8 \times 8 \times 512$.</li>
<li><strong>Flattening</strong>: Reshaped to a $64 \times 512$ feature matrix for the decoder.</li>
</ul>
<p><strong>Decoder</strong>:</p>
<ul>
<li><strong>Type</strong>: Long Short-Term Memory (LSTM) with Attention.</li>
<li><strong>Dropout</strong>: 0.3 applied to minimize overfitting.</li>
</ul>
<p>The encoder uses a pilot block, max-pool, and 4 layers of convolutional blocks (CB), feeding into the attention LSTM.</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>SA (Sequence Accuracy)</strong>: Strict exact match of SMILES strings.</li>
<li><strong>ALD (Average Levenshtein Distance)</strong>: Edit distance for character-level error analysis.</li>
<li><strong>AMFTS / <a href="mailto:MFTS@1.0">MFTS@1.0</a></strong>: Tanimoto similarity of ECFP4 fingerprints to measure structural similarity.</li>
</ul>
<p><strong>Test Sets</strong>:</p>
<ul>
<li><strong>Uni-style</strong>: 100,000 images (Indigo default).</li>
<li><strong>Multi-style</strong>: 100,000 images (&gt;10 styles).</li>
<li><strong>Noisy</strong>: 100,000 images with noise added.</li>
<li><strong>UOB</strong>: 5,575 real-world images from literature.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA Tesla V100 GPUs</li>
<li><strong>Training Time</strong>: Approximately 42 hours for the final model</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yiMICERPretrainedEncoder2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MICER}}: A Pre-Trained Encoder--Decoder Architecture for Molecular Image Captioning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MICER}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Yi, Jiacai and Wu, Chengkun and Zhang, Xiaochen and Xiao, Xinyi and Qiu, Yanlong and Zhao, Wentao and Hou, Tingjun and Cao, Dongsheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{38}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{19}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{4562--4572}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1367-4811}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1093/bioinformatics/btac545}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image2SMILES: Transformer OCSR with Synthetic Data Pipeline</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image2smiles/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image2smiles/</guid><description>Transformer-based OCSR using a novel synthetic data generation pipeline for robust molecular image interpretation across diverse drawing styles.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Khokhlov, I., Krasnov, L., Fedorov, M. V., &amp; Sosnin, S. (2022). Image2SMILES: Transformer-Based Molecular Optical Recognition Engine. <em>Chemistry-Methods</em>, 2(1), e202100069. <a href="https://doi.org/10.1002/cmtd.202100069">https://doi.org/10.1002/cmtd.202100069</a></p>
<p><strong>Publication</strong>: Chemistry-Methods 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/syntelly/img2smiles_generator">Official Code (Data Generator)</a></li>
<li><a href="https://app.syntelly.com/pdf2smiles">Syntelly Demo</a></li>
</ul>
<h2 id="contribution-image2smiles-as-a-method-and-resource">Contribution: Image2SMILES as a Method and Resource</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a specific neural architecture (ResNet backbone and Transformer Decoder) to solve the Optical Chemical Structure Recognition (OCSR) task, answering &ldquo;How well does this work?&rdquo; with extensive benchmarks against rule-based systems like OSRA.</li>
<li><strong>Resource</strong>: A core contribution is the &ldquo;Generate and Train!&rdquo; paradigm, where the authors release a comprehensive synthetic data generator to overcome the lack of labeled training data in the field.</li>
</ul>
<h2 id="motivation-bottlenecks-in-recognizing-trapped-chemical-structures">Motivation: Bottlenecks in Recognizing Trapped Chemical Structures</h2>
<p>Retrieving chemical structure data from legacy scientific literature is a major bottleneck in cheminformatics.</p>
<ul>
<li><strong>Problem</strong>: Chemical structures are often &ldquo;trapped&rdquo; in image formats (PDFs, scans). Manual extraction is slow, and existing rule-based tools (e.g., OSRA) are brittle when facing diverse drawing styles, &ldquo;Markush&rdquo; structures (templates), or visual contamination.</li>
<li><strong>Gap</strong>: Deep learning approaches require massive datasets, but no large-scale annotated dataset of chemical figures exists.</li>
<li><strong>Goal</strong>: To create a robust, data-driven recognition engine that can handle the messiness of real-world chemical publications (e.g., text overlays, arrows, partial overlaps).</li>
</ul>
<h2 id="core-innovation-the-generate-and-train-pipeline-and-fg-smiles">Core Innovation: The &ldquo;Generate and Train!&rdquo; Pipeline and FG-SMILES</h2>
<ul>
<li><strong>&ldquo;Generate and Train!&rdquo; Paradigm</strong>: The authors assert that architecture is secondary to data simulation. They developed an advanced augmentation pipeline that simulates geometry (rotation, bonds) alongside specific chemical drawing artifacts like &ldquo;Markush&rdquo; variables ($R_1$, $R_2$), functional group abbreviations (e.g., -OMe, -Ph), and visual &ldquo;contamination&rdquo; (stray text, arrows).</li>
<li><strong>FG-SMILES</strong>: A modified SMILES syntax designed to handle functional groups and Markush templates as single tokens (pseudo-atoms), allowing the model to predict generalized scaffolds.</li>
<li><strong>Encoder-Free Architecture</strong>: The authors found that a standard Transformer Encoder was unnecessary. They feed the flattened feature map from a ResNet backbone directly into the Transformer Decoder, which improved performance.</li>
</ul>
<h2 id="methodology-and-benchmarking-against-osra">Methodology and Benchmarking Against OSRA</h2>
<ul>
<li><strong>Training</strong>: The model was trained on 10 million synthetically generated images derived from PubChem structures, selected via a complexity-biased sampling algorithm.</li>
<li><strong>Validation (Synthetic)</strong>: Evaluated on a hold-out set of 1M synthetic images.</li>
<li><strong>Validation (Real World)</strong>:
<ul>
<li><strong>Dataset A</strong>: 332 manually cropped structures from 10 specific articles, excluding reaction schemes.</li>
<li><strong>Dataset B</strong>: 296 structures systematically extracted from <em>Journal of Organic Chemistry</em> (one paper per issue from 2020) to reduce selection bias.</li>
</ul>
</li>
<li><strong>Comparison</strong>: Benchmarked against OSRA (v2.11), a widely used rule-based OCSR tool.</li>
</ul>
<h2 id="results-high-precision-extraction-and-key-limitations">Results: High-Precision Extraction and Key Limitations</h2>
<ul>
<li><strong>Performance</strong>:
<ul>
<li><strong>Synthetic</strong>: 90.7% exact match accuracy.</li>
<li><strong>Real Data (Dataset A)</strong>: Image2SMILES achieved <strong>79.2%</strong> accuracy compared to OSRA&rsquo;s <strong>62.1%</strong>.</li>
<li><strong>Real Data (Dataset B)</strong>: Image2SMILES achieved <strong>62.5%</strong> accuracy compared to OSRA&rsquo;s <strong>24.0%</strong>.</li>
</ul>
</li>
<li><strong>Confidence Correlation</strong>: There is a strong correlation between the model&rsquo;s confidence score and prediction validity. Thresholding at 0.995 yields 99.85% accuracy while ignoring 22% of data, enabling high-precision automated pipelines.</li>
<li><strong>Key Failures</strong>: The model struggles with functional groups absent from its training dictionary (e.g., $\text{NMe}_2$, Ms), confusion of R-group indices ($R&rsquo;$ vs $R_1$), and explicit hydrogens rendered as groups.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: A subset of 10 million molecules sampled from PubChem.</li>
<li><strong>Selection Logic</strong>: Bias towards complex/rare structures using a &ldquo;Full Coefficient&rdquo; (FC) probability metric based on molecule size and ring/atom rarity.
<ul>
<li>Formula: $BC=0.1+1.2\left(\frac{n_{\max}-n}{n_{\max}}\right)^{3}$ where $n_{\max}=60$.</li>
</ul>
</li>
<li><strong>Generation</strong>: Uses RDKit for rendering with augmentations: rotation, font size, line thickness, whitespace, and CoordGen (20% probability).</li>
<li><strong>Contamination</strong>: &ldquo;Visual noise&rdquo; is stochastically added, including parts of other structures, labels, and arrows cropped from real documents.</li>
<li><strong>Target Format</strong>: <strong>FG-SMILES</strong> (Functional Group SMILES). Replaces common functional groups with pseudo-atoms (e.g., [Me], [Ph], [NO2]) and supports variable R-group positions using a <code>v</code> token.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Contamination Augmentation</strong>: A dedicated algorithm simulates visual noise (arrows, text) touching or overlapping the main molecule to force robustness.</li>
<li><strong>Functional Group Resolution</strong>: An algorithm identifies overlapping functional group templates (SMARTS) and resolves them to prevent nested group conflicts (e.g., resolving Methyl vs Methoxy).</li>
<li><strong>Markush Support</strong>: Stochastic replacement of substituents with R-group labels ($R_1$, $R&rsquo;$, etc.) based on a defined probability table (e.g., $P(R)=0.2$, $P(R_1)=0.15$).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: &ldquo;Image-to-Sequence&rdquo; hybrid model.
<ul>
<li><strong>Backbone</strong>: ResNet-50, but with the last two residual blocks removed. Output shape: $512 \times 48 \times 48$.</li>
<li><strong>Neck</strong>: No Transformer Encoder. CNN features are flattened and passed directly to the Decoder.</li>
<li><strong>Decoder</strong>: Standard Transformer Decoder (6 layers).</li>
</ul>
</li>
<li><strong>Input</strong>: Images resized to $384 \times 384 \times 3$.</li>
<li><strong>Output</strong>: Sequence of FG-SMILES tokens.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric</strong>: Binary &ldquo;Exact Match&rdquo; (valid/invalid).
<ul>
<li>Strict criteria: Stereo and R-group indices must match exactly (e.g., $R&rsquo;$ vs $R_1$ is a failure).</li>
</ul>
</li>
<li><strong>Datasets</strong>:
<ul>
<li><strong>Internal</strong>: 5% random split of generated data (500k samples).</li>
<li><strong>External (Dataset A &amp; B)</strong>: Manually cropped real-world images from specified journals.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4 $\times$ Nvidia V100 GPUs + 36 CPU cores.</li>
<li><strong>Duration</strong>: ~2 weeks for training (5 epochs, ~63 hours/epoch). Data generation took 3 days on 80 CPUs.</li>
<li><strong>Optimizer</strong>: RAdam with learning rate $3 \cdot 10^{-4}$.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{khokhlovImage2SMILESTransformerBasedMolecular2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Image2SMILES: Transformer-Based Molecular Optical Recognition Engine}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Image2SMILES}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Khokhlov, Ivan and Krasnov, Lev and Fedorov, Maxim V. and Sosnin, Sergey}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Chemistry-Methods}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{e202100069}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2628-9725}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1002/cmtd.202100069}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://chemistry-europe.onlinelibrary.wiley.com/doi/10.1002/cmtd.202100069}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image-to-Graph Transformers</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image-to-graph-transformers/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/image-to-graph-transformers/</guid><description>A deep learning model that converts molecular images directly into graph structures, enabling recognition of abbreviated non-atomic symbols.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yoo, S., Kwon, O., &amp; Lee, H. (2022). Image-to-Graph Transformers for Chemical Structure Recognition. <em>arXiv preprint arXiv:2202.09580</em>. <a href="https://doi.org/10.48550/arXiv.2202.09580">https://doi.org/10.48550/arXiv.2202.09580</a></p>
<p><strong>Publication</strong>: arXiv 2022</p>
<h2 id="contribution-and-taxonomic-classification">Contribution and Taxonomic Classification</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel deep learning architecture designed to extract molecular structures from images by directly predicting the graph topology. The paper validates this approach through ablation studies (comparing ResNet-only baselines to the Transformer-augmented model) and extensive benchmarking against existing tools.</p>
<h2 id="the-challenge-with-smiles-and-non-atomic-symbols">The Challenge with SMILES and Non-Atomic Symbols</h2>
<ul>
<li><strong>Handling Abbreviations:</strong> Chemical structures in scientific literature often use non-atomic symbols (superatoms like &ldquo;R&rdquo; or &ldquo;Ph&rdquo;) to reduce complexity. Standard tools that generate SMILES strings fail here because SMILES syntax does not support arbitrary non-atomic symbols.</li>
<li><strong>Robustness to Style:</strong> Existing rule-based tools are brittle to the diverse drawing styles found in literature.</li>
<li><strong>Data Utilization:</strong> Pixel-wise graph recognition tools (like ChemGrapher) require expensive pixel-level labeling. An end-to-end approach can utilize massive amounts of image-molecule pairs (like USPTO data) without needing exact coordinate labels.</li>
</ul>
<h2 id="the-image-to-graph-i2g-architecture">The Image-to-Graph (I2G) Architecture</h2>
<p>The core novelty is the <strong>Image-to-Graph (I2G)</strong> architecture that bypasses string representations entirely:</p>
<ul>
<li><strong>Hybrid Encoder:</strong> Combines a ResNet backbone (for locality) with a Transformer encoder (for global context), allowing the model to capture relationships between atoms that are far apart in the image.</li>
<li><strong>Graph Decoder (GRAT):</strong> A modified Transformer decoder that generates the graph auto-regressively. It uses feature-wise transformations to modulate attention weights based on edge information (bond types).</li>
<li><strong>Coordinate-Aware Training:</strong> The model is forced to predict the exact 2D coordinates of atoms in the source image, which significantly boosts accuracy by grounding the decoder&rsquo;s attention.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<ul>
<li><strong>Baselines:</strong> The model was compared against OSRA (rule-based), MolVec (rule-based), and ChemGrapher (deep learning pixel-wise).</li>
<li><strong>Benchmarks:</strong> Evaluated on four standard datasets: UoB, USPTO, CLEF, and JPO. Images were converted to PDF and back to simulate degradation.</li>
<li><strong>Large Molecule Test:</strong> A custom dataset (<strong>OLED</strong>) was created from 12 journal papers (434 images) to test performance on larger, more complex structures (average 52.8 atoms).</li>
<li><strong>Ablations:</strong> The authors tested the impact of the Transformer encoder, auxiliary losses, and coordinate prediction.</li>
</ul>
<h2 id="empirical-results-and-robustness">Empirical Results and Robustness</h2>
<ul>
<li><strong>SOTA Performance:</strong> The proposed model outperformed existing models with a 17.1% relative improvement on benchmark datasets.</li>
<li><strong>Robustness:</strong> On large molecules (OLED dataset), it achieved a 12.8% relative improvement.</li>
<li><strong>Data Scaling:</strong> Adding real-world USPTO data to the synthetic training set improved performance by 20.5%, demonstrating the model&rsquo;s ability to learn from noisy, unlabeled coordinates.</li>
<li><strong>Handling Superatoms:</strong> The model successfully recognized pseudo-atoms (e.g., $R_1$, $R_2$) as distinct nodes, whereas SMILES-based tools collapsed them into generic &ldquo;Any&rdquo; atoms.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors used a combination of synthetic and real-world data for training.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td><strong>PubChem</strong></td>
          <td>4.6M</td>
          <td>Synthetic images generated using RDKit. Random superatoms (e.g., $CF_3$, $NO_2$) were substituted to simulate abbreviations.</td>
      </tr>
      <tr>
          <td>Training</td>
          <td><strong>USPTO</strong></td>
          <td>2.5M</td>
          <td>Real image-molecule pairs from patents. Used for robustness; lacks coordinate labels.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td><strong>Benchmarks</strong></td>
          <td>~5.7k</td>
          <td>UoB, USPTO, CLEF, JPO. Average ~15.8 atoms per molecule.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td><strong>OLED</strong></td>
          <td>434</td>
          <td>Manually segmented from 12 journal papers. Large molecules (avg 52.8 atoms).</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing:</strong></p>
<ul>
<li>Input resolution is fixed at $800 x800$ pixels.</li>
<li>Images are virtually split into a $25 x25$ grid (625 patches total), where each patch is $32 x32$ pixels.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Encoder Logic:</strong></p>
<ul>
<li><strong>Grid Serialization:</strong> The $25 x25$ grid is flattened into a 1D sequence. 2D position information is concatenated to ResNet features before the Transformer.</li>
<li><strong>Auxiliary Losses:</strong> To aid convergence, classifiers on the encoder predict three things <em>per patch</em>: (1) number of atoms, (2) characters in atom labels, and (3) edge-sharing neighbors. These losses decrease to zero during training.</li>
</ul>
<p><strong>Decoder Logic:</strong></p>
<ul>
<li><strong>Auto-regressive Generation:</strong> At step $t$, the decoder generates a new node and connects it to existing nodes.</li>
<li><strong>Attention Modulation:</strong> Attention weights are transformed using bond information:
$$
\begin{aligned}
\text{Att}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d}} \odot \gamma(e_{ij}) + \beta(e_{ij}) \right) V
\end{aligned}
$$
where $e_{ij}$ is the edge type.</li>
<li><strong>Coordinate Prediction:</strong> The decoder outputs coordinates for each atom, which acts as a mechanism to track attention history.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Image Encoder:</strong> ResNet-34 backbone followed by a Transformer encoder.</li>
<li><strong>Graph Decoder:</strong> A &ldquo;Graph-Aware Transformer&rdquo; (GRAT) that outputs nodes (atom labels, coordinates) and edges (bond types).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics focus on structural identity, as standard string matching (SMILES) is insufficient for graphs with superatoms.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>SMI</strong></td>
          <td>Canonical SMILES Match</td>
          <td>Correct if predicted SMILES is identical to ground truth.</td>
      </tr>
      <tr>
          <td><strong>TS 1</strong></td>
          <td>Tanimoto Similarity = 1.0</td>
          <td>Ratio of predictions with perfect fingerprint overlap.</td>
      </tr>
      <tr>
          <td><strong>Sim.</strong></td>
          <td>Average Tanimoto Similarity</td>
          <td>Measures average structural overlap (cutoff usually &gt; 0.85).</td>
      </tr>
  </tbody>
</table>
]]></content:encoded></item><item><title>ICMDT: Automated Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/icmdt/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/icmdt/</guid><description>A Transformer-based model (ICMDT) for converting chemical structure images into InChI text strings using a novel Deep TNT block.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, Y., Chen, G., &amp; Li, X. (2022). Automated Recognition of Chemical Molecule Images Based on an Improved TNT Model. <em>Applied Sciences</em>, 12(2), 680. <a href="https://doi.org/10.3390/app12020680">https://doi.org/10.3390/app12020680</a></p>
<p><strong>Publication</strong>: MDPI Applied Sciences 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.kaggle.com/c/bms-molecular-translation">Kaggle Competition: BMS Molecular Translation</a></li>
</ul>
<h2 id="contribution-image-to-text-translation-for-chemical-structures">Contribution: Image-to-Text Translation for Chemical Structures</h2>
<p>This is a <strong>Method</strong> paper.</p>
<p>It proposes a novel neural network architecture, the <strong>Image Captioning Model based on Deep TNT (ICMDT)</strong>, to solve the specific problem of &ldquo;molecular translation&rdquo; (image-to-text). The classification is supported by the following rhetorical indicators:</p>
<ul>
<li><strong>Novel Mechanism:</strong> It introduces the &ldquo;Deep TNT block&rdquo; to improve upon the existing TNT architecture by fusing features at three levels (pixel, small patch, large patch).</li>
<li><strong>Baseline Comparison:</strong> The authors explicitly compare their model against four other architectures (CNN+RNN and CNN+Transformer variants).</li>
<li><strong>Ablation Study:</strong> Section 4.3 is dedicated to ablating specific components (position encoding, patch fusion) to prove their contribution to the performance gain.</li>
</ul>
<h2 id="motivation-digitizing-historical-chemical-literature">Motivation: Digitizing Historical Chemical Literature</h2>
<p>The primary motivation is to speed up chemical research by digitizing historical chemical literature.</p>
<ul>
<li><strong>Problem:</strong> Historical sources often contain corrupted or noisy images, making automated recognition difficult.</li>
<li><strong>Gap:</strong> Existing models like the standard TNT (Transformer in Transformer) function primarily as encoders for classification and fail to effectively integrate local pixel-level information required for precise structure generation.</li>
<li><strong>Goal:</strong> To build a dependable generative model that can accurately translate these noisy images into <strong><a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a></strong> (International Chemical Identifier) text strings.</li>
</ul>
<h2 id="novelty-multi-level-feature-fusion-with-deep-tnt">Novelty: Multi-Level Feature Fusion with Deep TNT</h2>
<p>The core contribution is the <strong>Deep TNT block</strong> and the resulting <strong>ICMDT</strong> architecture.</p>
<ul>
<li><strong>Deep TNT Block:</strong> The Deep TNT block expands upon standard local and global modeling by stacking three transformer blocks to process information at three granularities:
<ol>
<li><strong>Internal Transformer:</strong> Processes pixel embeddings.</li>
<li><strong>Middle Transformer:</strong> Processes small patch embeddings.</li>
<li><strong>Exterior Transformer:</strong> Processes large patch embeddings.</li>
</ol>
</li>
<li><strong>Multi-level Fusion:</strong> The model fuses pixel-level features into small patches, and small patches into large patches, allowing for finer integration of local details.</li>
<li><strong>Position Encoding:</strong> A specific strategy of applying shared position encodings to small patches and pixels, while using a learnable 1D encoding for large patches.</li>
</ul>
<h2 id="methodology-benchmarking-on-the-bms-dataset">Methodology: Benchmarking on the BMS Dataset</h2>
<p>The authors evaluated the model on the <strong>Bristol-Myers Squibb Molecular Translation</strong> dataset.</p>
<ul>
<li><strong>Baselines:</strong> They constructed four comparative models:
<ul>
<li>EfficientNetb0 + RNN (Bi-LSTM)</li>
<li>ResNet50d + RNN (Bi-LSTM)</li>
<li>EfficientNetb0 + Transformer</li>
<li>ResNet101d + Transformer</li>
</ul>
</li>
<li><strong>Ablation:</strong> They tested the impact of removing the large patch position encoding (ICMDT*) and reverting the encoder to a standard TNT-S (TNTD).</li>
<li><strong>Pre-processing Study:</strong> They experimented with denoising ratios and cropping strategies.</li>
</ul>
<h2 id="results--conclusions-state-of-the-art-inchi-translation">Results &amp; Conclusions: State-of-the-Art InChI Translation</h2>
<ul>
<li><strong>Performance:</strong> ICMDT achieved the lowest <strong>Levenshtein distance (0.69)</strong> compared to the best baseline (1.45 for ResNet101d+Transformer).</li>
<li><strong>Convergence:</strong> The model converged significantly faster than the baselines, outperforming others as early as epoch 6.7.</li>
<li><strong>Ablation Results:</strong> The full Deep TNT block reduced error by nearly half compared to the standard TNT encoder (0.69 vs 1.29 Levenshtein distance).</li>
<li><strong>Limitations:</strong> The model struggles with <strong>stereochemical layers</strong> (e.g., identifying clockwise neighbors or +/- signs) compared to non-stereochemical layers.</li>
<li><strong>Future Work:</strong> Integrating full object detection to predict atom/bond coordinates to better resolve 3D stereochemical information.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The experiments used the <strong>Bristol-Myers Squibb Molecular Translation</strong> dataset from Kaggle.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>BMS Training Set</td>
          <td>2,424,186 images</td>
          <td>Supervised; contains noise and blur</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BMS Test Set</td>
          <td>1,616,107 images</td>
          <td>Higher noise variation than training set</td>
      </tr>
  </tbody>
</table>
<p><strong>Pre-processing Strategy</strong>:</p>
<ul>
<li><strong>Effective:</strong> Padding resizing (reshaping to square, padding with border pixels).</li>
<li><strong>Ineffective:</strong> Smart cropping (removing white borders degraded performance).</li>
<li><strong>Augmentation:</strong> GaussNoise, Blur, RandomRotate90, and PepperNoise ($SNR=0.996$).</li>
<li><strong>Denoising:</strong> Best results found by mixing denoised and original data (Ratio 2:13) during training.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer:</strong> Lookahead ($\alpha=0.5, k=5$) and RAdam ($\beta_1=0.9, \beta_2=0.99$).</li>
<li><strong>Loss Function:</strong> Anti-Focal loss ($\gamma=0.5$) combined with Label Smoothing. While standard Focal Loss adds a modulating factor $(1-p_t)^\gamma$ to cross-entropy to focus on hard negatives, Anti-Focal Loss modifies this to reduce the disparity between training and inference distributions:
$$ \text{Anti-Focal Loss}(p_t) = - (1 + p_t)^\gamma \log(p_t) $$</li>
<li><strong>Training Schedule:</strong>
<ul>
<li>Initial resolution: $224 \times 224$</li>
<li>Fine-tuning: Resolution $384 \times 384$ for labels $&gt;150$ length.</li>
<li>Batch size: Dynamic, increasing from 16 to 1024.</li>
</ul>
</li>
<li><strong>Inference Strategy:</strong>
<ul>
<li>Beam Search ($k=16$ initially, $k=64$ if failing InChI validation).</li>
<li>Test Time Augmentation (TTA): Rotations of $90^\circ$.</li>
<li>Ensemble: Step-wise logit ensemble and voting based on Levenshtein distance scores.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>ICMDT Architecture:</strong></p>
<ul>
<li><strong>Encoder (Deep TNT)</strong>:
<ul>
<li><strong>Internal Block:</strong> Hidden size 160, Heads 4, MLP act GELU.</li>
<li><strong>Middle Block:</strong> Hidden size 640, Heads 10.</li>
<li><strong>Exterior Block:</strong> Hidden size 2560, Heads 16.</li>
<li><strong>Patch Sizes:</strong> Small patch $16 \times 16$, Pixel patch $4 \times 4$.</li>
</ul>
</li>
<li><strong>Decoder (Vanilla Transformer)</strong>:
<ul>
<li>Dimensions: 5120 (Hidden), 1024 (FFN).</li>
<li>Depth: 3 layers, Heads: 8.</li>
<li>Vocab Size: 193 (InChI tokens).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric:</strong> Levenshtein Distance (measures single-character edit operations between generated and ground truth InChI strings).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Levenshtein Distance</th>
          <th>Params (M)</th>
          <th>Convergence (Epochs)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ICMDT (Ours)</strong></td>
          <td><strong>0.69</strong></td>
          <td>138.16</td>
          <td>~9.76</td>
      </tr>
      <tr>
          <td>TNTD (Ablation)</td>
          <td>1.29</td>
          <td>114.36</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ResNet101d + Transformer</td>
          <td>1.45</td>
          <td>302.02</td>
          <td>14+</td>
      </tr>
      <tr>
          <td>ResNet50d + RNN</td>
          <td>3.86</td>
          <td>90.6</td>
          <td>14+</td>
      </tr>
      <tr>
          <td>EfficientNetb0 + RNN</td>
          <td>2.95</td>
          <td>46.3</td>
          <td>11+</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{liAutomatedRecognitionChemical2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Automated {{Recognition}} of {{Chemical Molecule Images Based}} on an {{Improved TNT Model}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Li, Yanchi and Chen, Guanyu and Li, Xiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Applied Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{680}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Multidisciplinary Digital Publishing Institute}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2076-3417}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.3390/app12020680}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Handwritten Chemical Structure Recognition with RCGD</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/hu-handwritten-rcgd-2023/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/hu-handwritten-rcgd-2023/</guid><description>An end-to-end framework (RCGD) and unambiguous markup language (SSML) for recognizing complex handwritten chemical structures with guided graph traversal.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hu, J., Wu, H., Chen, M., Liu, C., Wu, J., Yin, S., Yin, B., Liu, C., Du, J., &amp; Dai, L. (2023). Handwritten Chemical Structure Image to Structure-Specific Markup Using Random Conditional Guided Decoder. <em>Proceedings of the 31st ACM International Conference on Multimedia</em> (pp. 8114-8124). <a href="https://doi.org/10.1145/3581783.3612573">https://doi.org/10.1145/3581783.3612573</a></p>
<p><strong>Publication</strong>: ACM Multimedia 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/iFLYTEK-CV/EDU-CHEMC">GitHub Repository / EDU-CHEMC Dataset</a></li>
</ul>
<h2 id="contribution-and-methodological-framework">Contribution and Methodological Framework</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architectural framework (<strong>RCGD</strong>) and a new representation syntax (<strong>SSML</strong>) to solve the specific problem of handwritten chemical structure recognition.</li>
<li><strong>Resource</strong>: It introduces a new benchmark dataset, <strong>EDU-CHEMC</strong>, containing 50,000 handwritten images to address the lack of public data in this domain.</li>
</ul>
<h2 id="the-ambiguity-of-handwritten-chemical-structures">The Ambiguity of Handwritten Chemical Structures</h2>
<p>Recognizing handwritten chemical structures is significantly harder than printed ones due to:</p>
<ol>
<li><strong>Inherent Ambiguity</strong>: Handwritten atoms and bonds vary greatly in appearance.</li>
<li><strong>Projection Complexity</strong>: Converting 2D projected layouts (like Natta or Fischer projections) into linear strings is difficult.</li>
<li><strong>Limitations of Existing Formats</strong>: Standard formats like SMILES require domain knowledge (valence rules) and have a high semantic gap with the visual image. They often fail to represent &ldquo;invalid&rdquo; structures commonly found in educational/student work.</li>
</ol>
<h2 id="bridging-the-semantic-gap-with-ssml-and-rcgd">Bridging the Semantic Gap with SSML and RCGD</h2>
<p>The paper introduces two core contributions to bridge the semantic gap between image and markup:</p>
<ol>
<li>
<p><strong>Structure-Specific Markup Language (SSML)</strong>: An extension of Chemfig that provides an unambiguous, visual-based graph representation. Unlike SMILES, it describes <em>how to draw</em> the molecule step-by-step, making it easier for models to learn visual alignments. It supports &ldquo;reconnection marks&rdquo; to handle cyclic structures explicitly.</p>
</li>
<li>
<p><strong>Random Conditional Guided Decoder (RCGD)</strong>: A decoder that treats recognition as a graph traversal problem. It introduces three novel mechanisms:</p>
<ul>
<li><strong>Conditional Attention Guidance</strong>: Uses branch angle directions to guide the attention mechanism, preventing the model from getting lost in complex structures.</li>
<li><strong>Memory Classification</strong>: A module that explicitly stores and classifies &ldquo;unexplored&rdquo; branch points to handle ring closures (reconnections).</li>
<li><strong>Path Selection</strong>: A training strategy that randomly samples traversal paths to prevent overfitting to a specific serialization order.</li>
</ul>
</li>
</ol>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p><strong>Datasets</strong>:</p>
<ul>
<li><strong>Mini-CASIA-CSDB</strong> (Printed): A subset of 97,309 synthetic images, upscaled to $500 \x500$ resolution.</li>
<li><strong>EDU-CHEMC</strong> (Handwritten): A new dataset of 52,987 images collected from educational settings (cameras, scanners, screens), including erroneous/non-existent structures.</li>
</ul>
<p><strong>Baselines</strong>:</p>
<ul>
<li>Compared against standard <strong>String Decoders (SD)</strong> (based on DenseWAP) trained on SMILES strings.</li>
<li>Compared against <strong>BTTR</strong> and <strong>ABM</strong> (recent mathematical expression recognition models) adapted for this task.</li>
</ul>
<p><strong>Ablation Studies</strong>:</p>
<ul>
<li>Evaluated the impact of removing Path Selection (PS) and Memory Classification (MC) mechanisms.</li>
<li>Tested robustness to image rotation ($180^\circ$).</li>
</ul>
<h2 id="recognition-performance-and-robustness">Recognition Performance and Robustness</h2>
<ul>
<li><strong>Superiority of SSML</strong>: Models trained with SSML significantly outperformed those trained with SMILES (92.09% vs 81.89% EM on printed data) due to reduced semantic gap.</li>
<li><strong>SOTA Performance</strong>: RCGD achieved the highest Exact Match (EM) scores on both datasets:
<ul>
<li><strong>Mini-CASIA-CSDB</strong>: 95.01% EM.</li>
<li><strong>EDU-CHEMC</strong>: 62.86% EM.</li>
</ul>
</li>
<li><strong>Robustness</strong>: RCGD showed minimal performance drop (0.85%) on rotated images compared to SMILES-based methods (10.36% drop).</li>
<li><strong>Educational Utility</strong>: The method can recognize and reconstruct chemically invalid structures (e.g., a Carbon atom with 5 bonds), making it suitable for automated grading systems.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>1. EDU-CHEMC (Handwritten)</strong></p>
<ul>
<li><strong>Total Size</strong>: 52,987 images.</li>
<li><strong>Splits</strong>: Training (48,998), Validation (999), Test (2,992).</li>
<li><strong>Characteristics</strong>: Real-world educational data, mixture of isolated molecules and reaction equations, includes invalid chemical structures.</li>
</ul>
<p><strong>2. Mini-CASIA-CSDB (Printed)</strong></p>
<ul>
<li><strong>Total Size</strong>: 97,309 images.</li>
<li><strong>Splits</strong>: Training (80,781), Validation (8,242), Test (8,286).</li>
<li><strong>Preprocessing</strong>: Original $300 \x300$ images were upscaled to $500 \x500$ RGB to resolve blurring issues.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. SSML Generation</strong></p>
<p>To convert a molecular graph to SSML:</p>
<ol>
<li><strong>Traverse</strong>: Start from the left-most atom.</li>
<li><strong>Bonds/Atoms</strong>: Output atom text and bond format <code>&lt;bond&gt;[:&lt;angle&gt;]</code>.</li>
<li><strong>Branches</strong>: At branch points, use phantom symbols <code>(</code> and <code>)</code> to enclose branches, ordered by ascending bond angle.</li>
<li><strong>Reconnections</strong>: Use <code>?[tag]</code> and <code>?[tag, bond]</code> to mark start/end of ring closures.</li>
</ol>
<p><strong>2. RCGD Specifics</strong></p>
<ul>
<li><strong>RCGD-SSML</strong>: Modified version of SSML for the decoder. Removes <code>(</code> <code>)</code> delimiters; adds <code>\eob</code> (end of branch). Maintains a dynamic <strong>Branch Angle Set ($M$)</strong>.</li>
<li><strong>Path Selection</strong>: During training, when multiple branches exist in $M$, the model randomly selects one to traverse next. During inference, it uses beam search to score candidate paths.</li>
<li><strong>Loss Function</strong>:
$$
\begin{aligned}
L_{\text{total}} = L_{\text{ce}} + L_{\text{bc}}
\end{aligned}
$$
<ul>
<li>$L_{\text{ce}}$: Cross-entropy loss for character sequence generation.</li>
<li>$L_{\text{bc}}$: Multi-label classification loss for the memory module (predicting reconnection bond types for stored branch states).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Encoder</strong>: DenseNet</p>
<ul>
<li><strong>Structure</strong>: 3 dense blocks.</li>
<li><strong>Growth Rate</strong>: 24.</li>
<li><strong>Depth</strong>: 32 per block.</li>
<li><strong>Output</strong>: High-dimensional feature map $x \in \mathbb{R}^{d_x \xh \xw}$.</li>
</ul>
<p><strong>Decoder</strong>: GRU with Attention</p>
<ul>
<li><strong>Hidden State Dimension</strong>: 256.</li>
<li><strong>Embedding Dimension</strong>: 256.</li>
<li><strong>Attention Projection</strong>: 128.</li>
<li><strong>Memory Classification Projection</strong>: 256.</li>
</ul>
<p><strong>Training Config</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: Adam.</li>
<li><strong>Learning Rate</strong>: 2e-4 with multi-step decay (gamma 0.5).</li>
<li><strong>Dropout</strong>: 15%.</li>
<li><strong>Strategy</strong>: Teacher-forcing used for validation selection.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Exact Match (EM)</strong>: Percentage of samples where the predicted graph structure perfectly matches the label. For SMILES, string comparison; for SSML, converted to graph for isomorphism check.</li>
<li><strong>Structure EM</strong>: Auxiliary metric for samples with mixed content (text + molecules), counting samples where <em>all</em> molecular structures are correct.</li>
</ul>
<p><strong>Code Availability</strong>:</p>
<ul>
<li>The dataset is hosted at: <a href="https://github.com/iFLYTEK-CV/EDU-CHEMC">https://github.com/iFLYTEK-CV/EDU-CHEMC</a></li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{huHandwrittenChemicalStructure2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Handwritten Chemical Structure Image to Structure-Specific Markup Using Random Conditional Guided Decoder}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 31st ACM International Conference on Multimedia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Hu, Jinshui and Wu, Hao and Chen, Mingjun and Liu, Chenyu and Wu, Jiajia and Yin, Shi and Yin, Baocai and Yin, Bing and Liu, Cong and Du, Jun and Dai, Lirong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{8114--8124}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Ottawa ON Canada}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1145/3581783.3612573}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{979-8-4007-0108-5}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>End-to-End Transformer for Molecular Image Captioning</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/vit-inchi-transformer/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/vit-inchi-transformer/</guid><description>Vision Transformer encoder with Transformer decoder for molecular image-to-InChI translation, achieving state-of-the-art performance on noisy datasets.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sundaramoorthy, C., Kelvin, L. Z., Sarin, M., &amp; Gupta, S. (2021). End-to-End Attention-based Image Captioning. <em>arXiv preprint arXiv:2104.14721</em>. <a href="https://doi.org/10.48550/arXiv.2104.14721">https://doi.org/10.48550/arXiv.2104.14721</a></p>
<p><strong>Publication</strong>: arXiv 2021 (preprint)</p>
<p><strong>Note</strong>: This is an arXiv preprint and has not undergone formal peer review.</p>
<h2 id="methodological-contribution">Methodological Contribution</h2>
<p>This is a <strong>Methodological Paper</strong>. It proposes a novel architectural approach to molecular image translation by replacing the standard CNN encoder with a Vision Transformer (ViT). The authors validate this method through comparative benchmarking against standard CNN+RNN baselines (e.g., ResNet+LSTM) and provide optimizations for inference speed.</p>
<h2 id="motivation-and-problem-statement">Motivation and Problem Statement</h2>
<p>The core problem addressed is existing molecular translation methods (extracting chemical structure from images into computer-readable InChI format) rely heavily on rule-based systems or CNN+RNN architectures. These current approaches often underperform when handling noisy images (common in scanned old journals) or images with few distinguishable features. There is a significant need in drug discovery to digitize and analyze legacy experimental data locked in image format within scientific publications.</p>
<h2 id="core-innovations-end-to-end-vit-encoder">Core Innovations: End-to-End ViT Encoder</h2>
<p>The primary contribution is the use of a completely convolution-free Vision Transformer (ViT) as the encoder, allowing the model to utilize long-range dependencies among image patches from the very beginning via self-attention:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
The architecture is a pure Transformer (Encoder-Decoder), treating the molecular image similarly to a sequence of tokens (patches). Furthermore, the authors implement a specific caching strategy for the decoder to avoid recomputing embeddings for previously decoded tokens, reducing the time complexity of the decoding step.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The model was compared against standard CNN + RNN and ResNet (18, 34, 50) + LSTM with attention. Ablation studies were conducted varying the number of transformer layers (3, 6, 12, 24) and image resolution (224x224 vs 384x384). The model trained on a large combined dataset, including Bristol Myers Squibb data, SMILES, GDB-13, and synthetically augmented images containing noise and artifacts. Performance was evaluated using the Levenshtein distance metric, which computes the minimum number of single-character edits to transform the predicted string into the ground truth.</p>
<h2 id="performance-outcomes-and-capabilities">Performance Outcomes and Capabilities</h2>
<p>The proposed 24-layer ViT model (input size 384) achieved the lowest Levenshtein distance of <strong>6.95</strong>, significantly outperforming the ResNet50+LSTM baseline (7.49) and the standard CNN+RNN (103.7). Increasing the number of layers had a strong positive impact, with the 24-layer model becoming competitive with current approaches. The model demonstrated high robustness on datasets with low distinguishable features and noise. The proposed caching optimization reduced the decoding time complexity per timestep from $O(MN^2 + N^3)$ to $O(MN + N^2)$.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a combined dataset randomly split into 70% training, 10% test, and 20% validation.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Bristol Myers Squibb</strong></td>
          <td>~2.4 million synthetic images with InChI labels.</td>
          <td>Provided by BMS global biopharmaceutical company.</td>
      </tr>
      <tr>
          <td><strong>SMILES</strong></td>
          <td>Kaggle contest data converted to InChI.</td>
          <td>Images generated using RDKit.</td>
      </tr>
      <tr>
          <td><strong>GDB-13</strong></td>
          <td>Subset of 977 million small organic molecules (up to 13 atoms).</td>
          <td>Converted from SMILES using RDKit.</td>
      </tr>
      <tr>
          <td><strong>Augmented Images</strong></td>
          <td>Synthetic images with salt/pepper noise, dropped atoms, and bond modifications.</td>
          <td>Used to improve robustness against noise.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Training Objective</strong>: Cross-entropy loss minimization.</li>
<li><strong>Inference Decoding</strong>: Autoregressive decoding predicting the next character of the InChI string.</li>
<li><strong>Positional Encoding</strong>: Standard sine and cosine functions of different frequencies.</li>
<li><strong>Optimization</strong>:
<ul>
<li><strong>Caching</strong>: Caches the output of each layer during decoding to avoid recomputing embeddings for already decoded tokens.</li>
<li><strong>JIT</strong>: PyTorch JIT compiler used for graph optimization.</li>
<li><strong>Self-Critical Training</strong>: Finetuning performed using self-critical sequence training (SCST).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Encoder (Vision Transformer)</strong>:
<ul>
<li>Input: Flattened 2D patches of the image. Patch size: $16 \x16$.</li>
<li>Projection: Trainable linear projection to latent vector size $D$.</li>
<li>Structure: Alternating layers of Multi-Head Self-Attention (MHSA) and MLP blocks.</li>
</ul>
</li>
<li><strong>Decoder (Vanilla Transformer)</strong>:
<ul>
<li>Input: Tokenized InChI string + sinusoidal positional embedding.</li>
<li>Vocabulary: 275 tokens (including <code>&lt;SOS&gt;</code>, <code>&lt;PAD&gt;</code>, <code>&lt;EOS&gt;</code>).</li>
</ul>
</li>
<li><strong>Hyperparameters (Best Model)</strong>:
<ul>
<li>Image Size: $384 \x384$.</li>
<li>Layers: 24.</li>
<li>Feature Dimension: 512.</li>
<li>Attention Heads: 12.</li>
<li>Optimizer: Adam.</li>
<li>Learning Rate: $3 \x10^{-5}$ (decayed by 0.5 in last 2 epochs).</li>
<li>Batch Size: Varied [64-512].</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Levenshtein Distance (lower is better).</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Image Size</th>
          <th>Layers</th>
          <th>Levenshtein Dist.</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Standard CNN+RNN</td>
          <td>224</td>
          <td>3</td>
          <td>103.7</td>
      </tr>
      <tr>
          <td>ResNet50 + LSTM</td>
          <td>224</td>
          <td>5</td>
          <td>7.49</td>
      </tr>
      <tr>
          <td>ViT Transformers (Best)</td>
          <td>384</td>
          <td>24</td>
          <td><strong>6.95</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>System</strong>: 70GB GPU system.</li>
<li><strong>Framework</strong>: PyTorch and PyTorch Lightning.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{sundaramoorthyEndtoEndAttentionbasedImage2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{End-to-{{End Attention-based Image Captioning}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Sundaramoorthy, Carola and Kelvin, Lin Ziwen and Sarin, Mahak and Gupta, Shubham}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DECIMER 1.0: Transformers for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-1.0/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer-1.0/</guid><description>Transformer-based approach for Optical Chemical Structure Recognition converting chemical images to SELFIES strings with 96% accuracy.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A. &amp; Steinbeck, C. (2021). DECIMER 1.0: deep learning for chemical image recognition using transformers. <em>Journal of Cheminformatics</em>, 13(1), 61. <a href="https://doi.org/10.1186/s13321-021-00538-8">https://doi.org/10.1186/s13321-021-00538-8</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Kohulan/DECIMER-TPU">GitHub Repository</a></li>
<li><a href="https://decimer.ai/">DECIMER Project Page</a></li>
</ul>
<h2 id="evaluating-the-contribution-a-methodological-shift">Evaluating the Contribution: A Methodological Shift</h2>
<p><strong>Method (Dominant)</strong> with strong <strong>Resource</strong> elements.</p>
<p>This is primarily a <strong>Method</strong> paper because it proposes a specific architectural evolution. It replaces CNN-RNN/Encoder-Decoder models with a <strong>Transformer-based network</strong> to solve the problem of image-to-structure translation. It validates this methodological shift through rigorous ablation studies comparing feature extractors (InceptionV3 vs. EfficientNet) and decoder architectures.</p>
<p>It also serves as a <strong>Resource</strong> contribution by releasing the open-source software, trained models, and describing the curation of a massive synthetic training dataset (&gt;35 million molecules).</p>
<h2 id="motivation-unlocking-locked-chemical-knowledge">Motivation: Unlocking Locked Chemical Knowledge</h2>
<ul>
<li><strong>Data Inaccessibility</strong>: A vast amount of chemical knowledge (pre-1990s) is locked in printed or scanned literature and is not machine-readable.</li>
<li><strong>Manual Bottlenecks</strong>: Manual curation and extraction of this data is tedious, slow, and error-prone.</li>
<li><strong>Limitations of Prior Tools</strong>: Existing Optical Chemical Structure Recognition (OCSR) tools are often rule-based or struggle with the noise and variability of full-page scanned articles. Previous deep learning attempts were not publicly accessible or robust enough.</li>
</ul>
<h2 id="key-innovation-transformer-based-molecular-translation">Key Innovation: Transformer-Based Molecular Translation</h2>
<ul>
<li><strong>Transformer Architecture</strong>: Shifts from the standard CNN-RNN (Encoder-Decoder) approach to a <strong>Transformer-based decoder</strong>, significantly improving accuracy.</li>
<li><strong>EfficientNet Backbone</strong>: Replaces the standard InceptionV3 feature extractor with <strong>EfficientNet-B3</strong>, which improved feature extraction quality for chemical images.</li>
<li><strong>SELFIES Representation</strong>: Utilizes <a href="/notes/computational-chemistry/molecular-representations/selfies/"><strong>SELFIES</strong></a> (SELF-referencing Embedded Strings) as the target output. This guarantees 100% robust molecular strings and eliminates the &ldquo;invalid SMILES&rdquo; problem common in generative models.</li>
<li><strong>Massive Scaling</strong>: Trains on a synthetic dataset of <strong>35-39 million molecules</strong>, demonstrating that scaling data size directly correlates with improved model performance.</li>
</ul>
<h2 id="methodology-and-experimental-validation">Methodology and Experimental Validation</h2>
<ul>
<li><strong>Feature Extractor Ablation</strong>: Compared InceptionV3 vs. EfficientNet-B3 (and B7) on a 1-million molecule subset to determine the optimal image encoder.</li>
<li><strong>Architecture Comparison</strong>: Benchmarked the Encoder-Decoder (CNN+RNN) against the Transformer model using Tanimoto similarity metrics. The structural similarity between predicted and ground truth molecules was measured via Tanimoto similarity over molecular fingerprints:
$$ T(\mathbf{A}, \mathbf{B}) = \frac{\mathbf{A} \cdot \mathbf{B}}{|\mathbf{A}|^2 + |\mathbf{B}|^2 - \mathbf{A} \cdot \mathbf{B}} $$</li>
<li><strong>Data Scaling</strong>: Evaluated performance across increasing training set sizes (1M, 10M, 15M, 35M) to observe scaling laws.</li>
<li><strong>Stereochemistry &amp; Ions</strong>: Tested the model&rsquo;s ability to handle complex stereochemical information and charged groups (ions), creating separate datasets for these tasks.</li>
<li><strong>Augmentation Robustness</strong>: Evaluated the model on augmented images (blur, noise, varying contrast) to simulate real-world scanned document conditions.</li>
</ul>
<h2 id="results-superior-accuracy-and-scaling-limit-observations">Results: Superior Accuracy and Scaling Limit Observations</h2>
<ul>
<li><strong>Superior Architecture</strong>: The Transformer model with EfficientNet-B3 features significantly outperformed the Encoder-Decoder baseline. On the 1M dataset, the Transformer achieved <strong>74.57%</strong> exact matches (Tanimoto 1.0) compared to only <strong>7.03%</strong> for the Encoder-Decoder.</li>
<li><strong>High Accuracy at Scale</strong>: With the full 35-million molecule training set (Dataset 1), the model achieved a <strong>Tanimoto 1.0 score of 96.47%</strong> and an average Tanimoto similarity of 0.99.</li>
<li><strong>Isomorphism</strong>: 99.75% of predictions with a Tanimoto score of 1.0 were confirmed to be structurally isomorphic to the ground truth (checked via <a href="/notes/computational-chemistry/molecular-representations/inchi-2013/">InChI</a>).</li>
<li><strong>Stereochemistry Costs</strong>: Including stereochemistry and ions increased the token count and difficulty, resulting in slightly lower accuracy (~89.87% exact match on Dataset 2).</li>
<li><strong>Hardware Efficiency</strong>: Training on TPUs (v3-8) was ~4x faster than Nvidia V100 GPUs, reducing training time for the largest models from months to under 14 days.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors generated synthetic data from PubChem.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 1 (Clean)</td>
          <td>35M molecules</td>
          <td>No stereo/ions. Filtered for MW &lt; 1500, bond count 3-40, SMILES len &lt; 40.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 2 (Complex)</td>
          <td>33M molecules</td>
          <td>Includes stereochemistry and charged groups (ions).</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 3 (Augmented)</td>
          <td>33M molecules</td>
          <td>Dataset 2 with image augmentations applied.</td>
      </tr>
      <tr>
          <td><strong>Preprocessing</strong></td>
          <td>N/A</td>
          <td>N/A</td>
          <td>Molecules converted to <strong>SELFIES</strong>. Images generated via CDK Structure Diagram Generator (SDG) as $299\times299$ 8-bit PNGs.</td>
      </tr>
      <tr>
          <td><strong>Format</strong></td>
          <td>TFRecords</td>
          <td>75 MB chunks</td>
          <td>128 Data points (image vector + tokenized string) per record.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Text Representation</strong>: <strong>SELFIES</strong> used to avoid invalid intermediate strings. Tokenized via Keras tokenizer.
<ul>
<li><em>Dataset 1 Tokens</em>: 27 unique tokens. Max length 47.</li>
<li><em>Dataset 2/3 Tokens</em>: 61 unique tokens (due to stereo/ion tokens).</li>
</ul>
</li>
<li><strong>Augmentation</strong>: Implemented using <code>imgaug</code> python package. Random application of:
<ul>
<li>Gaussian/Average Blur, Additive Gaussian Noise, Salt &amp; Pepper, Coarse Dropout, Gamma Contrast, Sharpen, Brightness.</li>
</ul>
</li>
<li><strong>Optimization</strong>: Adam optimizer with a custom learning rate scheduler (following the &ldquo;Attention is all you need&rdquo; paper).</li>
</ul>
<h3 id="models">Models</h3>
<p>The final architecture is an <strong>Image-to-SELFIES Transformer</strong>.</p>
<ul>
<li><strong>Encoder (Feature Extractor)</strong>:
<ul>
<li><strong>EfficientNet-B3</strong> (pre-trained on Noisy-student).</li>
<li>Input: $299 \times 299 \times 3$ images (normalized -1 to 1).</li>
<li>Output Feature Vector: $10 \times 10 \times 1536$.</li>
</ul>
</li>
<li><strong>Decoder (Transformer)</strong>:
<ul>
<li>4 Encoder-Decoder layers.</li>
<li>8 Parallel Attention Heads.</li>
<li>Dimension size: 512.</li>
<li>Feed-forward size: 2048.</li>
<li>Dropout: 0.1.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation was performed on a held-out test set (10% of total data) selected via RDKit MaxMin algorithm for diversity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto 1.0</strong></td>
          <td><strong>96.47%</strong></td>
          <td>74.57% (1M subset)</td>
          <td>Percentage of predictions with perfect fingerprint match (Dataset 1, 35M training).</td>
      </tr>
      <tr>
          <td><strong>Avg Tanimoto</strong></td>
          <td><strong>0.9923</strong></td>
          <td>0.9371 (1M subset)</td>
          <td>Average similarity score (Dataset 1, 35M training).</td>
      </tr>
      <tr>
          <td><strong>Isomorphism</strong></td>
          <td><strong>99.75%</strong></td>
          <td>-</td>
          <td>Percentage of Tanimoto 1.0 predictions that are structurally identical (checked via InChI).</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Hardware</strong>: TPU v3-8 (Google Cloud). TPU v3-32 was tested but v3-8 was chosen for cost-effectiveness.</li>
<li><strong>Comparison Hardware</strong>: Nvidia Tesla V100 (32GB GPU).</li>
<li><strong>Performance</strong>:
<ul>
<li>TPU v3-8 was ~4x faster than V100 GPU.</li>
<li>1 Million molecule model convergence: ~8.5 hours on TPU vs ~30 hours on GPU.</li>
<li>Largest model (35M) took &lt; 14 days on TPU.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMER10Deep2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{DECIMER 1.0: Deep Learning for Chemical Image Recognition Using Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{DECIMER 1.0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = <span style="color:#e6db74">{dec}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{61}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00538-8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1186/s13321-021-00538-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemPix: Hand-Drawn Hydrocarbon Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chempix/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chempix/</guid><description>Deep learning framework using CNN-LSTM image captioning to convert hand-drawn hydrocarbon structures into SMILES strings with 76% accuracy.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Weir, H., Thompson, K., Woodward, A., Choi, B., Braun, A., &amp; Martínez, T. J. (2021). ChemPix: Automated Recognition of Hand-Drawn Hydrocarbon Structures Using Deep Learning. <em>Chemical Science</em>, 12(31), 10622-10633. <a href="https://doi.org/10.1039/D1SC02957F">https://doi.org/10.1039/D1SC02957F</a></p>
<p><strong>Publication</strong>: Chemical Science 2021</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/mtzgroup/ChemPixCH">GitHub Repository</a></li>
</ul>
<h2 id="paper-classification-and-core-contribution">Paper Classification and Core Contribution</h2>
<p>This is primarily a <strong>Method</strong> paper, with a secondary contribution as a <strong>Resource</strong> paper.</p>
<p>The paper&rsquo;s core contribution is the <strong>ChemPix architecture and training strategy</strong> using neural image captioning (CNN-LSTM) to convert hand-drawn chemical structures to SMILES. The extensive ablation studies on synthetic data generation (augmentation, degradation, backgrounds) and ensemble learning strategies confirm the methodological focus. The secondary resource contribution includes releasing a curated dataset of hand-drawn hydrocarbons and code for generating synthetic training data.</p>
<h2 id="the-structural-input-bottleneck-in-computational-chemistry">The Structural Input Bottleneck in Computational Chemistry</h2>
<p>Inputting molecular structures into computational chemistry software for quantum calculations is often a bottleneck, requiring domain expertise and cumbersome manual entry in drawing software. While optical chemical structure recognition (OCSR) tools exist, they typically struggle with the noise and variability of hand-drawn sketches. There is a practical need for a tool that allows chemists to simply photograph a hand-drawn sketch and immediately convert it into a machine-readable format (<a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>), making computational workflows more accessible.</p>
<h2 id="cnn-lstm-image-captioning-and-robust-synthetic-generalization">CNN-LSTM Image Captioning and Robust Synthetic Generalization</h2>
<ol>
<li><strong>Image Captioning Paradigm</strong>: The authors treat the problem as <strong>neural image captioning</strong>, using an encoder-decoder (CNN-LSTM) framework to &ldquo;translate&rdquo; an image directly to a SMILES string. This avoids the complexity of explicit atom/bond detection and graph assembly.</li>
<li><strong>Synthetic Data Engineering</strong>: The paper introduces a rigorous synthetic data generation pipeline that transforms clean RDKit-generated images into &ldquo;pseudo-hand-drawn&rdquo; images via randomized backgrounds, degradation, and heavy augmentation. This allows the model to achieve &gt;50% accuracy on real hand-drawn data without ever seeing it during pre-training.</li>
<li><strong>Ensemble Uncertainty Estimation</strong>: The method utilizes a &ldquo;committee&rdquo; (ensemble) of networks to improve accuracy and estimate confidence based on vote agreement, providing users with reliability indicators for predictions.</li>
</ol>
<h2 id="extensive-ablation-and-real-world-evaluation">Extensive Ablation and Real-World Evaluation</h2>
<ol>
<li><strong>Ablation Studies on Data Pipeline</strong>: The authors trained models on datasets generated at different stages of the pipeline (Clean RDKit $\rightarrow$ Augmented $\rightarrow$ Backgrounds $\rightarrow$ Degraded) to quantify the value of each transformation in bridging the synthetic-to-real domain gap.</li>
<li><strong>Sample Size Scaling</strong>: They analyzed performance scaling by training on synthetic dataset sizes ranging from 50,000 to 500,000 images to understand data requirements.</li>
<li><strong>Real-world Validation</strong>: The model was evaluated on a held-out test set of hand-drawn images collected via a custom web app, providing genuine out-of-distribution testing.</li>
<li><strong>Fine-tuning Experiments</strong>: Comparisons of &ldquo;zero-shot&rdquo; performance (training only on synthetic data) versus fine-tuning with a small fraction of real hand-drawn data to assess the value of limited real-world supervision.</li>
</ol>
<h2 id="state-of-the-art-hand-drawn-ocsr-performance">State-of-the-Art Hand-Drawn OCSR Performance</h2>
<ol>
<li>
<p><strong>Pipeline Efficacy</strong>: Augmentation and image degradation were the most critical factors for generalization, improving accuracy from ~8% to nearly 50% on hand-drawn data. Adding backgrounds had a surprisingly negligible effect compared to degradation.</p>
</li>
<li>
<p><strong>State-of-the-Art Performance</strong>: The final ensemble model achieved <strong>76% accuracy</strong> (top-1) and <strong>86% accuracy</strong> (top-3) on the hand-drawn test set, demonstrating practical viability for real-world use.</p>
</li>
<li>
<p><strong>Synthetic Generalization</strong>: A model trained on 500,000 synthetic images achieved &gt;50% accuracy on real hand-drawn data without any fine-tuning, validating the synthetic data generation strategy as a viable alternative to expensive manual labeling.</p>
</li>
<li>
<p><strong>Ensemble Benefits</strong>: The voting committee approach improved accuracy and provided interpretable uncertainty estimates through vote distributions.</p>
</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study relies on two primary data sources: a massive synthetic dataset generated procedurally and a smaller collected dataset of real drawings.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Synthetic (RDKit)</td>
          <td>500,000 images</td>
          <td>Generated via RDKit with &ldquo;heavy&rdquo; augmentation: rotation ($0-360°$), blur, salt+pepper noise, and background texture addition.</td>
      </tr>
      <tr>
          <td><strong>Fine-tuning</strong></td>
          <td>Hand-Drawn (Real)</td>
          <td>~600 images</td>
          <td>Crowdsourced via a web app; used for fine-tuning and validation.</td>
      </tr>
      <tr>
          <td><strong>Backgrounds</strong></td>
          <td>Texture Images</td>
          <td>1,052 images</td>
          <td>A pool of unlabeled texture photos (paper, desks, shadows) used to generate synthetic backgrounds.</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Generation Parameters</strong>:</p>
<ul>
<li><strong>Augmentations</strong>: Rotation, Resize ($200-300px$), Blur, Dilate, Erode, Aspect Ratio, Affine transform ($\pm 20px$), Contrast, Quantize, Sharpness</li>
<li><strong>Backgrounds</strong>: Randomly translated $\pm 100$ pixels and reflected</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Ensemble Voting</strong><br>
A committee of networks casts votes for the predicted SMILES string. The final prediction is the one with the highest vote count. Validity of SMILES is checked using RDKit.</p>
<p><strong>Beam Search</strong><br>
Used in the decoding layer with a beam width of $k=5$ to explore multiple potential SMILES strings. It approximates the sequence $\mathbf{\hat{y}}$ that maximizes the joint probability:</p>
<p>$$ \mathbf{\hat{y}} = \arg\max_{\mathbf{y}} \sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$</p>
<p><strong>Optimization</strong>:</p>
<ul>
<li>
<p><strong>Optimizer</strong>: Adam</p>
</li>
<li>
<p><strong>Learning Rate</strong>: $1 \x10^{-4}$</p>
</li>
<li>
<p><strong>Batch Size</strong>: 20</p>
</li>
<li>
<p><strong>Loss Function</strong>: Cross-entropy loss across the sequence of $T$ tokens, computed as:</p>
<p>$$ \mathcal{L} = -\sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$</p>
<p>where $\mathbf{x}$ is the image representation and $y_t$ is the predicted SMILES character. This is calculated as perplexity for validation.</p>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is a standard image captioning model (Show, Attend and Tell style) adapted for chemical structures.</p>
<p><strong>Encoder (CNN)</strong>:</p>
<ul>
<li><strong>Input</strong>: 256x256 images (implicit from scaling operations)</li>
<li><strong>Structure</strong>: 4 blocks of Conv2D + MaxPool
<ul>
<li>Block 1: 64 filters, (3,3) kernel</li>
<li>Block 2: 128 filters, (3,3) kernel</li>
<li>Block 3: 256 filters, (3,3) kernel</li>
<li>Block 4: 512 filters, (3,3) kernel</li>
</ul>
</li>
<li><strong>Activation</strong>: ReLU throughout</li>
</ul>
<p><strong>Decoder (LSTM)</strong>:</p>
<ul>
<li><strong>Hidden Units</strong>: 512</li>
<li><strong>Embedding Dimension</strong>: 80</li>
<li><strong>Attention</strong>: Mechanism with intermediary vector dimension of 512</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Exact SMILES match accuracy (character-by-character identity between predicted and ground truth SMILES)</li>
<li><strong>Perplexity</strong>: Used for saving model checkpoints (minimizing uncertainty)</li>
<li><strong>Top-k Accuracy</strong>: Reported for $k=1$ (76%) and $k=3$ (86%)</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{weir2021chempix,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemPix: Automated Recognition of Hand-Drawn Hydrocarbon Structures Using Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Weir, Hayley and Thompson, Keiran and Woodward, Amelia and Choi, Benjamin and Braun, Augustin and Mart{\&#39;i}nez, Todd J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{31}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{10622--10633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D1SC02957F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ABC-Net: Divide-and-Conquer SMILES Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/abc-net/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/abc-net/</guid><description>Deep learning OCSR model using keypoint estimation to detect atom and bond centers for graph-based molecular structure recognition.</description><content:encoded><![CDATA[<h2 id="contribution-and-paper-type">Contribution and Paper Type</h2>
<p><strong>Method</strong>. The paper proposes a novel architectural framework (ABC-Net) for Optical Chemical Structure Recognition (OCSR). It reformulates the problem from image captioning (sequence generation) to keypoint estimation (pixel-wise detection), backed by ablation studies on noise and comparative benchmarks against state-of-the-art tools.</p>
<h2 id="motivation-for-keypoint-based-ocsr">Motivation for Keypoint-Based OCSR</h2>
<ul>
<li><strong>Inefficiency of Rule-Based Methods</strong>: Traditional tools (OSRA, MolVec) rely on hand-coded rules that are brittle, require domain expertise, and fail to handle the wide variance in molecular drawing styles.</li>
<li><strong>Data Inefficiency of Captioning Models</strong>: Recent Deep Learning approaches (like DECIMER, Img2mol) treat OCSR as image captioning (Image-to-SMILES). This is data-inefficient because canonical SMILES require learning traversal orders, necessitating millions of training examples.</li>
<li><strong>Goal</strong>: To create a scalable, data-efficient model that predicts graph structures directly by detecting atomic/bond primitives.</li>
</ul>
<h2 id="abc-nets-divide-and-conquer-architecture">ABC-Net&rsquo;s Divide-and-Conquer Architecture</h2>
<ul>
<li><strong>Divide-and-Conquer Strategy</strong>: ABC-Net breaks the problem down into detecting <strong>atom centers</strong> and <strong>bond centers</strong> as independent keypoints.</li>
<li><strong>Keypoint Estimation</strong>: It leverages a Fully Convolutional Network (FCN) to generate heatmaps for object centers. This is inspired by computer vision techniques like CornerNet and CenterNet.</li>
<li><strong>Angle-Based Bond Detection</strong>: To handle overlapping bonds, the model classifies bond angles into 60 distinct bins ($0-360°$) at detected bond centers, allowing separation of intersecting bonds.</li>
<li><strong>Implicit Hydrogen Prediction</strong>: The model explicitly predicts the number of implicit hydrogens for aromatic heteroatoms to resolve ambiguity in dearomatization.</li>
</ul>
<h2 id="experimental-setup-and-synthetic-data">Experimental Setup and Synthetic Data</h2>
<ul>
<li><strong>Dataset Construction</strong>: Synthetic dataset of 100,000 molecules from ChEMBL, rendered using two different engines (RDKit and Indigo) to ensure style diversity.</li>
<li><strong>Baselines</strong>: Compared against two rule-based methods (MolVec, OSRA) and one deep learning method (Img2mol).</li>
<li><strong>Robustness Testing</strong>: Evaluated on the external UOB dataset (real-world images) and synthetic images with varying levels of salt-and-pepper noise (up to $p=0.6$).</li>
<li><strong>Data Efficiency</strong>: Analyzed performance scaling with training set size (10k to 160k images).</li>
</ul>
<h2 id="results-generalization-and-noise-robustness">Results, Generalization, and Noise Robustness</h2>
<ul>
<li><strong>Superior Accuracy</strong>: ABC-Net achieved <strong>94-98% accuracy</strong> on synthetic test sets, significantly outperforming MolVec (&lt;50%), OSRA (~61%), and Img2mol (~89-93%).</li>
<li><strong>Generalization</strong>: On the external UOB benchmark, ABC-Net achieved <strong>&gt;95% accuracy</strong>, whereas the deep learning baseline (Img2mol) dropped to 78.2%, indicating better generalization.</li>
<li><strong>Data Efficiency</strong>: The model reached ~95% performance with only 80,000 training images, proving it requires orders of magnitude less data than captioning-based models (which often use millions).</li>
<li><strong>Noise Robustness</strong>: Performance remained stable (&lt;2% drop) with noise levels up to $p=0.1$. Even at extreme noise ($p=0.6$), Tanimoto similarity remained high, suggesting the model recovers most substructures even when exact matches fail.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/zhang-xuan1314/ABC-Net">ABC-Net Repository</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Official implementation. Missing requirements.txt and pre-trained weights.</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Status: Partially Reproducible</strong>. The code is provided, but key components like the pre-trained weights, exact training environment dependencies, and the generated synthetic datasets are missing from the open-source release, making exact reproduction difficult.</p>
<h3 id="data">Data</h3>
<p>The authors constructed a synthetic dataset because labeled pixel-wise OCSR data is unavailable.</p>
<ul>
<li><strong>Source</strong>: ChEMBL database</li>
<li><strong>Filtering</strong>: Excluded molecules with &gt;50 non-H atoms or rare atom types/charges (&lt;1000 occurrences).</li>
<li><strong>Sampling</strong>: 100,000 unique SMILES selected such that every atom type/charge appears in at least 1,000 compounds.</li>
<li><strong>Generation</strong>: Images generated via <strong>RDKit</strong> and <strong>Indigo</strong> libraries.
<ul>
<li><em>Augmentation</em>: Varied bond thickness, label mode, orientation, and aromaticity markers.</li>
<li><em>Resolution</em>: $512 \times 512$ pixels.</li>
<li><em>Noise</em>: Salt-and-pepper noise added during training ($P$ = prob of background flip, $Q = 50P$).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL (RDKit/Indigo)</td>
          <td>80k</td>
          <td>8:1:1 split (Train/Val/Test)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>UOB Dataset</td>
          <td>~5.7k images</td>
          <td>External benchmark from Univ. of Birmingham</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Keypoint Detection (Heatmaps)</strong></p>
<ul>
<li>
<p><strong>Down-sampling</strong>: Input $512 \times 512$ → Output $102 \times 102$ (stride 5).</p>
</li>
<li>
<p><strong>Label Softening</strong>: To handle discretization error, ground truth peaks are set to 1, first-order neighbors to 0.95, others to 0.</p>
</li>
<li>
<p><strong>Loss Function</strong>: Penalty-reduced pixel-wise binary focal loss (variants of CornerNet loss). The loss formulation is given as:</p>
<p>$$ L_{det} = - \frac{1}{N} \sum_{c,x,y} \begin{cases} (1 - \hat{Y}_{c,x,y})^{\alpha} \log(\hat{Y}_{c,x,y}) &amp; \text{if } Y_{c,x,y} = 1 \\ (1 - Y_{c,x,y})^{\beta} (\hat{Y}_{c,x,y})^{\alpha} \log(1 - \hat{Y}_{c,x,y}) &amp; \text{otherwise} \end{cases} $$</p>
<ul>
<li>$\alpha=2$ (focal parameter), $\beta=4$ (penalty reduction).</li>
<li>Weight balancing: Classes &lt;10% frequency weighted 10x.</li>
</ul>
</li>
</ul>
<p><strong>2. Bond Direction Classification</strong></p>
<ul>
<li><strong>Angle Binning</strong>: $360°$ divided into 60 intervals.</li>
<li><strong>Inference</strong>: A bond is detected if the angle probability is a local maximum and exceeds a threshold.</li>
<li><strong>Non-Maximum Suppression (NMS)</strong>: Required for opposite angles (e.g., $30°$ and $210°$) representing the same non-stereo bond.</li>
</ul>
<p><strong>3. Multi-Task Weighting</strong></p>
<ul>
<li>Uses Kendall&rsquo;s uncertainty weighting to balance 8 different loss terms (atom det, bond det, atom type, charge, H-count, bond angle, bond type, bond length).</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: ABC-Net (Custom U-Net / FCN)</p>
<ul>
<li><strong>Input</strong>: $512 \times 512 \times 1$ (Grayscale).</li>
<li><strong>Contracting Path</strong>: 5 steps. Each step has conv-blocks + $2 \times 2$ MaxPool.</li>
<li><strong>Expansive Path</strong>: 3 steps. Transpose-Conv upsampling + Concatenation (Skip Connections).</li>
<li><strong>Heads</strong>: Separate $1 \times 1$ convs for each task map (Atom Heatmap, Bond Heatmap, Property Maps).</li>
<li><strong>Output Dimensions</strong>:
<ul>
<li>Heatmaps: $(1, 102, 102)$</li>
<li>Bond Angles: $(60, 102, 102)$</li>
</ul>
</li>
<li><strong>Pre-trained Weights</strong>: Stated as available in paper, but missing from the public <a href="https://github.com/zhang-xuan1314/ABC-Net">GitHub repository</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Detection</strong>: Precision &amp; Recall (Object detection level).</li>
<li><strong>Regression</strong>: Mean Absolute Error (MAE) for bond lengths.</li>
<li><strong>Structure Recovery</strong>:
<ul>
<li><em>Accuracy</em>: Exact SMILES match rate.</li>
<li><em>Tanimoto</em>: ECFP similarity (fingerprint overlap).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>ABC-Net</th>
          <th>Img2mol (Baseline)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (UOB)</strong></td>
          <td><strong>96.1%</strong></td>
          <td>78.2%</td>
          <td>Non-stereo subset</td>
      </tr>
      <tr>
          <td><strong>Accuracy (Indigo)</strong></td>
          <td><strong>96.4%</strong></td>
          <td>89.5%</td>
          <td>Non-stereo subset</td>
      </tr>
      <tr>
          <td><strong>Tanimoto (UOB)</strong></td>
          <td><strong>0.989</strong></td>
          <td>0.953</td>
          <td>Higher substructure recovery</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Configuration</strong>: 15 epochs, Batch size 64.</li>
<li><strong>Optimization</strong>: Adam Optimizer. LR $2.5 \times 10^{-4}$ (first 5 epochs) → $2.5 \times 10^{-5}$ (last 10).</li>
<li><strong>Compute</strong>: &ldquo;High-Performance Computing Center&rdquo; mentioned, specific GPU model not listed, but method described as &ldquo;efficient&rdquo; on GPU.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Yi, J.-C., Yang, G.-P., Wu, C.-K., Hou, T.-J., &amp; Cao, D.-S. (2022). ABC-Net: A divide-and-conquer based deep learning architecture for SMILES recognition from molecular images. <em>Briefings in Bioinformatics</em>, 23(2), bbac033. <a href="https://doi.org/10.1093/bib/bbac033">https://doi.org/10.1093/bib/bbac033</a></p>
<p><strong>Publication</strong>: Briefings in Bioinformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/zhang-xuan1314/ABC-Net">GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhangABCNetDivideandconquerBased2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{ABC-Net: A Divide-and-Conquer Based Deep Learning Architecture for {SMILES} Recognition from Molecular Images}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Xiao-Chen and Yi, Jia-Cai and Yang, Guo-Ping and Wu, Cheng-Kun and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{bbac033}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1093/bib/bbac033}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Img2Mol: Accurate SMILES from Molecular Depictions</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/img2mol/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/img2mol/</guid><description>Two-stage CNN approach for converting molecular images to SMILES using CDDD embeddings and extensive data augmentation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Clevert, D.-A., Le, T., Winter, R., &amp; Montanari, F. (2021). Img2Mol - accurate SMILES recognition from molecular graphical depictions. <em>Chemical Science</em>, 12(42), 14174-14181. <a href="https://doi.org/10.1039/D1SC01839F">https://doi.org/10.1039/D1SC01839F</a></p>
<p><strong>Publication</strong>: Chemical Science (2021)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/bayer-science-for-a-better-life/Img2Mol">GitHub Repository</a></li>
<li><a href="https://doi.org/10.1039/D1SC01839F">Paper on Royal Society of Chemistry</a></li>
</ul>
<h2 id="method-classification">Method Classification</h2>
<p>This is a <strong>method paper</strong> that introduces Img2Mol, a deep learning system for Optical Chemical Structure Recognition (OCSR). The work focuses on building a fast, accurate, and robust system for converting molecular structure depictions into machine-readable SMILES strings.</p>
<h2 id="systematization-and-motivation">Systematization and Motivation</h2>
<p>Vast amounts of chemical knowledge exist only as images in scientific literature and patents, making this data inaccessible for computational analysis, database searches, or machine learning pipelines. Manually extracting this information is slow and error-prone, creating a bottleneck for drug discovery and chemical research.</p>
<p>While rule-based OCSR systems like OSRA, MolVec, and Imago exist, they are brittle. Small variations in drawing style or image quality can cause them to fail. The authors argue that a deep learning approach, trained on diverse synthetic data, can generalize better across different depiction styles and handle the messiness of real-world images more reliably.</p>
<h2 id="two-stage-architecture-and-core-novelty">Two-Stage Architecture and Core Novelty</h2>
<p>The novelty lies in a two-stage architecture that separates perception from decoding, combined with aggressive data augmentation to ensure robustness. The key contributions are:</p>
<p><strong>1. Two-Stage Architecture with CDDD Embeddings</strong></p>
<p>Img2Mol uses an intermediate representation to predict SMILES from pixels. A <strong>custom CNN encoder</strong> maps the input image to a 512-dimensional <strong>Continuous and Data-Driven Molecular Descriptor (CDDD)</strong> embedding - a pre-trained, learned molecular representation that smoothly captures chemical similarity. A <strong>pre-trained decoder</strong> then converts this CDDD vector into the final canonical SMILES string.</p>
<p>This two-stage design has several advantages:</p>
<ul>
<li>The CDDD space is continuous and chemically meaningful, so nearby embeddings correspond to structurally similar molecules. This makes the regression task easier than learning discrete token sequences directly.</li>
<li>The decoder is pre-trained and fixed, so the CNN only needs to learn the image → CDDD mapping. This decouples the visual recognition problem from the sequence generation problem.</li>
<li>CDDD embeddings naturally enforce chemical validity constraints, reducing the risk of generating nonsensical structures.</li>
</ul>
<p><strong>2. Extensive Data Augmentation for Robustness</strong></p>
<p>The model was trained on 11.1 million unique molecules from ChEMBL and PubChem, but the critical insight is how the training images were generated. To expose the CNN to maximum variation in depiction styles, the authors:</p>
<ul>
<li>Used <strong>three different cheminformatics libraries</strong> (RDKit, OEChem, Indigo) to render images, each with its own drawing conventions</li>
<li>Applied <strong>wide-ranging augmentations</strong>: varying bond thickness, font size, rotation, resolution (190-2500 pixels), and other stylistic parameters</li>
<li><strong>Over-sampled larger molecules</strong> to improve performance on complex structures, which are underrepresented in chemical databases</li>
</ul>
<p>This ensures the network rarely sees the same depiction of a molecule twice, forcing it to learn invariant features.</p>
<p><strong>3. Fast Inference</strong></p>
<p>Because the architecture is a simple CNN followed by a fixed decoder, inference is very fast - especially compared to rule-based systems that rely on iterative graph construction algorithms. This makes Img2Mol practical for large-scale document mining.</p>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The evaluation focused on demonstrating that Img2Mol is more accurate, robust, and generalizable than existing rule-based systems:</p>
<ol>
<li>
<p><strong>Benchmark Comparisons</strong>: Img2Mol was tested on several standard OCSR benchmarks - USPTO (patent images), University of Birmingham (UoB), and CLEF datasets - against three open-source baselines: <strong>OSRA, MolVec, and Imago</strong>. Notably, no deep learning baselines were available at the time for comparison.</p>
</li>
<li>
<p><strong>Resolution and Molecular Size Analysis</strong>: The initial model, <code>Img2Mol(no aug.)</code>, was evaluated across different image resolutions and molecule sizes (measured by number of atoms) to understand failure modes. This revealed that:</p>
<ul>
<li>Performance degraded for molecules with &gt;35 atoms</li>
<li>Very high-resolution images lost detail when downscaled to the fixed input size</li>
<li>Low-resolution images (where rule-based methods failed completely) were handled well</li>
</ul>
</li>
<li>
<p><strong>Data Augmentation Ablation</strong>: A final model, <strong>Img2Mol</strong>, was trained with the full augmentation pipeline (wider resolution range, over-sampling of large molecules). Performance was compared to the initial version to quantify the effect of augmentation.</p>
</li>
<li>
<p><strong>Depiction Library Robustness</strong>: The model was tested on images generated by each of the three rendering libraries separately to confirm that training on diverse styles improved generalization.</p>
</li>
<li>
<p><strong>Generalization Tests</strong>: Img2Mol was evaluated on real-world patent images from the <strong>STAKER</strong> dataset, which were not synthetically generated. This tested whether the model could transfer from synthetic training data to real documents.</p>
</li>
<li>
<p><strong>Hand-Drawn Molecule Recognition</strong>: As an exploratory test, the authors evaluated performance on hand-drawn molecular structures - a task the model was never trained for - to see if the learned features could generalize to completely different visual styles.</p>
</li>
<li>
<p><strong>Speed Benchmarking</strong>: Inference time was measured and compared to rule-based baselines to demonstrate the practical efficiency of the approach.</p>
</li>
</ol>
<h2 id="results-conclusions-and-limitations">Results, Conclusions, and Limitations</h2>
<ul>
<li>
<p><strong>Substantial Performance Gains</strong>: Img2Mol outperformed all three rule-based baselines on nearly every benchmark. Accuracy was measured both as exact SMILES match and as <strong>Tanimoto similarity</strong> (a chemical fingerprint-based metric that measures structural similarity). Even when Img2Mol didn&rsquo;t predict the exact molecule, it often predicted a chemically similar one.</p>
</li>
<li>
<p><strong>Robustness Across Conditions</strong>: The full Img2Mol model (with aggressive augmentation) showed consistent performance across all image resolutions and molecule sizes. In contrast, rule-based systems were &ldquo;brittle&rdquo; - performance dropped sharply with minor perturbations to image quality or style.</p>
</li>
<li>
<p><strong>Depiction Library Invariance</strong>: Img2Mol&rsquo;s performance was stable across all three rendering libraries (RDKit, OEChem, Indigo), validating the multi-library training strategy. Rule-based methods struggled particularly with RDKit-generated images.</p>
</li>
<li>
<p><strong>Strong Generalization to Real-World Data</strong>: Despite being trained exclusively on synthetic images, Img2Mol performed well on real patent images from the STAKER dataset. This suggests the augmentation strategy successfully captured the diversity of real-world depictions.</p>
</li>
<li>
<p><strong>Overfitting in Baselines</strong>: Rule-based methods performed surprisingly well on older benchmarks (USPTO, UoB, CLEF) but failed on newer datasets (Img2Mol&rsquo;s test set, STAKER). This suggests they may be implicitly tuned to specific drawing conventions in legacy datasets.</p>
</li>
<li>
<p><strong>Limited Hand-Drawn Recognition</strong>: Img2Mol could recognize simple hand-drawn structures but struggled with complex or large molecules. This is unsurprising given the lack of hand-drawn data in training, but it highlights a potential avenue for future work.</p>
</li>
<li>
<p><strong>Speed Advantage</strong>: Img2Mol was substantially faster than rule-based competitors, especially on high-resolution images. This makes it practical for large-scale literature mining.</p>
</li>
</ul>
<p>The work establishes that deep learning can outperform traditional rule-based OCSR systems when combined with a principled two-stage architecture and comprehensive data augmentation. The CDDD embedding acts as a bridge between visual perception and chemical structure, providing a chemically meaningful intermediate representation that improves both accuracy and robustness. The focus on synthetic data diversity proves to be an effective strategy for generalizing to real-world documents.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Custom 8-layer Convolutional Neural Network (CNN) encoder</p>
<ul>
<li><strong>Input</strong>: $224 \times 224$ pixel grayscale images</li>
<li><strong>Backbone Structure</strong>: 8 convolutional layers organized into 3 stacks, followed by 3 fully connected layers
<ul>
<li><strong>Stack 1</strong>: 3 Conv layers ($7 \times 7$ filters, stride 3, padding 4) + Max Pooling</li>
<li><strong>Stack 2</strong>: 2 Conv layers + Max Pooling</li>
<li><strong>Stack 3</strong>: 3 Conv layers + Max Pooling</li>
<li><strong>Head</strong>: 3 fully connected layers</li>
</ul>
</li>
<li><strong>Output</strong>: 512-dimensional CDDD embedding vector</li>
</ul>
<p><strong>Decoder</strong>: Pre-trained CDDD decoder (from Winter et al.) - fixed during training, not updated</p>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>: Mean Squared Error (MSE) regression minimizing the distance between the predicted and true embeddings:</p>
<p>$$
l(d) = l(\text{cddd}_{\text{true}} - \text{cddd}_{\text{predicted}})
$$</p>
<p><strong>Optimizer</strong>: AdamW with initial learning rate $10^{-4}$</p>
<p><strong>Training Schedule</strong>:</p>
<ul>
<li>Batch size: 256</li>
<li>Training duration: 300 epochs</li>
<li>Plateau scheduler: Multiplies learning rate by 0.7 if validation loss plateaus for 10 epochs</li>
<li>Early stopping: Triggered if no improvement in validation loss for 50 epochs</li>
</ul>
<p><strong>Noise Tolerance</strong>: The decoder requires the CNN to predict embeddings with noise level $\sigma \le 0.15$ to achieve &gt;90% accuracy</p>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong>: 11.1 million unique molecules from ChEMBL and PubChem</p>
<p><strong>Splits</strong>: Approximately 50,000 examples each for validation and test sets</p>
<p><strong>Synthetic Image Generation</strong>:</p>
<ul>
<li>Three cheminformatics libraries: RDKit, OEChem, and Indigo</li>
<li>Augmentations: Resolution (190-2500 pixels), rotation, bond thickness, font size</li>
<li>Salt stripping: Keep only the largest fragment</li>
<li>Over-sampling: Larger molecules (&gt;35 atoms) over-sampled to improve performance</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li>Exact SMILES match accuracy</li>
<li>Tanimoto similarity (chemical fingerprint-based structural similarity)</li>
</ul>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li>USPTO (patent images)</li>
<li>University of Birmingham (UoB)</li>
<li>CLEF dataset</li>
<li>STAKER (real-world patent images)</li>
<li>Hand-drawn molecular structures (exploratory)</li>
</ul>
<p><strong>Baselines</strong>: OSRA, MolVec, Imago (rule-based systems)</p>
<h3 id="hardware">Hardware</h3>
<p>⚠️ <strong>Unspecified in paper or supplementary materials.</strong> Inference speed reported as ~4 minutes for 5000 images; training hardware (GPU model, count) is undocumented.</p>
<h3 id="known-limitations">Known Limitations</h3>
<p><strong>Molecular Size</strong>: Performance degrades for molecules with &gt;35 atoms. This is partly a property of the CDDD latent space itself - for larger molecules, the &ldquo;volume of decodable latent space&rdquo; shrinks, making the decoder more sensitive to small noise perturbations in the predicted embedding.</p>
]]></content:encoded></item><item><title>Handwritten Chemical Ring Recognition with NNs</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/hewahi-ring-recognition-2008/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/hewahi-ring-recognition-2008/</guid><description>A two-phase neural network approach for recognizing handwritten heterocyclic chemical rings with ~94% accuracy.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hewahi, N., Nounou, M. N., Nassar, M. S., Abu-Hamad, M. I., &amp; Abu-Hamad, H. I. (2008). Chemical Ring Handwritten Recognition Based on Neural Networks. <em>Ubiquitous Computing and Communication Journal</em>, 3(3).</p>
<p><strong>Publication</strong>: Ubiquitous Computing and Communication Journal 2008</p>
<h2 id="contribution-recognition-architecture-for-heterocyclic-rings">Contribution: Recognition Architecture for Heterocyclic Rings</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>It proposes a specific algorithmic architecture (the &ldquo;Classifier-Recognizer Approach&rdquo;) to solve a pattern recognition problem. The rhetorical structure centers on defining three variations of a method, performing ablation-like comparisons between them (Whole Image vs. Lower Part), and demonstrating superior performance metrics (~94% accuracy) for the proposed technique.</p>
<h2 id="motivation-enabling-sketch-based-chemical-search">Motivation: Enabling Sketch-Based Chemical Search</h2>
<p>The authors identify a gap in existing OCR and handwriting recognition research, which typically focuses on alphanumeric characters or whole words.</p>
<ul>
<li><strong>Missing Capability</strong>: Recognition of specific <em>heterocyclic chemical rings</em> (23 types) had not been performed previously.</li>
<li><strong>Practical Utility</strong>: Existing chemical search engines require text-based queries (names); this work enables &ldquo;backward&rdquo; search where a user can draw a ring to find its information.</li>
<li><strong>Educational/Professional Aid</strong>: Useful for chemistry departments and mobile applications where chemists can sketch formulas on screens.</li>
</ul>
<h2 id="innovation-the-classifier-recognizer-pipeline">Innovation: The Classifier-Recognizer Pipeline</h2>
<p>The core novelty is the <strong>two-phase &ldquo;Classifier-Recognizer&rdquo; architecture</strong> designed to handle the visual similarity of heterocyclic rings:</p>
<ol>
<li><strong>Phase 1 (Classifier)</strong>: A neural network classifies the ring into one of four broad categories (S, N, O, Others) based solely on the <em>upper part</em> of the image (40x15 pixels).</li>
<li><strong>Phase 2 (Recognizer)</strong>: A class-specific neural network identifies the exact ring.</li>
<li><strong>Optimization</strong>: The most successful variation (&ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo;) uses only the <em>lower part</em> of the image and <em>odd rows</em> (half-grid) to reduce input dimensionality and computation time while improving accuracy. This effectively subsamples the input grid matrix $M \in \mathbb{R}^{H \times W}$ to a reduced matrix $M_{\text{sub}}$:
$$ M_{\text{sub}} = { m_{i,j} \in M \mid i \text{ is odd} } $$</li>
</ol>
<h2 id="experimental-setup-and-network-variations">Experimental Setup and Network Variations</h2>
<p>The authors conducted a comparative study of three methodological variations:</p>
<ol>
<li><strong>Whole Image Recognizer</strong>: Uses the full image.</li>
<li><strong>Whole Image (Half Size Grid)</strong>: Uses only odd rows ($20 \x40$ pixels).</li>
<li><strong>Lower Part (Half Size Grid)</strong>: Uses the lower part of the image with odd rows (the proposed method).</li>
</ol>
<p><strong>Setup</strong>:</p>
<ul>
<li><strong>Dataset</strong>: 23 types of heterocyclic rings.</li>
<li><strong>Training</strong>: 1500 samples (distributed across S, N, O, and Others classes).</li>
<li><strong>Testing</strong>: 1150 samples.</li>
<li><strong>Metric</strong>: Recognition accuracy (Performance %) and Error %.
Results: High Accuracy via Dimension Reduction</li>
</ul>
<h2 id="what-were-the-outcomes-and-conclusions-drawn">What were the outcomes and conclusions drawn?</h2>
<ul>
<li><strong>Superior Method</strong>: The &ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo; achieved the best performance (~94% overall).</li>
<li><strong>High Classifier Accuracy</strong>: The first phase (classification into S/N/O/Other) is highly robust, with 100% accuracy for class &lsquo;S&rsquo; and &gt;97% for others.</li>
<li><strong>Class &lsquo;Others&rsquo; Difficulty</strong>: The &lsquo;Others&rsquo; class showed lower performance (~90-93%) compared to S/N/O due to the higher complexity and similarity of rings in that category.</li>
<li><strong>Efficiency</strong>: The half-grid approach significantly reduced input size and training time without sacrificing accuracy.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset consists of handwritten samples of 23 specific heterocyclic rings.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1500 samples</td>
          <td style="text-align: left">Split: 300 (S), 400 (N), 400 (O), 400 (Others)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1150 samples</td>
          <td style="text-align: left">Split: 150 (S), 300 (O), 400 (N), 300 (Others)</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing Steps</strong>:</p>
<ol>
<li><strong>Monochrome Conversion</strong>: Convert image to monochrome bitmap.</li>
<li><strong>Grid Scaling</strong>: Convert drawing area (regardless of original size) to a fixed <strong>40x40</strong> grid.</li>
<li><strong>Bounding</strong>: Scale the ring shape itself to fit the 40x40 grid.</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<p><strong>The &ldquo;Lower Part with Half Size&rdquo; Pipeline</strong>:</p>
<ol>
<li><strong>Cut Point</strong>: A horizontal midline is defined; the algorithm separates the &ldquo;Upper Part&rdquo; and &ldquo;Lower Part&rdquo;.</li>
<li><strong>Phase 1 Input</strong>: The <strong>Upper Part</strong> (rows 0-15 approx, scaled) is fed to the Classifier NN to determine the class (S, N, O, or Others).</li>
<li><strong>Phase 2 Input</strong>:
<ul>
<li>For classes <strong>S, N, O</strong>: The <strong>Lower Part</strong> of the image is used.</li>
<li>For class <strong>Others</strong>: The <strong>Whole Ring</strong> is used.</li>
</ul>
</li>
<li><strong>Dimensionality Reduction</strong>: For the recognizer networks, only <strong>odd rows</strong> are used (effectively a 20x40 input grid) to reduce inputs from 1600 to 800.</li>
</ol>
<h3 id="models">Models</h3>
<p>The system uses multiple distinct Feed-Forward Neural Networks (Backpropagation is implied by &ldquo;training&rdquo; and &ldquo;epochs&rdquo; context, though not explicitly named as the algorithm):</p>
<ul>
<li><strong>Structure</strong>: 1 Classifier NN + 4 Recognizer NNs (one for each class).</li>
<li><strong>Hidden Layers</strong>: The preliminary experiments used 1600 hidden units, but the final optimized method likely uses fewer due to reduced inputs (exact hidden node count for the final method is not explicitly tabulated, but Table 2 mentions &ldquo;50&rdquo; and &ldquo;1000&rdquo; hidden nodes for different trials).</li>
<li><strong>Input Nodes</strong>:
<ul>
<li>Standard: 1600 (40x40).</li>
<li>Optimized: ~800 (20x40 via half-grid).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Testing Results (Lower Part Image Recognizer with Half Size Grid)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Class</th>
          <th style="text-align: left">Samples</th>
          <th style="text-align: left">Correct</th>
          <th style="text-align: left">Accuracy</th>
          <th style="text-align: left">Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>S</strong></td>
          <td style="text-align: left">150</td>
          <td style="text-align: left">147</td>
          <td style="text-align: left"><strong>98.00%</strong></td>
          <td style="text-align: left">2.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>O</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">289</td>
          <td style="text-align: left"><strong>96.33%</strong></td>
          <td style="text-align: left">3.67%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>N</strong></td>
          <td style="text-align: left">400</td>
          <td style="text-align: left">386</td>
          <td style="text-align: left"><strong>96.50%</strong></td>
          <td style="text-align: left">3.50%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Others</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">279</td>
          <td style="text-align: left"><strong>93.00%</strong></td>
          <td style="text-align: left">7.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Overall</strong></td>
          <td style="text-align: left"><strong>1150</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
          <td style="text-align: left"><strong>~94.0%</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{hewahiCHEMICALRINGHANDWRITTEN2008,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{CHEMICAL RING HANDWRITTEN RECOGNITION BASED ON NEURAL NETWORKS}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Hewahi, Nabil and Nounou, Mohamed N and Nassar, Mohamed S and Abu-Hamad, Mohamed I and Abu-Hamad, Husam I}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2008}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Ubiquitous Computing and Communication Journal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Deep Learning for Molecular Structure Extraction</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/staker-deep-learning-2019/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/staker-deep-learning-2019/</guid><description>An end-to-end deep learning approach using U-Net and CNN-LSTM to segment and predict chemical structures from document images.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Staker, J., Marshall, K., Abel, R., &amp; McQuaw, C. (2019). Molecular Structure Extraction From Documents Using Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 59(3), 1017-1029. <a href="https://doi.org/10.1021/acs.jcim.8b00669">https://doi.org/10.1021/acs.jcim.8b00669</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling (JCIM) 2019</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.schrodinger.com/publications">Schrödinger Publication Page</a></li>
</ul>
<h2 id="contribution-type-method-and-resource">Contribution Type: Method and Resource</h2>
<p>This is primarily a <strong>methodological</strong> paper with a secondary <strong>resource</strong> contribution.</p>
<p><strong>Method</strong>: It proposes a novel end-to-end deep learning architecture (Segmentation U-Net + Recognition Encoder-Decoder) to replace traditional rule-based optical chemical structure recognition (OCSR) systems.</p>
<p><strong>Resource</strong>: It details a pipeline for generating massive synthetic datasets (images overlaying patent/journal backgrounds) necessary to train these data-hungry models.</p>
<h2 id="motivation-overcoming-brittle-rule-based-systems">Motivation: Overcoming Brittle Rule-Based Systems</h2>
<p>Existing tools for extracting chemical structures from literature (e.g., OSRA, CLIDE) rely on complex, handcrafted rules and heuristics (edge detection, vectorization). These systems suffer from:</p>
<ol>
<li><strong>Brittleness</strong>: They fail when image quality is low (low resolution, noise) or when artistic styles vary (wavy bonds, crossing lines).</li>
<li><strong>Maintenance difficulty</strong>: Improvements require manual codification of new rules for every edge case, which is difficult to scale.</li>
<li><strong>Data volume</strong>: The explosion of published life science papers (2000+ per day in Medline) creates a need for automated, robust curation tools that humans cannot match.</li>
</ol>
<h2 id="core-innovation-end-to-end-pixel-to-smiles-recognition">Core Innovation: End-to-End Pixel-to-SMILES Recognition</h2>
<p>The authors present the first fully <strong>end-to-end deep learning approach</strong> for this task that operates directly on raw pixels without explicit subcomponent recognition (e.g., detecting atoms and bonds separately). Key innovations include:</p>
<ol>
<li><strong>Pixel-to-SMILES</strong>: Treating structure recognition as an image captioning problem using an encoder-decoder architecture with attention, generating SMILES directly.</li>
<li><strong>Resolution Invariance</strong>: The model is explicitly designed and trained to work on <strong>low-resolution images</strong> (downsampled to ~60 dpi), making it robust to the poor quality of legacy PDF extractions.</li>
<li><strong>Implicit Superatom Handling</strong>: The model learns to recognize and generate sequences for superatoms (e.g., &ldquo;OTBS&rdquo;) contextually.</li>
</ol>
<h2 id="experimental-setup-and-large-scale-synthetic-data">Experimental Setup and Large-Scale Synthetic Data</h2>
<p>The authors validated their approach using a mix of massive synthetic training sets and real-world test sets:</p>
<ol>
<li><strong>Synthetic Generation</strong>: They created a segmentation dataset by overlaying USPTO molecules onto &ldquo;whited-out&rdquo; journal pages.</li>
<li><strong>Ablation/Training</strong>: Metrics were tracked on Indigo (synthetic) and USPTO (real patent images) datasets.</li>
<li><strong>External Validation</strong>:
<ul>
<li><strong>Valko Dataset</strong>: A standard benchmark of 454 heterogeneous images from literature.</li>
<li><strong>Proprietary Dataset</strong>: A collection of images from 47 articles and 5 patents to simulate real-world drug discovery curation.</li>
</ul>
</li>
<li><strong>Stress Testing</strong>: They analyzed performance distributions across molecular weight, heavy atom count, and rare elements (e.g., Uranium, Vanadium).</li>
</ol>
<h2 id="results-and-limitations-in-complex-structures">Results and Limitations in Complex Structures</h2>
<ul>
<li><strong>High Accuracy on Standard Sets</strong>: The model achieved <strong>82% accuracy</strong> on the Indigo validation set and <strong>77%</strong> on the USPTO validation set.</li>
<li><strong>Real-World Viability</strong>: It achieved <strong>83% accuracy</strong> on the proprietary internal test set, suggesting it is ready for production curation workflows.</li>
<li><strong>Limitations on Complexity</strong>: Performance dropped to <strong>41% on the Valko test set</strong>. This was attributed to complex superatoms and explicit stereochemistry not present in the training distribution.</li>
<li><strong>Stereochemistry Challenges</strong>: The model struggled to learn correct chiral configurations (R vs S) purely from 2D images without broader context, often correctly identifying the stereocenter but assigning the wrong direction.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors utilized three primary sources for generating training data. All inputs were strictly downsampled to improve robustness.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>Indigo Set</strong></td>
          <td>57M</td>
          <td>PubChem molecules rendered via Indigo (256x256).</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>USPTO Set</strong></td>
          <td>1.7M</td>
          <td>Image/SMILES pairs from public patent data.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>OS X Indigo</strong></td>
          <td>10M</td>
          <td>Additional Indigo renders from Mac OS for style diversity.</td>
      </tr>
      <tr>
          <td><strong>Segmentation</strong></td>
          <td><strong>Synthetic Pages</strong></td>
          <td>N/A</td>
          <td>Generated by overlaying USPTO images on text-cleared PDF pages.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Segmentation Inputs</strong>: Grayscale, downsampled to ~60 dpi.</li>
<li><strong>Prediction Inputs</strong>: Resized to 256x256 such that bond lengths are 3-12 pixels.</li>
<li><strong>Normalization</strong>: Input pixels normalized using $\frac{\text{input} - 251.7392}{261.574}$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Segmentation Pipeline</strong>:</p>
<ul>
<li><strong>Multi-scale Inference</strong>: Masks generated at resolutions from 30 to 60 dpi (3 dpi increments) and averaged for the final mask.</li>
<li><strong>Post-processing</strong>: Hough transform used to remove long straight lines (table borders). Mask blobs filtered by pixel count thresholds.</li>
</ul>
<p><strong>Prediction Pipeline</strong>:</p>
<ul>
<li><strong>Sequence Generation</strong>: SMILES generated character-by-character (beam search implied by &ldquo;sequences of highest confidence&rdquo; but implemented as product of softmax outputs).</li>
<li><strong>Attention-based Verification</strong>: Attention weights used to re-project predicted atoms back into 2D space to visually verify alignment with the input image.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>1. Segmentation Model (U-Net Variant)</strong>:</p>
<ul>
<li><strong>Architecture</strong>: U-Net style with skip connections.</li>
<li><strong>Input</strong>: 128x128x1 grayscale image.</li>
<li><strong>Layers</strong>: Alternating 3x3 Conv and 2x2 Max Pool.</li>
<li><strong>Activation</strong>: Parametric ReLU (pReLU).</li>
<li><strong>Parameters</strong>: ~380,000.</li>
</ul>
<p><strong>2. Structure Prediction Model (Encoder-Decoder)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: CNN with 5x5 convolutions, 2x2 Max Pooling, pReLU. No pooling in first layers to preserve fine features.</li>
<li><strong>Decoder</strong>: 3 layers of <strong>GridLSTM</strong> cells.</li>
<li><strong>Attention</strong>: Soft/Global attention mechanism conditioned on the encoder state.</li>
<li><strong>Input</strong>: 256x256x1 image.</li>
<li><strong>Output</strong>: Sequence of characters (vocab size 65).</li>
<li><strong>Parameters</strong>: ~46.3 million.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation required an exact string match of the Canonical SMILES (including stereochemistry) to the ground truth.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Dataset</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td><strong>82%</strong></td>
          <td>Indigo Val</td>
          <td>Synthetic validation set</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>77%</strong></td>
          <td>USPTO Val</td>
          <td>Real patent images</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>83%</strong></td>
          <td>Proprietary</td>
          <td>Internal pharma dataset (real world)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>41%</strong></td>
          <td>Valko Test</td>
          <td>External benchmark; difficult due to superatoms</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Segmentation Training</strong>: 1 GPU, ~4 days (650k steps).</li>
<li><strong>Prediction Training</strong>: 8 NVIDIA Pascal GPUs, ~26 days (1M steps).</li>
<li><strong>Framework</strong>: TensorFlow 1.x (Google).</li>
</ul>
]]></content:encoded></item><item><title>DECIMER: Deep Learning for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/decimer/</guid><description>Deep learning method for optical chemical structure recognition using image captioning networks trained on millions of synthetic molecular images.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A. &amp; Steinbeck, C. (2020). DECIMER: towards deep learning for chemical image recognition. <em>Journal of Cheminformatics</em>, 12(65). <a href="https://doi.org/10.1186/s13321-020-00469-w">https://doi.org/10.1186/s13321-020-00469-w</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2020</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Kohulan/DECIMER">Official GitHub Repository</a></li>
<li><a href="https://github.com/Kohulan/DECIMER-Image-to-SMILES">DECIMER Image-to-SMILES Repository</a></li>
</ul>
<h2 id="contribution-method-for-optical-chemical-entity-recognition">Contribution: Method for Optical Chemical Entity Recognition</h2>
<p>This is primarily a <strong>Method ($\Psi_{\text{Method}}$)</strong> paper with a strong <strong>Resource ($\Psi_{\text{Resource}}$)</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architecture (DECIMER) that repurposes &ldquo;show-and-tell&rdquo; image captioning networks for Optical Chemical Entity Recognition (OCER), providing an alternative to traditional rule-based segmentation pipelines.</li>
<li><strong>Resource</strong>: It establishes a framework for generating massively scalable synthetic training data using open-source cheminformatics tools (CDK) and databases (PubChem), circumventing the scarcity of manually annotated chemical images.</li>
</ul>
<h2 id="motivation-brittleness-of-heuristic-pipelines">Motivation: Brittleness of Heuristic Pipelines</h2>
<p>The extraction of chemical structures from scientific literature (OCER) is critical for populating open-access databases. Traditional OCER systems (like OSRA or CLiDE) rely on complex multi-step pipelines involving vectorization, character recognition, and graph compilation. These systems are brittle and incorporating new structural features requires laborious engineering. Inspired by deep reinforcement learning systems like AlphaGo Zero, the authors sought to formulate an end-to-end deep learning approach that learns directly from data with minimal prior assumptions.</p>
<h2 id="novelty-image-captioning-for-molecular-graphs">Novelty: Image Captioning for Molecular Graphs</h2>
<ul>
<li><strong>Image-to-Text Formulation</strong>: The paper frames chemical structure recognition as an image captioning problem, translating a bitmap image directly into a SMILES string using an encoder-decoder network. This bypasses explicit segmentation of atoms and bonds entirely.</li>
<li><strong>Synthetic Data Strategy</strong>: The authors generate synthetic images from PubChem using the CDK Structure Diagram Generator, scaling the dataset size to 15 million.</li>
<li><strong>Robust String Representations</strong>: The study performs key ablation experiments on string representations, comparing standard SMILES against DeepSMILES to evaluate how syntactic validity affects the network&rsquo;s learning capability.</li>
</ul>
<h2 id="experimental-setup-and-validation-strategies">Experimental Setup and Validation Strategies</h2>
<ul>
<li><strong>Data Scaling</strong>: Models were trained on dataset sizes ranging from 54,000 to 15 million synthetic images to observe empirical scaling laws regarding accuracy and compute time.</li>
<li><strong>Representation Comparison</strong>: The authors compared the validity of predicted strings and recognition accuracy when training on SMILES versus DeepSMILES. The cross-entropy loss formulation for sequence generation can be represented as:
$$ \mathcal{L} = -\sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$
where $\mathbf{x}$ is the image representation and $y_t$ are the tokens of the SMILES/DeepSMILES string.</li>
<li><strong>Metric Evaluation</strong>: Performance was measured using Validity (syntactic correctness) and Tanimoto Similarity $T$, computed on molecular fingerprints to capture partial correctness even if the exact string prediction failed:
$$ T(A, B) = \frac{|A \cap B|}{|A| + |B| - |A \cap B|} $$</li>
</ul>
<h2 id="results-and-critical-conclusions">Results and Critical Conclusions</h2>
<ul>
<li><strong>Data Representation</strong>: DeepSMILES proved superior to standard SMILES for training stability and output validity. Preliminary tests suggested SELFIES performs even better (0.78 Tanimoto vs 0.53 for DeepSMILES at 6M images).</li>
<li><strong>Scaling Behavior</strong>: Accuracy improves linearly with dataset size. The authors extrapolate that near-perfect detection would require training on 50 to 100 million structures.</li>
<li><strong>Current Limitations</strong>: At the reported training scale (up to 15M), the model does not yet rival traditional heuristic approaches, but the learning curve suggests it is a viable trajectory given sufficient compute and data.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is synthetic, generated using the Chemistry Development Kit (CDK) Structure Diagram Generator (SDG) based on molecules from PubChem.</p>
<p><strong>Curation Rules</strong> (applied to PubChem data):</p>
<ul>
<li>Molecular weight &lt; 1500 Daltons.</li>
<li>Elements restricted to: C, H, O, N, P, S, F, Cl, Br, I, Se, B.</li>
<li>No counter ions or charged groups.</li>
<li>No isotopes (e.g., D, T).</li>
<li>Bond count between 5 and 40.</li>
<li>SMILES length &lt; 40 characters.</li>
<li>Implicit hydrogens only (except in functional groups).</li>
</ul>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Images</strong>: Generated as 299x299 bitmaps to match Inception V3 input requirements.</li>
<li><strong>Augmentation</strong>: One random rotation applied per molecule; no noise or blurring added in this iteration.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic (PubChem)</td>
          <td>54k - 15M</td>
          <td>Scaled across 12 experiments</td>
      </tr>
      <tr>
          <td>Testing</td>
          <td>Independent Set</td>
          <td>6k - 1.6M</td>
          <td>10% of training size</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: <code>&quot;Show, Attend and Tell&quot;</code> (Attention-based Image Captioning).</li>
<li><strong>Optimization</strong>: Adam optimizer with learning rate 0.0005.</li>
<li><strong>Loss Function</strong>: Sparse Categorical Crossentropy.</li>
<li><strong>Training Loop</strong>: Trained for 25 epochs per model. Batch size of 640 images.</li>
</ul>
<h3 id="models">Models</h3>
<p>The network is implemented in TensorFlow 2.0.</p>
<ul>
<li><strong>Encoder</strong>: Inception V3 (Convolutional NN), used unaltered. Extracts feature vectors saved as NumPy arrays.</li>
<li><strong>Decoder</strong>: Gated Recurrent Unit (GRU) based Recurrent Neural Network (RNN) with soft attention mechanism.</li>
<li><strong>Embeddings</strong>: Image embedding dimension size of 600.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The primary metric is Tanimoto similarity (Jaccard index) on PubChem fingerprints, which is robust for measuring structural similarity even when exact identity is not reached.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Definition</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto 1.0</strong></td>
          <td>Percentage of predictions that are chemically identical to ground truth (isomorphic).</td>
      </tr>
      <tr>
          <td><strong>Average Tanimoto</strong></td>
          <td>Mean similarity score across the test set (captures partial correctness).</td>
      </tr>
      <tr>
          <td><strong>Validity</strong></td>
          <td>Percentage of predicted strings that are valid DeepSMILES/SMILES.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on a single node.</p>
<ul>
<li><strong>GPU</strong>: 1x NVIDIA Tesla V100.</li>
<li><strong>CPU</strong>: 2x Intel Xeon Gold 6230.</li>
<li><strong>RAM</strong>: 384 GB.</li>
<li><strong>Compute Time</strong>:
<ul>
<li>Linear scaling with data size.</li>
<li>15 million structures took ~27 days (64,909s per epoch).</li>
<li>Projected time for 100M structures: ~4 months on single GPU.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMERDeepLearning2020,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{DECIMER}}: Towards Deep Learning for Chemical Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{DECIMER}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{65}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-020-00469-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemGrapher: Deep Learning for Chemical OCR</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemgrapher-2020/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemgrapher-2020/</guid><description>Deep learning OCSR method using semantic segmentation and classification CNNs to reconstruct chemical graphs with improved stereochemistry.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Oldenhof, M., Arany, A., Moreau, Y., &amp; Simm, J. (2020). ChemGrapher: Optical Graph Recognition of Chemical Compounds by Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 60(10), 4506-4517. <a href="https://doi.org/10.1021/acs.jcim.0c00459">https://doi.org/10.1021/acs.jcim.0c00459</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2020 (arXiv preprint Feb 2020)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2002.09914">arXiv Page</a></li>
</ul>
<h2 id="classifying-the-methodology">Classifying the Methodology</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel deep learning architecture and a specific graph-reconstruction algorithm to solve the problem of Optical Chemical Structure Recognition (OCSR). It validates this method by comparing it against the existing standard tool (OSRA), demonstrating superior performance on specific technical challenges like stereochemistry.</p>
<h2 id="the-ocr-stereochemistry-challenge">The OCR Stereochemistry Challenge</h2>
<p>Chemical knowledge is frequently locked in static images within scientific publications. Extracting this structure into machine-readable formats (graphs, SMILES) is essential for drug discovery and database querying. Existing tools, such as OSRA, rely on optical character recognition (OCR) and expert systems or hand-coded rules. These tools struggle with bond multiplicity and stereochemical information, often missing atoms or misinterpreting 3D cues (wedges and dashes). A machine learning approach allows for improvement via data scaling.</p>
<h2 id="decoupled-semantic-segmentation-and-classification-pipeline">Decoupled Semantic Segmentation and Classification Pipeline</h2>
<p>The core novelty is the <strong>segmentation-classification pipeline</strong> which decouples object detection from type assignment:</p>
<ol>
<li><strong>Semantic Segmentation</strong>: The model first predicts pixel-wise maps for atoms, bonds, and charges using a modified U-Net/dilated convolution architecture.</li>
<li><strong>Graph Building Algorithm</strong>: A specific algorithm iterates over the segmentation maps to generate candidate locations for atoms and bonds.</li>
<li><strong>Refinement via Classification</strong>: Dedicated classification networks take cutouts of the original image combined with the segmentation mask to verify and classify each candidate (e.g., distinguishing a single bond from a double bond, or a wedge from a dash).</li>
</ol>
<p>Additionally, the authors developed a novel method for <strong>synthetic data generation</strong> by modifying the source code of RDKit to output pixel-wise labels during the image drawing process. This solves the lack of labeled training data.</p>
<h2 id="evaluating-synthetics-and-benchmarks">Evaluating Synthetics and Benchmarks</h2>
<ul>
<li><strong>Synthetic Benchmarking</strong>: The authors generated test sets in 3 different stylistic variations. For each style, they tested on both stereo (complex 3D information) and non-stereo compounds.</li>
<li><strong>Baseline Comparison</strong>: They compared the error rates of ChemGrapher against <strong>OSRA</strong> (Optical Structure Recognition Application).</li>
<li><strong>Ablation Study</strong>: They analyzed the F1 scores of the segmentation networks versus the classification networks independently to understand where errors propagated.</li>
<li><strong>Real-world Case Study</strong>: They manually curated 61 images cut from journal articles to test performance on real, non-synthetic data.</li>
</ul>
<h2 id="advancements-over-osra">Advancements Over OSRA</h2>
<ul>
<li><strong>Superior Accuracy</strong>: ChemGrapher consistently achieved lower error rates than OSRA across all synthetic styles, particularly for stereochemical information (wedge and dash bonds).</li>
<li><strong>Component Performance</strong>: The classification networks showed significantly higher F1 scores than the segmentation networks. This suggests the two-stage approach allows the classifier to correct segmentation noise.</li>
<li><strong>Real-world Viability</strong>: In the manual case study, ChemGrapher correctly predicted 46 of 61 images, compared to 42 of 61 for OSRA.</li>
<li><strong>Limitations</strong>: The model struggles with thick bond lines and non-carbon atom labels in real-world images due to the bias in the synthetic training data.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors created a custom synthetic dataset using ChEMBL and RDKit, as no pixel-wise labeled dataset existed.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Source</strong></td>
          <td>ChEMBL</td>
          <td>1.9M</td>
          <td>Split into training pool (1.5M), val/train pool (300K), and test pools (35K each).</td>
      </tr>
      <tr>
          <td><strong>Segmentation Train</strong></td>
          <td>Synthetic</td>
          <td>~114K</td>
          <td>Sampled from ChEMBL pool such that every atom type appears in &gt;1000 compounds.</td>
      </tr>
      <tr>
          <td><strong>Labels</strong></td>
          <td>Pixel-wise</td>
          <td>N/A</td>
          <td>Generated by modifying <strong>RDKit</strong> source code to output label masks (atom type, bond type, charge) during drawing.</td>
      </tr>
      <tr>
          <td><strong>Candidates</strong></td>
          <td>Cutouts</td>
          <td>~27K (Atom)<br>~55K (Bond)</td>
          <td>Generated from the validation pool for training the classification networks.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Algorithm 1: Graph Building</strong></p>
<ol>
<li><strong>Segment</strong>: Apply segmentation network $s(x)$ to get maps $S^a$ (atoms), $S^b$ (bonds), $S^c$ (charges).</li>
<li><strong>Atom Candidates</strong>: Identify candidate blobs in $S^a$.</li>
<li><strong>Classify Atoms</strong>: For each candidate, crop the input image and segmentation map. Feed to $c_A$ and $c_C$ to predict Atom Type and Charge. Add to Vertex set $V$ if valid.</li>
<li><strong>Bond Candidates</strong>: Generate all pairs of nodes in $V$ within $2 \times$ bond length distance.</li>
<li><strong>Classify Bonds</strong>: For each pair, create a candidate mask (two rectangles meeting in the middle to encode directionality). Feed to $c_B$ to predict Bond Type (single, double, wedge, etc.). Add to Edge set $E$.</li>
</ol>
<h3 id="models">Models</h3>
<p>The pipeline uses four distinct Convolutional Neural Networks (CNNs).</p>
<p><strong>1. Semantic Segmentation Network ($s$)</strong></p>
<ul>
<li><strong>Architecture</strong>: 8-layer fully convolutional network (Dense Prediction Convolutional Network).</li>
<li><strong>Kernels</strong>: $3\times3$ for all layers.</li>
<li><strong>Dilation</strong>: Uses dilated convolutions to expand receptive field without losing resolution. Dilations: 1, 2, 4, 8, 8, 4, 2, 1.</li>
<li><strong>Input</strong>: Binary B/W image.</li>
<li><strong>Output</strong>: Multi-channel probability maps for Atom Types ($S^a$), Bond Types ($S^b$), and Charges ($S^c$).</li>
</ul>
<p><strong>2. Classification Networks ($c_A, c_B, c_C$)</strong></p>
<ul>
<li><strong>Purpose</strong>: Refines predictions on small image patches.</li>
<li><strong>Architecture</strong>: 5 convolutional layers followed by a MaxPool and 2 Fully Connected layers.
<ul>
<li>Layer 1: <strong>Depthwise separable convolution</strong>.</li>
<li>Layers 2-4: Dilated convolutions (factors 2, 4, 8).</li>
<li>Layer 5: Standard convolution + MaxPool ($124 \x124$).</li>
</ul>
</li>
<li><strong>Inputs</strong>:
<ul>
<li>Crop of the binary image ($x^{cut}$).</li>
<li>Crop of the segmentation map ($S^{cut}$).</li>
<li>&ldquo;Highlight&rdquo; mask ($h_L$) indicating the specific candidate location (e.g., a dot for atoms, two rectangles for bonds).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric</strong>: <strong>F1 Score</strong> for individual network performance (segmentation pixels and classification accuracy).</li>
<li><strong>Metric</strong>: <strong>Error Rate</strong> (percentage of incorrect graphs) for overall system. A graph is &ldquo;incorrect&rdquo; if there is at least one mistake in atoms or bonds.</li>
<li><strong>Baselines</strong>: Compared against <strong>OSRA</strong> (v2.1.0 implied by context of standard tools).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Training and inference performed on a single <strong>NVIDIA Titan Xp</strong>.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{oldenhof2020chemgrapher,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemGrapher: Optical Graph Recognition of Chemical Compounds by Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Oldenhof, Martijn and Arany, Adam and Moreau, Yves and Simm, Jaak}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{60}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{4506--4517}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACS Publications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.0c00459}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kekulé-1 System for Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/kekule-1996/</link><pubDate>Mon, 15 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/kekule-1996/</guid><description>Foundational OCSR method combining neural OCR with chemical rule-based post-processing for automated structure interpretation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: McDaniel, J. R., &amp; Balmuth, J. R. (1996). Automatic Interpretation of Chemical Structure Diagrams. <em>Graphics Recognition. Methods and Applications</em>, 148-158. <a href="https://doi.org/10.1007/3-540-61226-2_14">https://doi.org/10.1007/3-540-61226-2_14</a></p>
<p><strong>Publication</strong>: Lecture Notes in Computer Science (LNCS), Vol. 1072, Springer, 1996.</p>
<h2 id="system-architecture-and-contribution">System Architecture and Contribution</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel software architecture (&ldquo;Kekulé-1&rdquo;) designed to solve the specific technical problem of converting rasterized chemical diagrams into machine-readable connection tables. The paper is characterized by:</p>
<ul>
<li><strong>Algorithmic Specification</strong>: It details specific algorithms for vectorization, polygon approximation, and character recognition.</li>
<li><strong>Performance Metrics</strong>: It validates the method using quantitative accuracy (98.9%) and speed comparisons against manual entry.</li>
<li><strong>System Architecture</strong>: It describes the integration of typically disparate components (OCR, vectorization, chemical rules) into a cohesive pipeline.</li>
</ul>
<h2 id="motivation-the-chemical-data-entry-bottleneck">Motivation: The Chemical Data Entry Bottleneck</h2>
<p>Chemical structure diagrams are the primary medium for communication between chemists, but computers cannot natively &ldquo;read&rdquo; these raster images.</p>
<ul>
<li><strong>Efficiency Gap</strong>: Manual redrawing of structures into chemical databases takes 6 to 10 minutes per structure.</li>
<li><strong>Technical Challenge</strong>: Existing commercial OCR systems failed on chemical diagrams because they could not handle the mix of graphics (bonds) and text (atom labels), nor could they recognize small fonts (3-7 points) or chemical symbols accurately.</li>
<li><strong>Goal</strong>: To create an &ldquo;Optical Chemical Structure Recognition&rdquo; (OCSR) system that reduces processing time to seconds while handling complex notation like stereochemistry and group formulas.</li>
</ul>
<h2 id="core-innovations-in-chemical-ocr">Core Innovations in Chemical OCR</h2>
<p>Kekulé-1 represents the &ldquo;first successful attempt&rdquo; to integrate image processing, OCR, and structure editing into a single workflow. Key innovations include:</p>
<ul>
<li><strong>Context-Aware OCR</strong>: Unlike standard OCR, Kekulé-1 uses &ldquo;chemical spell checking&rdquo; by applying valence rules and chemical context to correct raw character recognition errors (e.g., distinguishing &lsquo;5&rsquo; from &lsquo;S&rsquo; based on bonding).</li>
<li><strong>Adaptive Polygon Approximation</strong>: A modified vectorization algorithm that partitions objects at the farthest node to prevent artifact nodes in U-shaped structures.</li>
<li><strong>Hybrid Parsing</strong>: It treats the diagram as a graph where nodes can be explicit atoms or geometric intersections, using rule-based logic to parse &ldquo;group formulas&rdquo; (like $COOH$) recursively.</li>
</ul>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The authors evaluated the system on a private test set to validate robustness and speed.</p>
<ul>
<li><strong>Dataset</strong>: 524 chemical structures chosen from a &ldquo;wide variety of sources&rdquo; specifically to test the system&rsquo;s limits.</li>
<li><strong>Metrics</strong>: Success rate (percentage of structures processed with minimal editing) and processing time per structure.</li>
<li><strong>Comparators</strong>: Performance was compared against the &ldquo;manual redrawing&rdquo; baseline.</li>
</ul>
<h2 id="results-performance-and-conclusions">Results, Performance, and Conclusions</h2>
<ul>
<li><strong>High Accuracy</strong>: 98.9% of the test structures were successfully processed (with an average of 0.74 user prompts per structure).</li>
<li><strong>Speedup</strong>: Processing took 7 to 30 seconds per structure, a significant improvement over the 6 to 10 minute manual baseline.</li>
<li><strong>Robustness</strong>: The system successfully handled pathological cases like broken characters, skew (rotation), and touching characters.</li>
<li><strong>Impact</strong>: The authors conclude that the techniques are generalizable to other domains like electrical circuits and utility maps.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training/Test Data</strong>: The evaluation used 524 chemical structures. These were not released publicly but were selected to represent &ldquo;limit&rdquo; cases.</li>
<li><strong>Input format</strong>: Scanned images at 300-400 dpi. The authors note that higher resolutions do not add information due to ink wicking and paper limitations.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The paper details several specific algorithmic implementations:</p>
<p><strong>Vectorization (Polygon Approximation)</strong>:</p>
<ul>
<li>Standard thinning and raster-to-vector translation are used.</li>
<li><strong>Innovation</strong>: The algorithm searches for the node <em>farthest</em> from the current start node to partition the object. This prevents artifact nodes in curved lines.</li>
<li><strong>Threshold Formula</strong>: The allowed deviation ($dist$) from a straight line is adaptive based on segment length ($length$):</li>
</ul>
<p>$$dist = \max(1, \frac{length}{10.0} + 0.4)$$</p>
<p>(Units in pixels)</p>
<p><strong>Rotation Correction</strong>:</p>
<ul>
<li>The system computes the angle of all &ldquo;long&rdquo; line segments modulo 15 degrees.</li>
<li>It bins these angles; the bin with the highest count (representing &lt; 4 degrees rotation) is treated as the scan skew and corrected.</li>
</ul>
<p><strong>Optical Character Recognition (OCR)</strong>:</p>
<ul>
<li>Uses a neural network with linked/shared weights (similar to Convolutional Neural Networks, though not named as such) acting as a feature detector.</li>
<li><strong>Training</strong>: Trained on specific chemical fonts.</li>
<li><strong>Inference</strong>: Outputs are ranked; if multiple characters (e.g., &lsquo;5&rsquo; and &lsquo;S&rsquo;) exceed a threshold, both are kept, and chemical context resolves the ambiguity later.</li>
</ul>
<p><strong>Chemical Parsing</strong>:</p>
<ul>
<li>Group formulas (e.g., $COOH$) are parsed left-to-right by subtracting valences.</li>
<li>Example: For $COOH$, the external bond reduces Carbon&rsquo;s valence to 3. The first Oxygen takes 2, leaving 1. The final Oxygen takes 1 (attaching to Carbon), and the Hydrogen takes 1 (attaching to Oxygen).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>OCR Model</strong>: A neural network with a &ldquo;shared weights&rdquo; paradigm, effectively creating a learned convolution map. It achieves ~99.9% raw accuracy on isolated test sets of chemical fonts.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: The evaluation was performed on an <strong>80486 processor at 33 MHz</strong>.</li>
<li><strong>Time</strong>: Average processing time was 9 seconds per structure.</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{mcdanielAutomaticInterpretationChemical1996,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Automatic Interpretation of Chemical Structure Diagrams}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Graphics Recognition. Methods and Applications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{McDaniel, Joe R. and Balmuth, Jason R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{O&#39;Gorman, Lawrence and Kasturi, Rangachar}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span> = <span style="color:#e6db74">{Lecture Notes in Computer Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{1072}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{148--158}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{1996}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1007/3-540-61226-2_14}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Party Matters: Enhancing Legislative Embeddings</title><link>https://hunterheidenreich.com/notes/interdisciplinary/computational-social-science/party-matters-hiptm/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/interdisciplinary/computational-social-science/party-matters-hiptm/</guid><description>A method for improving legislative vote prediction across sessions by augmenting bill text embeddings with sponsor metadata.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel neural architecture that modifies how bill embeddings are constructed by explicitly incorporating sponsor metadata alongside text. The authors validate this method by comparing it against text-only baselines (MWE and CNN) and demonstrating superior performance in a newly defined &ldquo;out-of-session&rdquo; evaluation setting.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Existing models for predicting legislative roll-call votes rely heavily on text or voting history within a single session. However, these models fail to generalize across sessions because the underlying data generation process changes. Specifically, the ideological position of bills on similar topics shifts depending on which party is in power. A model trained on a single session learns an implicit ideological prior that becomes inaccurate when the political context changes in subsequent sessions.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is a neural architecture that augments bill text representations with sponsor ideology, specifically the percentage of Republican vs. Democrat sponsors.</p>
<ul>
<li><strong>Sponsor-Weighted Embeddings</strong>: They compute a composite embedding where the text representation is weighted by party sponsorship percentages ($p_{r}, p_{d}$) and party-specific influence vectors ($a_{r}, a_{d}$).</li>
<li><strong>Out-of-Session Evaluation</strong>: They introduce a rigorous evaluation setting where models trained on past sessions (e.g., 2005-2012) are tested on future sessions (e.g., 2013-2014) to test generalization, which previous work had ignored.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors evaluated their models using a dataset of U.S. Congressional bills from 2005 to 2016.</p>
<ul>
<li><strong>Models Tested</strong>: They compared text-only models (MWE (Mean Word Embedding), CNN) against metadata-augmented versions (MWE+Meta, CNN+Meta) and a &ldquo;Meta-Only&rdquo; baseline (using dummy text).</li>
<li><strong>Settings</strong>:
<ul>
<li><strong>In-Session</strong>: 5-fold cross-validation on 2005-2012 data.</li>
<li><strong>Out-of-Session</strong>: Training on 2005-2012 and testing on 2013-2014 and 2015-2016.</li>
</ul>
</li>
<li><strong>Baselines</strong>: Comparisons included a &ldquo;Guess Yes&rdquo; baseline and an SVM trained on bag-of-words summaries with sponsor indicators.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Metadata is Critical</strong>: Augmenting text with sponsor metadata consistently outperformed text-only models. The <code>CNN+Meta</code> model achieved the highest accuracy in most settings (e.g., 86.21% in-session vs 83.24% for CNN).</li>
<li><strong>Generalization</strong>: Text-only models degraded significantly in out-of-session testing (dropping from ~83% to ~77% or lower), confirming that text alone fails to capture shifting ideological contexts.</li>
<li><strong>Sponsor Signal</strong>: The <code>Meta-Only</code> model (using no text) outperformed text-only models in the 2013-2014 out-of-session test, suggesting that in some contexts, the author&rsquo;s identity provides a stronger predictive signal than the bill&rsquo;s content.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: Collected from GovTrack, covering sessions from 2005 to 2016 (the paper text says &ldquo;106th to 111th&rdquo;, but this is an error in the published paper: the data tables confirm coverage through 2016, corresponding to the 109th-114th Congressional sessions).</li>
<li><strong>Content</strong>: Non-unanimous roll call votes, full text of bills/resolutions, and Congressional Research Service (CRS) summaries.</li>
<li><strong>Filtering</strong>: Bills with unanimous votes (defined as &lt;1% &ldquo;no&rdquo; votes) were excluded.</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Text lowercased and stop-words removed.</li>
<li>Summaries truncated to $N=400$ words; full text truncated to $N=2000$ words (80th percentile lengths).</li>
</ul>
</li>
<li><strong>Splits</strong>:
<ul>
<li><em>Training</em>: Sessions 2005-2012 (1718 bills).</li>
<li><em>Testing</em>: Sessions 2013-2014 (360 bills) and 2015-2016 (382 bills).</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Bill Representation ($v_{B}$)</strong>:
$$v_{B}=((a_{r}p_{r})\cdot T_{r})+((a_{d}p_{d})\cdot T_{d})$$
Where $T$ is the text embedding (CNN or MWE), $p$ is the percentage of sponsors from a party, and $a$ is a learnable party influence vector. Specifically, $T_{r}$ and $T_{d}$ are the text embeddings computed over bills with Republican and Democrat sponsorship respectively.</li>
<li><strong>Vote Prediction</strong>:
<ul>
<li>Project bill embedding to legislator space: $v_{BL}=W_{B}v_{B}+b_{B}$.</li>
<li>Alignment score: $W_{v}(v_{BL}\odot v_{L})+b_{v}$ (using element-wise multiplication).</li>
<li>Output: Sigmoid activation.</li>
</ul>
</li>
<li><strong>Optimization</strong>: AdaMax algorithm with binary cross-entropy loss.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Text Encoders</strong>:
<ul>
<li><strong>CNN</strong>: 4-grams with 400 filter maps.</li>
<li><strong>MWE</strong>: <a href="/posts/intro-to-word-embeddings/">Mean Word Embedding</a>.</li>
</ul>
</li>
<li><strong>Embeddings</strong>:
<ul>
<li>Initialized with 50-dimensional GloVe vectors.</li>
<li>Embeddings are non-static (updated during training).</li>
<li>Legislator embedding size ($v_{L}$): 25 dimensions.</li>
</ul>
</li>
<li><strong>Initialization</strong>: Weights initialized with Glorot uniform distribution.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Accuracy.</li>
<li><strong>Comparison</strong>:
<ul>
<li><strong>In-session</strong>: 5-fold cross-validation.</li>
<li><strong>Out-of-session</strong>: Train on past sessions, predict future sessions.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Config</strong>: Models trained for 50 epochs with mini-batches of size 50.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kornilova, A., Argyle, D., &amp; Eidelman, V. (2018). Party Matters: Enhancing Legislative Embeddings with Author Attributes for Vote Prediction. <em>arXiv preprint arXiv:1805.08182</em>. <a href="https://arxiv.org/abs/1805.08182">https://arxiv.org/abs/1805.08182</a></p>
<p><strong>Publication</strong>: arXiv 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kornilovaPartyMattersEnhancing2018,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Party {{Matters}}: {{Enhancing Legislative Embeddings}} with {{Author Attributes}} for {{Vote Prediction}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Party {{Matters}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Kornilova, Anastassia and Argyle, Daniel and Eidelman, Vlad}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2018</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = may,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:1805.08182}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{1805.08182}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Mixture Density Networks: Modeling Multimodal Distributions</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</guid><description>Seminal 1994 paper introducing MDNs to model arbitrary conditional probability distributions using neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper.</p>
<p>It identifies a specific failure mode in existing neural network methodologies (least-squares regression on multi-valued inverse problems) and proposes a novel architecture (combining MLPs with Mixture Models) to solve it. It derives the mathematical framework for training this architecture via standard back-propagation and validates it against the established baseline.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard neural networks trained with sum-of-squares (MSE) or cross-entropy error functions approximate the <strong>conditional average</strong> of the target data, $\langle t|x \rangle$.</p>
<p>While optimal for single-valued functions or classification, this produces completely erroneous results for <strong>inverse problems</strong> where the mapping is multi-valued (one input has multiple valid outputs). For example, in robot inverse kinematics, &ldquo;elbow-up&rdquo; and &ldquo;elbow-down&rdquo; configurations can achieve the same hand position. An MSE-trained network will average these two valid angles, resulting in an invalid configuration that puts the robot arm through an obstacle or breaks the linkage.</p>















<figure class="post-figure center ">
    <img src="/img/notes/single-gaussian-mse-prediction.webp"
         alt="Single Gaussian MSE prediction averaging multimodal distribution"
         title="Single Gaussian MSE prediction averaging multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MSE-trained networks predict the mean, which averages across modes and produces invalid outputs for inverse problems.</figcaption>
    
</figure>

<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The introduction of the <strong>Mixture Density Network (MDN)</strong>.</p>
<p>The neural network predicts the <strong>parameters</strong> (mixing coefficients, means, and variances) of a kernel mixture distribution (typically Gaussian).</p>















<figure class="post-figure center ">
    <img src="/img/notes/gaussian-mixture-mdn-prediction.webp"
         alt="Gaussian mixture model prediction capturing multimodal distribution"
         title="Gaussian mixture model prediction capturing multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MDNs predict mixture parameters to capture the full conditional probability density, representing all modes.</figcaption>
    
</figure>

<p>Key technical contributions include:</p>
<ol>
<li><strong>Architecture</strong>: Mapping network outputs to mixture parameters using specific activation functions to satisfy constraints (Softmax for priors $\alpha$, Exponential for variances $\sigma$).</li>
<li><strong>Training</strong>: Deriving the error function as the negative log-likelihood of the mixture model.</li>
<li><strong>Optimization</strong>: Deriving the exact derivatives (gradients) of the error with respect to network outputs, allowing the mixture model parameters to be learned via standard back-propagation.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method on two tasks, comparing an MDN against a standard MLP trained with least-squares:</p>
<ol>
<li><strong>Toy Inverse Problem</strong>: A sinusoidal mapping $x = t + 0.3\sin(2\pi t) + \epsilon$. The forward problem ($t \to x$) is single-valued, but the inverse ($x \to t$) is multi-valued.</li>
<li><strong>Robot Kinematics</strong>: A 2-link robot arm simulation. The task is to map end-effector Cartesian coordinates $(x_1, x_2)$ back to joint angles $(\theta_1, \theta_2)$.</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Toy Problem</strong>: The standard least-squares network failed completely, drawing a smooth curve through the average of the multiple branches, which did not correspond to valid data. The MDN correctly modeled the tri-modal density and discontinuous jumps in the most probable solution.</li>
<li><strong>Robot Kinematics</strong>: The MDN reduced the RMS positioning error by an order of magnitude compared to the standard network (0.0053 vs 0.0578).</li>
<li><strong>Generality</strong>: The paper concludes that MDNs provide a complete description of the conditional probability density, allowing users to calculate any statistic (mean, mode, variance) needed for the application.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>1. Toy Inverse Problem</strong></p>
<ul>
<li><strong>Function</strong>: $x = t + 0.3\sin(2\pi t) + \epsilon$</li>
<li><strong>Noise</strong>: $\epsilon \sim U(-0.1, 0.1)$</li>
<li><strong>Sampling</strong>: 1,000 points generated by sampling $t$ uniformly in range $(0, 1)$.</li>
<li><strong>Task</strong>: Inverse mapping (predict $t$ given $x$).</li>
</ul>
<p><strong>2. Robot Kinematics</strong></p>
<ul>
<li><strong>System</strong>: 2-link arm with lengths $L_1=0.8, L_2=0.2$.</li>
<li><strong>Forward Kinematics</strong>:
<ul>
<li>$x_1 = L_1 \cos(\theta_1) - L_2 \cos(\theta_1 + \theta_2)$</li>
<li>$x_2 = L_1 \sin(\theta_1) - L_2 \sin(\theta_1 + \theta_2)$</li>
</ul>
</li>
<li><strong>Constraints</strong>: $\theta_1 \in (0.3, 1.2)$, $\theta_2 \in (\pi/2, 3\pi/2)$.</li>
<li><strong>Dataset</strong>: 1,000 training points, 1,000 test points.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Mixture Model Definition</strong></p>
<p>The conditional density is defined as:</p>
<p>$$p(t|x) = \sum_{i=1}^{m} \alpha_i(x) \phi_i(t|x)$$</p>
<p>Where kernels $\phi_i$ are Gaussians with centers $\mu_i(x)$ and variances $\sigma_i(x)$.</p>
<p><strong>Network Output Mappings</strong></p>
<p>If the network produces raw outputs $z$, they are mapped to parameters as follows to satisfy probability constraints:</p>
<ul>
<li><strong>Mixing Coefficients ($\alpha$)</strong>: Softmax. $\alpha_i = \frac{\exp(z_i^\alpha)}{\sum_j \exp(z_j^\alpha)}$</li>
<li><strong>Variances ($\sigma$)</strong>: Exponential. $\sigma_i = \exp(z_i^\sigma)$</li>
<li><strong>Means ($\mu$)</strong>: Linear/Identity. $\mu_{ik} = z_{ik}^\mu$</li>
</ul>
<p><strong>Loss Function</strong></p>
<p>Negative Log Likelihood:</p>
<p>$$E^q = - \ln \left{ \sum_{i=1}^{m} \alpha_i(x^q) \phi_i(t^q|x^q) \right}$$</p>
<h3 id="models">Models</h3>
<p><strong>1. Toy Problem Configuration</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 1 input ($x$), 1 hidden layer.</li>
<li><strong>Hidden Units</strong>: 20 units (tanh activation).</li>
<li><strong>Outputs</strong>: 9 units.
<ul>
<li>$m=3$ Gaussian kernels.</li>
<li>Parameters per kernel: 1 $\alpha$, 1 $\sigma$, 1 $\mu$. Total = $3 \times 3 = 9$.</li>
</ul>
</li>
<li><strong>Training</strong>: 1,000 cycles of BFGS.</li>
</ul>
<p><strong>2. Robot Kinematics Configuration</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 2 inputs ($x_1, x_2$).</li>
<li><strong>Hidden Units</strong>: 10 units (tanh activation).</li>
<li><strong>Outputs</strong>: 8 units.
<ul>
<li>$m=2$ Gaussian kernels.</li>
<li>Target dimension $c=2$ (predicting $\theta_1, \theta_2$).</li>
<li>Parameters per kernel: 1 $\alpha$ + 1 $\sigma$ (common variance) + 2 $\mu$ (means for $\theta_1, \theta_2$).</li>
<li>Total = $2 \times (1 + 1 + 2) = 8$.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric</strong>: RMS Euclidean distance between the desired end-effector position and the achieved position (calculated by plugging predicted angles back into forward kinematics).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Hidden Units</th>
          <th>Kernels</th>
          <th>RMS Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Least Squares</td>
          <td>20</td>
          <td>N/A</td>
          <td>0.0578</td>
      </tr>
      <tr>
          <td>MDN</td>
          <td>10</td>
          <td>2</td>
          <td>0.0053</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bishop, C. M. (1994). Mixture Density Networks. <em>Neural Computing Research Group Report: NCRG/94/004</em>, Aston University.</p>
<p><strong>Publication</strong>: Neural Computing Research Group Technical Report 1994</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{bishopMixtureDensityNetworks1994,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Mixture {{Density Networks}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Bishop, Christopher M.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1994</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{NCRG/94/004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span> = <span style="color:#e6db74">{Aston University}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kekulé: OCR-Optical Chemical Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/kekule-1992/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/kekule-1992/</guid><description>A seminal 1992 system for Optical Chemical Structure Recognition (OCSR) using neural networks and heuristic graph compilation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: McDaniel, J. R., &amp; Balmuth, J. R. (1992). Kekulé: OCR-Optical Chemical (Structure) Recognition. <em>Journal of Chemical Information and Computer Sciences</em>, 32(4), 373-378. <a href="https://doi.org/10.1021/ci00008a018">https://doi.org/10.1021/ci00008a018</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Computer Sciences, 1992</p>
<h2 id="system-architecture-and-methodological-approach">System Architecture and Methodological Approach</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$). It proposes a novel software architecture (&ldquo;Kekulé&rdquo;) designed to solve a specific technical problem: the automatic conversion of printed chemical structure diagrams into computer-readable connection tables. The paper focuses on the &ldquo;how&rdquo; of the system by detailing the seven-step pipeline from scanning to graph compilation, validating the method through performance testing on a specific dataset.</p>
<h2 id="motivation-bridging-visual-diagrams-and-connection-tables">Motivation: Bridging Visual Diagrams and Connection Tables</h2>
<p>The primary motivation is to bridge the gap between how chemists communicate (structural diagrams) and how chemical databases store information (connection tables like MOLfiles).</p>
<ul>
<li><strong>Inefficiency of Manual Entry</strong>: Manual compilation of structural descriptions is &ldquo;tedious and highly prone to error&rdquo;.</li>
<li><strong>Redrawing Costs</strong>: Even using drawing programs (like ChemDraw ancestors) to capture connectivity is inefficient; redrawing a complex molecule like vitamin $B_{12}$ takes ~20 minutes.</li>
<li><strong>Lack of Existing Solutions</strong>: Existing OCR systems at the time failed on chemical diagrams because they could not handle the mix of graphics (bonds) and text (atom labels), and struggled with small, mixed fonts.</li>
</ul>
<h2 id="novelty-a-hybrid-ocr-and-heuristic-approach">Novelty: A Hybrid OCR and Heuristic Approach</h2>
<p>Kekulé represents the first successful attempt to integrate all of the required elements of image processing, OCR, structure editing, and database communication into a complete system.</p>
<ul>
<li><strong>Hybrid OCR Approach</strong>: Unlike commercial OCR of the time, it used a custom implementation combining rotation correction (for skew) with a <strong>multilayer perceptron neural network</strong> trained specifically on small fonts (down to 3.2 points).</li>
<li><strong>Heuristic Feature Extraction</strong>: The authors developed specific heuristics to handle chemical artifacts, such as an exhaustive search for dashed lines, explicitly rejecting Hough transforms as unreliable for short segments.</li>
<li><strong>Contextual &ldquo;Spell Checking&rdquo;</strong>: The system uses chemical context to verify OCR results, such as checking atom symbols against a valid list and using bond connections to disambiguate characters.</li>
</ul>
<h2 id="experimental-setup-and-dataset-validation">Experimental Setup and Dataset Validation</h2>
<p>The authors performed a validation study on a diverse set of chemical structures to stress-test the system:</p>
<ul>
<li><strong>Dataset</strong>: 444 chemical structures were selected from a wide variety of sources, including the <em>Merck Index</em>, <em>Aldrich Handbook</em>, and <em>ACS Nomenclature Guide</em>, specifically chosen to &ldquo;test Kekulé&rsquo;s limits&rdquo;.</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Processing Success</strong>: Percentage of structures processed.</li>
<li><strong>User Intervention</strong>: Average number of prompts per structure for verification.</li>
<li><strong>Editing Time</strong>: Time required to correct interpretation errors (arbitrary &ldquo;good&rdquo; limit set at 30 seconds).</li>
</ul>
</li>
</ul>
<h2 id="results-and-system-performance">Results and System Performance</h2>
<ul>
<li><strong>High Success Rate</strong>: 98.9% of the 444 structures were processed successfully.</li>
<li><strong>Performance Speed</strong>: The average processing time was 9 seconds per structure on an 80486 (33 MHz) processor.</li>
<li><strong>Error Modes</strong>: The primary bottleneck was broken characters in scanned images (e.g., breaks in &lsquo;H&rsquo; or &lsquo;N&rsquo; crossbars), which slowed down the OCR significantly.</li>
<li><strong>Impact</strong>: The system demonstrated that automated interpretation was faster and less error-prone than manual redrawing.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The following details outline the specific technical implementation described in the 1992 paper.</p>
<h3 id="data">Data</h3>
<p>The authors did not release a public dataset but described their test set sources in detail.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>Mixed Chemical Sources</td>
          <td>444 structures</td>
          <td>Sourced from <em>Merck Index</em>, <em>Aldrich Handbook</em>, <em>ACS Nomenclature Guide</em>, etc.</td>
      </tr>
      <tr>
          <td>Training (OCR)</td>
          <td>Font Exemplars</td>
          <td>Unknown</td>
          <td>&ldquo;Exemplars of characters from numerous serif and sanserif fonts&rdquo;.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The paper details a 7-step pipeline. Key algorithmic choices include:</p>
<ul>
<li>
<p><strong>Vectorization</strong>:</p>
<ul>
<li>Images are reduced to 1-pixel width using <strong>thinning</strong> and <strong>raster-to-vector translation</strong>.</li>
<li>An <strong>adaptive smoothing algorithm</strong> is applied to remove pixel-level jitter.</li>
</ul>
</li>
<li>
<p><strong>Feature Extraction (Dashed Lines)</strong>:</p>
<ul>
<li><strong>Hough Transforms</strong> were rejected due to poor performance on short line segments.</li>
<li><strong>Slope sorting</strong> was rejected due to variance in short dashes.</li>
<li><strong>Chosen Method</strong>: Exhaustive search/testing of all features that <em>might</em> be dashed lines (subset of features).</li>
</ul>
</li>
<li>
<p><strong>Graph Compilation</strong>:</p>
<ul>
<li><strong>Character Grouping</strong>: Characters are assembled into strings based on XY adjacency.</li>
<li><strong>Node Creation</strong>: Character strings become nodes. Vectors with endpoints &ldquo;too far&rdquo; from strings create new nodes.</li>
<li><strong>Heuristics</strong>: Circles are converted to alternating single-double bonds; &ldquo;thick&rdquo; bonds between wedges are automatically generated.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The core machine learning component is the OCR engine.</p>
<ul>
<li><strong>Architecture</strong>: A <strong>multilayer perceptron neural network</strong> (fully connected).</li>
<li><strong>Input</strong>: Normalized characters. Normalization involves rotation (for skew), scaling, under-sampling, and contrast/density adjustments.</li>
<li><strong>Output</strong>: Ranked probability matches. Outputs above an experimental threshold are retained. If a character is ambiguous (e.g., &lsquo;5&rsquo; vs &lsquo;S&rsquo;), both are kept and resolved via chemical context.</li>
<li><strong>Performance</strong>: Raw accuracy ~96% on small fonts (compared to ~85% for commercial OCR of the era).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The system was developed and tested on hardware typical of the early 1990s.</p>
<ul>
<li><strong>Processor</strong>: Intel 80486 at 33 MHz.</li>
<li><strong>Scanners</strong>: Hewlett-Packard ScanJet (300 dpi) and Logitech ScanMan (400 dpi hand-held).</li>
<li><strong>Platform</strong>: Microsoft Windows.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{mcdanielKekuleOCRopticalChemical1992,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Kekulé: {{OCR-optical}} Chemical (Structure) Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Kekulé}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{McDaniel, Joe R. and Balmuth, Jason R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1992</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Computer Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{32}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{373--378}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{0095-2338, 1520-5142}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/ci00008a018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>IMG2SMI: Translating Molecular Structure Images to SMILES</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/img2smi/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/img2smi/</guid><description>Campos &amp; Ji's method for converting 2D molecular images to SMILES strings using Transformers and SELFIES representation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Campos, D., &amp; Ji, H. (2021). IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System (No. arXiv:2109.04202). arXiv. <a href="https://doi.org/10.48550/arXiv.2109.04202">https://doi.org/10.48550/arXiv.2109.04202</a></p>
<p><strong>Publication</strong>: arXiv preprint (2021)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2109.04202">Paper on arXiv</a></li>
</ul>
<h2 id="contributions--taxonomy">Contributions &amp; Taxonomy</h2>
<p>This is both a <strong>Method</strong> and <strong>Resource</strong> paper:</p>
<ul>
<li><strong>Method</strong>: It adapts standard image captioning architectures (encoder-decoder) to the domain of Optical Chemical Structure Recognition (OCSR), treating molecule recognition as a translation task.</li>
<li><strong>Resource</strong>: It introduces <strong>MOLCAP</strong>, a large-scale dataset of 81 million molecules aggregated from public chemical databases, addressing the data scarcity that previously hindered deep learning approaches to OCSR.</li>
</ul>
<h2 id="the-bottleneck-in-chemical-literature-translation">The Bottleneck in Chemical Literature Translation</h2>
<p>Chemical literature is &ldquo;full of recipes written in a language computers cannot understand&rdquo; because molecules are depicted as 2D images. This creates a fundamental bottleneck:</p>
<ul>
<li><strong>The Problem</strong>: Chemists must manually redraw molecular structures to search for related compounds or reactions. This is slow, error-prone, and makes large-scale literature mining impossible.</li>
<li><strong>Existing Tools</strong>: Legacy systems like OSRA (Optical Structure Recognition Application) rely on handcrafted rules and often require human correction, making them unfit for unsupervised, high-throughput processing.</li>
<li><strong>The Goal</strong>: An automated system that can translate structure images directly to machine-readable strings (SMILES/SELFIES) without human supervision, enabling large-scale knowledge extraction from decades of chemistry literature and patents.</li>
</ul>
<h2 id="core-innovation-selfies-and-image-captioning">Core Innovation: SELFIES and Image Captioning</h2>
<p>The core novelty is demonstrating that <strong>how you represent the output text is as important as the model architecture itself</strong>. Key contributions:</p>
<ol>
<li>
<p><strong>Image Captioning Framework</strong>: Applies modern encoder-decoder architectures (ResNet-101 + Transformer) to OCSR, treating it as an image-to-text translation problem with a standard cross-entropy loss objective over the generation sequence:
$$ \mathcal{L} = -\sum\limits_{t=1}^{T} \log P(y_t \mid y_1, \ldots, y_{t-1}, x) $$</p>
</li>
<li>
<p><strong>SELFIES as Target Representation</strong>: The key mechanism relies on using <strong>SELFIES</strong> (Self-Referencing Embedded Strings) as the output format. SELFIES is based on a formal grammar where every possible string corresponds to a valid molecule, eliminating the syntactic invalidity problems (unmatched parentheses, invalid characters) that plague SMILES generation.</p>
</li>
<li>
<p><strong>MOLCAP Dataset</strong>: Created a comprehensive dataset of 81 million unique molecules from PubChem, ChEMBL, GDB13, and other sources. Generated 256x256 pixel images using RDKit for 1 million training samples and 5,000 validation samples.</p>
</li>
<li>
<p><strong>Task-Specific Evaluation</strong>: Demonstrated that traditional NLP metrics (BLEU) are poor indicators of scientific utility. Introduced evaluation based on <strong>molecular fingerprints</strong> (MACCS, RDK, Morgan) and <strong>Tanimoto similarity</strong>:
$$ T(a, b) = \frac{c}{a + b - c} $$
where $c$ is the number of common fingerprint bits, and $a$ and $b$ are the number of set bits in each respective molecule&rsquo;s fingerprint. This formulation reliably measures functional chemical similarity.</p>
</li>
</ol>
<h2 id="experimental-setup-and-ablation-studies">Experimental Setup and Ablation Studies</h2>
<p>The evaluation focused on comparing IMG2SMI to existing systems and identifying which design choices matter most:</p>
<ol>
<li>
<p><strong>Baseline Comparisons</strong>: Benchmarked against OSRA (rule-based system) and DECIMER (first deep learning approach) on the MOLCAP dataset to establish whether modern architectures could surpass traditional methods.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Extensive ablations isolating key factors:</p>
<ul>
<li><strong>Decoder Architecture</strong>: Transformer vs. RNN/LSTM decoders</li>
<li><strong>Encoder Fine-tuning</strong>: Fine-tuned vs. frozen pre-trained ResNet weights</li>
<li><strong>Output Representation</strong>: SELFIES vs. character-level SMILES vs. BPE-tokenized SMILES (the most critical ablation)</li>
</ul>
</li>
</ol>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>MACCS FTS</th>
          <th>Valid Captions</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN Decoder</td>
          <td>~0.36</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Transformer Decoder</td>
          <td>0.94</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Fixed Encoder Weights</td>
          <td>0.76</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Fine-tuned Encoder</td>
          <td>0.94</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Character-level SMILES</td>
          <td>&lt;0.50</td>
          <td>~2%</td>
      </tr>
      <tr>
          <td>BPE SMILES (2000 vocab)</td>
          <td>0.85</td>
          <td>~40%</td>
      </tr>
      <tr>
          <td>SELFIES</td>
          <td>0.94</td>
          <td>99.4%</td>
      </tr>
  </tbody>
</table>
<ol start="3">
<li><strong>Metric Analysis</strong>: Systematic comparison of evaluation metrics including BLEU, ROUGE, Levenshtein distance, exact match accuracy, and molecular fingerprint-based similarity measures.</li>
</ol>
<h2 id="results-findings-and-limitations">Results, Findings, and Limitations</h2>
<p><strong>Performance Gains</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI</th>
          <th>OSRA</th>
          <th>DECIMER</th>
          <th>Random Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>N/A</td>
          <td>~0.20</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9238</td>
          <td>N/A</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8848</td>
          <td>N/A</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>N/A</td>
          <td>0%</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<ul>
<li>163% improvement over OSRA on MACCS Tanimoto similarity.</li>
<li>Approximately 10x improvement on ROUGE scores.</li>
<li>Average Tanimoto similarity exceeds 0.85 (functionally similar molecules even when not exact matches).</li>
</ul>
<p><strong>Key Findings</strong>:</p>
<ul>
<li><strong>SELFIES is Critical</strong>: Using SELFIES yields <strong>99.4% valid molecules</strong>, compared to only ~2% validity for character-level SMILES. This robustness is essential for practical deployment.</li>
<li><strong>Architecture Matters</strong>: Transformer decoder significantly outperforms RNN/LSTM approaches. Fine-tuning the ResNet encoder (vs. frozen weights) yields substantial performance gains (e.g., MACCS FTS: 0.76 to 0.94).</li>
<li><strong>Metric Insights</strong>: BLEU is a poor metric for this task. Molecular fingerprint-based Tanimoto similarity is most informative because it measures functional chemical similarity.</li>
</ul>
<p><strong>Limitations</strong>:</p>
<ul>
<li><strong>Low Exact Match</strong>: Only <strong>7.24%</strong> exact matches. The model captures the overarching functional groups and structure but misses fine details like exact double bond placement.</li>
<li><strong>Complexity Bias</strong>: Trained on large molecules (average length &gt;40 tokens), so it performs poorly on very simple structures where OSRA still excels.</li>
</ul>
<p><strong>Conclusion</strong>: The work establishes that modern architectures combined with robust molecular representations (SELFIES) can significantly outperform traditional rule-based systems. The system is already useful for literature mining where functional similarity is more important than exact matches, though low exact match accuracy and poor performance on simple molecules indicate clear directions for future work.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Image captioning system based on DETR (Detection Transformer) framework.</p>
<p><strong>Visual Encoder</strong>:</p>
<ul>
<li><strong>Backbone</strong>: ResNet-101 pre-trained on ImageNet</li>
<li><strong>Feature Extraction</strong>: 4th layer extraction (convolutions only)</li>
<li><strong>Output</strong>: 2048-dimensional dense feature vector</li>
</ul>
<p><strong>Caption Decoder</strong>:</p>
<ul>
<li><strong>Type</strong>: Transformer encoder-decoder</li>
<li><strong>Layers</strong>: 3 stacked encoder layers, 3 stacked decoder layers</li>
<li><strong>Attention Heads</strong>: 8</li>
<li><strong>Hidden Dimensions</strong>: 2048 (feed-forward networks)</li>
<li><strong>Dropout</strong>: 0.1</li>
<li><strong>Layer Normalization</strong>: 1e-12</li>
</ul>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate</strong>: 5e-5 (selected after sweep from 1e-4 to 1e-6)</li>
<li><strong>Weight Decay</strong>: 1e-4</li>
<li><strong>Batch Size</strong>: 32</li>
<li><strong>Epochs</strong>: 5</li>
<li><strong>Codebase</strong>: Built on open-source DETR implementation</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>MOLCAP Dataset</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Total Size</td>
          <td>81,230,291 molecules</td>
          <td>Aggregated from PubChem, ChEMBL, GDB13</td>
      </tr>
      <tr>
          <td>Training Split</td>
          <td>1,000,000 molecules</td>
          <td>Randomly selected unique molecules</td>
      </tr>
      <tr>
          <td>Validation Split</td>
          <td>5,000 molecules</td>
          <td>Randomly selected for evaluation</td>
      </tr>
      <tr>
          <td>Image Resolution</td>
          <td>256x256 pixels</td>
          <td>Generated using RDKit</td>
      </tr>
      <tr>
          <td>Median SELFIES Length</td>
          <td>&gt;45 characters</td>
          <td>More complex than typical benchmarks</td>
      </tr>
      <tr>
          <td>Full Dataset Storage</td>
          <td>~16.24 TB</td>
          <td>Necessitated use of 1M subset</td>
      </tr>
      <tr>
          <td>Augmentation</td>
          <td>None</td>
          <td>No cropping, rotation, or other augmentation</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li>Images generated using RDKit at 256x256 resolution</li>
<li>Molecules converted to canonical representations</li>
<li>SELFIES tokenization for model output</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metrics</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI Value</th>
          <th>OSRA Baseline</th>
          <th>Purpose</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>Fingerprint Tanimoto Similarity (functional groups)</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9238</td>
          <td>N/A</td>
          <td>RDKit fingerprint similarity</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8848</td>
          <td>N/A</td>
          <td>Morgan fingerprint similarity (circular)</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>Text overlap metric</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>Structural identity (strict)</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>Syntactic validity (with SELFIES)</td>
      </tr>
      <tr>
          <td>Levenshtein Distance</td>
          <td>Low</td>
          <td>High</td>
          <td>String edit distance</td>
      </tr>
  </tbody>
</table>
<p><strong>Secondary Metrics</strong> (shown to be less informative for chemical tasks):</p>
<ul>
<li>BLEU, ROUGE (better suited for natural language)</li>
<li>Levenshtein distance (doesn&rsquo;t capture chemical similarity)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Single NVIDIA GeForce RTX 2080 Ti</li>
<li><strong>Training Time</strong>: ~5 hours per epoch, approximately 24 hours total for 5 epochs</li>
<li><strong>Memory</strong>: Sufficient for batch size 32 with ResNet-101 + Transformer architecture</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{campos2021img2smi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Campos, Daniel and Ji, Heng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2109.04202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2109.04202}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Distributed Representations: A Foundational Theory</title><link>https://hunterheidenreich.com/notes/machine-learning/classic-papers/distributed-representations/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/classic-papers/distributed-representations/</guid><description>Hinton's 1984 technical report establishing the theoretical efficiency of distributed representations over local encoding in neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Theory</strong> paper, with strong secondary elements of <strong>Method</strong> and <strong>Position</strong>.</p>
<p>It is a theoretical work because its core contribution is the formal mathematical derivation of the storage capacity and accuracy of distributed schemes (coarse coding) compared to local schemes. It serves as a position paper by challenging the &ldquo;grandmother cell&rdquo; (local representation) intuition prevalent in AI at the time and advocating for the &ldquo;constructive&rdquo; view of memory.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The motivation is to overcome the inefficiency of <strong>local representations</strong>, where one hardware unit corresponds to exactly one entity, and to challenge traditional metaphors of memory.</p>
<ul>
<li><strong>Inefficiency</strong>: In local representations, high accuracy requires an exponential number of units (accuracy $\propto \sqrt[k]{n}$ for $k$ dimensions).</li>
<li><strong>Brittleness</strong>: Local representations lack natural support for generalization; learning a fact about one concept (e.g., &ldquo;chimps like onions&rdquo;) requires extra machinery to transfer to similar concepts (e.g., &ldquo;gorillas&rdquo;).</li>
<li><strong>Hardware Mismatch</strong>: Massive parallelism is wasted if units are active rarely (1 bit of info per unit active 50% of the time vs. almost 0 for sparse local units).</li>
<li><strong>The &ldquo;Filing Cabinet&rdquo; Metaphor</strong>: The paper challenges the standard view of memory as a storage system of literal copies. It motivates a shift toward understanding memory as a reconstructive inference process.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The paper introduces formal mechanisms that explain <em>why</em> distributed representations are superior:</p>
<ol>
<li><strong>Coarse Coding Efficiency</strong>: Hinton proves that using broad, overlapping receptive fields (&ldquo;coarse coding&rdquo;) yields higher accuracy for a fixed number of units than non-overlapping local fields. He derives that accuracy scales linearly with the number of units ($a \propto n$, a simplification of the full $a \propto n \cdot r^{k-1}$ derived in outcomes)</li>
<li><strong>Automatic Generalization</strong>: It demonstrates that generalization is an emergent property of vector overlap. Modifying weights for one pattern automatically affects similar patterns (conspiracy effect).</li>
<li><strong>Memory as Reconstruction</strong>: It posits that memory is a reconstructive process where items are created afresh from fragments using plausible inference rules (connection strengths). This blurs the line between veridical recall and confabulation.</li>
<li><strong>Gradual Concept Formation</strong>: Distributed representations allow new concepts to emerge gradually through weight modifications that progressively differentiate existing concepts. This avoids the discrete decisions and spare hardware units required by local representations.</li>
<li><strong>Solution to the Binding Problem</strong>: It proposes that true part/whole hierarchies are formed by fusing the identity of a part with its role to produce a single, new subpattern. The representation of the whole is then the sum of these combined identity/role representations.</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/distributed-representations-binding.svg"
         alt="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         title="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The binding problem solution: true hierarchies require creating unique subpatterns that fuse an identity with its role, where the whole is represented as the sum of these combined representations.</figcaption>
    
</figure>

<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The paper performs analytical derivations and two specific computer simulations:</p>
<ol>
<li><strong>Arbitrary Mapping Simulation</strong>: A 3-layer network trained to map 20 grapheme strings (e.g., words) to 20 unrelated semantic vectors.</li>
<li><strong>Damage &amp; Recovery Analysis</strong>:
<ul>
<li><strong>Lesioning</strong>: Removing &ldquo;word-set&rdquo; units to observe error patterns.</li>
<li><strong>Noise Injection</strong>: Adding noise to weights to simulate &ldquo;Deep Dyslexia&rdquo; (semantic errors like reading &ldquo;PEACH&rdquo; as &ldquo;APRICOT&rdquo;).</li>
<li><strong>Retraining</strong>: Measuring the speed of relearning after damage (&ldquo;spontaneous recovery&rdquo;).</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ol>
<li><strong>Accuracy Scaling</strong>: For a $k$-dimensional feature space, the accuracy $a$ of a distributed representation scales as $a \propto n \cdot r^{k-1}$ (where $r$ is the receptive field radius), vastly outperforming local schemes.</li>
<li><strong>Reliability</strong>: Distributed systems exhibit graceful degradation. Removing units causes slight noise across many items.</li>
<li><strong>Spontaneous Recovery</strong>: When retraining a damaged network on a subset of items, the network &ldquo;spontaneously&rdquo; recovers unrehearsed items due to weight sharing, which is a qualitative signature of distributed representations.</li>
<li><strong>Limitations of Coarse Coding</strong>: The paper identifies that coarse coding requires relatively sparse features. Crowding too many feature-points together causes receptive fields to contain too many features, preventing the activity pattern from discriminating between combinations.</li>
</ol>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The following details are extracted from Section 5 (&ldquo;Implementing an Arbitrary Mapping&rdquo;) to facilitate reproduction of the &ldquo;Deep Dyslexia&rdquo; and &ldquo;Arbitrary Mapping&rdquo; simulation.</p>
<h3 id="data">Data</h3>
<p>The simulation uses synthetic data representing words and meanings.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic Grapheme/Sememe Pairs</td>
          <td>20 pairs</td>
          <td>20 different grapheme strings mapped to random semantic vectors.</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Input (Graphemes)</strong>: 30 total units.
<ul>
<li>Structure: Divided into 3 groups of 10 units each.</li>
<li>Encoding: Each &ldquo;word&rdquo; (3 letters) activates exactly 1 unit in each group (sparse binary).</li>
</ul>
</li>
<li><strong>Output (Sememes)</strong>: 30 total units.
<ul>
<li>Structure: Binary units.</li>
<li>Encoding: Meanings are random vectors where each unit is active with probability $p=0.2$.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Learning Rule</strong>: The paper cites &ldquo;Hinton, Sejnowski &amp; Ackley (1984)&rdquo; (Boltzmann Machines) for the specific learning algorithm used to set weights.</li>
<li><strong>False Positive Analysis</strong>: The probability $f$ that a semantic feature is incorrectly activated is derived as:</li>
</ul>
<p>$$f = (1 - (1-p)^{(w-1)})^u$$</p>
<p>Where:</p>
<ul>
<li>$p$: Probability of a sememe being in a word meaning ($0.2$).</li>
<li>$w$: Number of words in a &ldquo;word-set&rdquo; (cluster).</li>
<li>$u$: Number of active &ldquo;word-set&rdquo; units per word.</li>
</ul>
<h3 id="models">Models</h3>
<p>The simulation uses a specific three-layer architecture.</p>
<table>
  <thead>
      <tr>
          <th>Layer</th>
          <th>Type</th>
          <th>Count</th>
          <th>Connectivity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Input</strong></td>
          <td>Grapheme Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units (no direct link to Output).</td>
      </tr>
      <tr>
          <td><strong>Hidden</strong></td>
          <td>&ldquo;Word-Set&rdquo; Units</td>
          <td>20</td>
          <td>Fully connected to Input and Output.</td>
      </tr>
      <tr>
          <td><strong>Output</strong></td>
          <td>Sememe Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units. Includes lateral inhibition (implied for &ldquo;clean up&rdquo;).</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Weights</strong>: Binary/Integer logic in theoretical analysis, but &ldquo;stochastic&rdquo; weights in the Boltzmann simulation.</li>
<li><strong>Thresholds</strong>: Sememe units have variable thresholds dynamically adjusted to be slightly less than the number of active word-set units.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The simulation evaluated the robustness of the mapping.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (Clean)</strong></td>
          <td>99.9%</td>
          <td>N/A</td>
          <td>Correct pattern produced 99.9% of time after learning.</td>
      </tr>
      <tr>
          <td><strong>Lesion Error Rate</strong></td>
          <td>1.4%</td>
          <td>N/A</td>
          <td>140 errors in 10,000 tests after removing 1 word-set unit.</td>
      </tr>
      <tr>
          <td><strong>Semantic Errors</strong></td>
          <td>~60% of errors</td>
          <td>N/A</td>
          <td>83 of the 140 errors were &ldquo;Deep Dyslexia&rdquo; errors (producing a valid but wrong semantic pattern).</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Minimal. The original simulation ran on 1980s hardware (likely VAX-11 or similar).</li>
<li><strong>Replication</strong>: Reproducible on any modern CPU in milliseconds.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hinton, G. E. (1984). Distributed Representations. <em>Technical Report CMU-CS-84-157</em>. Carnegie-Mellon University.</p>
<p><strong>Publication</strong>: CMU Computer Science Department Technical Report, October 1984</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{hinton1984distributed,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Distributed representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hinton, Geoffrey E}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1984}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span>=<span style="color:#e6db74">{Carnegie-Mellon University}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemical Machine Vision</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemical-machine-vision/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/chemical-machine-vision/</guid><description>Machine vision approach using Gabor wavelets and Kohonen networks to classify chemical raster images and extract structural metadata.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Gkoutos, G. V., Rzepa, H., Clark, R. M., Adjei, O., &amp; Johal, H. (2003). Chemical Machine Vision: Automated Extraction of Chemical Metadata from Raster Images. <em>Journal of Chemical Information and Computer Sciences</em>, 43(5), 1342-1355. <a href="https://doi.org/10.1021/ci034017n">https://doi.org/10.1021/ci034017n</a></p>
<p><strong>Publication</strong>: J. Chem. Inf. Comput. Sci. 2003</p>
<h2 id="paper-classification-methodological-approach">Paper Classification: Methodological Approach</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel architectural pipeline applying &ldquo;machine vision&rdquo; techniques (Gabor wavelets and Kohonen networks) to the problem of identifying chemical diagrams in low-resolution raster images. The paper focuses on the &ldquo;how&rdquo; (the algorithm and its parameters) and validates the method through quantitative experiments optimizing feature vectors and masks.</p>
<h2 id="motivation-extracting-legacy-chemical-data">Motivation: Extracting Legacy Chemical Data</h2>
<p>The primary motivation is to unlock the &ldquo;large amount of data&rdquo; trapped in legacy raster images (GIF, JPEG) on the Web that lack semantic metadata.</p>
<ul>
<li><strong>Legacy Data Problem</strong>: Most chemical structural information on the Web is embedded in raster images, not machine-readable formats like Molfiles.</li>
<li><strong>Limitations of Existing Tools</strong>: Previous tools like Kekule and CLiDE acted as &ldquo;Chemical OCR,&rdquo; attempting to reconstruct exact atom-bond connections. This required high-resolution images (&gt;300 dpi) and human intervention, making them unsuitable for automated Web crawling of low-resolution (72-96 dpi) images.</li>
<li><strong>Goal</strong>: To create a low-cost, automated tool for a &ldquo;robot-based Internet resource discovery tool&rdquo; that can classify images (e.g., &ldquo;is this a molecule?&rdquo;).</li>
</ul>
<h2 id="core-innovation-texture-recognition-over-structural-ocr">Core Innovation: Texture Recognition over Structural OCR</h2>
<p>The core novelty is the shift from &ldquo;Optical Character Recognition&rdquo; (exact reconstruction) to <strong>&ldquo;Texture Recognition&rdquo;</strong> (classification).</p>
<ul>
<li><strong>Texture-Based Approach</strong>: The authors treat chemical diagrams as textures. They use <strong>Gabor wavelets</strong> to extract texture features. <strong>Crucially, this system does not recognize specific chemical structures</strong> (i.e., atom-bond connectivity tables, <a href="/notes/computational-chemistry/molecular-representations/smiles/">SMILES</a>, or Molfiles). It only classifies images into broad categories.</li>
<li><strong>Incremental Learning</strong>: The system uses a <strong>Kohonen Self-Organizing Feature Map (KSOFM)</strong> combined with Class Boundary Analysis (CBA). This allows for &ldquo;incremental learning,&rdquo; where new classes (e.g., aromatic vs. non-aromatic) can be added without retraining the entire system.</li>
<li><strong>Optimization for Chemistry</strong>: The authors identify specific parameters (frequency channels, mask sizes) that are optimal for the &ldquo;texture&rdquo; of chemical diagrams.</li>
<li><strong>Integration with ChemDig</strong>: The method was designed to feed into ChemDig, a robot-based index engine for automated web crawling and metadata generation.</li>
</ul>
<h2 id="experimental-setup-parameter-optimization">Experimental Setup: Parameter Optimization</h2>
<p>The authors performed optimization and validation experiments using a dataset of <strong>300 images</strong> divided into three classes: Ring Systems, Non-Ring Systems, and Non-Chemistry (textures, biological figures, etc.).</p>
<ol>
<li><strong>Parameter Optimization</strong>: They systematically varied hyperparameters to find the optimal configuration:
<ul>
<li><strong>Feature Vector Size</strong>: Tested sizes from 100 to 4000 elements.</li>
<li><strong>Energy Mask Size</strong>: Tested windows from $3 \times 3$ to $15 \times 15$ pixels.</li>
<li><strong>Frequency Channels</strong>: Tested seven spatial frequencies ($\sqrt{2}$ to $64\sqrt{2}$).</li>
</ul>
</li>
<li><strong>Classification Performance</strong>: Evaluated the system&rsquo;s ability to classify unseen test images using a 50:50 training/test split.</li>
<li><strong>Comparison</strong>: Qualitatively compared the approach against vectorization tools (Autotrace, CR2V).</li>
</ol>
<h2 id="results-robust-classification-of-low-resolution-images">Results: Robust Classification of Low-Resolution Images</h2>
<ul>
<li><strong>Optimal Configuration</strong>: The system performed best with a feature vector size of ~1500 elements, a $9 \x9$ energy mask, and frequency channel $4\sqrt{2}$.</li>
<li><strong>High Accuracy</strong>: Achieved a recognition rate of <strong>91%</strong> with a 50:50 training/test split, and up to <strong>92%</strong> with a 70:30 split.</li>
<li><strong>Robustness</strong>: The system successfully distinguished between chemical and non-chemical images (zero false negatives for chemical images).</li>
<li><strong>Limitations</strong>: Misclassifications occurred between &ldquo;ring&rdquo; and &ldquo;non-ring&rdquo; systems when structures had similar visual &ldquo;textures&rdquo; (e.g., similar density or layout).</li>
<li><strong>Impact</strong>: The method is viable for automating metadata generation (e.g., <code>alt</code> tags) for web crawlers, functioning as a coarse-grained filter before more expensive processing.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used a custom dataset of raster images collected from the Web.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td><strong>Custom Web Dataset</strong></td>
          <td>300 images</td>
          <td>Split into 3 classes: Ring Systems, Non-Ring Systems, Non-Chemistry.</td>
      </tr>
      <tr>
          <td>Resolution</td>
          <td><strong>Low-Res Web Images</strong></td>
          <td>72-96 dpi</td>
          <td>Deliberately chosen to mimic Web conditions where OCR fails.</td>
      </tr>
      <tr>
          <td>Format</td>
          <td><strong>Raster</strong></td>
          <td>GIF, JPEG</td>
          <td>Typical web formats.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The core pipeline consists of a <strong>Gabor Transform Unit</strong> followed by a <strong>Training/Classification Unit</strong>.</p>
<ul>
<li><strong>Gabor Wavelets</strong>: Used for feature extraction. The 2D Gabor wavelet equation is:
$$h(x,y)=\exp\left{-\frac{1}{2}\left[\frac{x^{2}}{\sigma_{x}^{2}}+\frac{y^{2}}{\sigma_{y}^{2}}\right]\right}\cos(2\pi\mu_{\sigma}x+\phi)$$
<ul>
<li><strong>Bank Structure</strong>: 28 filters total (4 orientations $\times$ 7 radial frequencies).</li>
<li><strong>Orientations</strong>: $0^{\circ}, 45^{\circ}, 90^{\circ}, 135^{\circ}$.</li>
<li><strong>Frequencies</strong>: 1 octave apart, specifically $1\sqrt{2}, \dots, 64\sqrt{2}$.</li>
<li><strong>Selected Frequency</strong>: $4\sqrt{2}$ was found to be optimal for chemistry.</li>
</ul>
</li>
<li><strong>Preprocessing</strong>:
<ul>
<li><strong>Buffer Mounting</strong>: Images are mounted in a buffer (set to 0) to handle edge artifacts.</li>
<li><strong>Look-Up-Tables (LUT/LUF)</strong>: A binary Look-Up-Frame (LUF) indicates Regions of Interest (ROI) to avoid computing empty space; values are stored in a Look-Up-Table (LUT) to prevent re-computation of overlapping windows.</li>
</ul>
</li>
<li><strong>Feature Extraction</strong>:
<ul>
<li><strong>Non-linear Thresholding</strong>: $\psi(t) = \tanh(\alpha t)$ with $\alpha = 0.25$.</li>
<li><strong>Energy Function</strong>: Calculated as average absolute deviation from the mean using a window $W_{xy}$.
$$e_{k}(x,y)=\frac{1}{M^{2}}\sum_{(a,b)\in W_{xy}}|\psi(r_{k}(a,b))|$$</li>
<li><strong>Optimal Window</strong>: $9 \times 9$ pixels.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The classification model relies on competitive learning.</p>
<ul>
<li><strong>Architecture</strong>: <strong>Kohonen Self-Organizing Feature Map (KSOFM)</strong>.</li>
<li><strong>Training</strong>:
<ul>
<li><strong>Learning Rate</strong>: Starts at 1.0, decreases to 0.1.</li>
<li><strong>Class Boundary Analysis (CBA)</strong>: Computes the centroid (mean) and variance of each cluster. The variance defines the class boundary.</li>
</ul>
</li>
<li><strong>Classification Metric</strong>: <strong>Euclidean Distance Norm</strong>. An unknown vector is classified based on the shortest distance to a cluster center, provided it falls within the variance boundary.
$$D_{ij}=||x_{i}-x_{j}||$$</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance was measured using recognition rate ($R_s$) and misclassification error ($E_s$).</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Recognition Rate</td>
          <td><strong>91%</strong></td>
          <td>N/A</td>
          <td>Achieved with 50:50 split. 92% with 70:30 split.</td>
      </tr>
      <tr>
          <td>Feature Size</td>
          <td><strong>~1500</strong></td>
          <td>4000</td>
          <td>Reducing vector size from 4000 to 1500 maintained ~80% accuracy while improving speed.</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{gkoutosChemicalMachineVision2003,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Chemical {{Machine Vision}}: {{Automated Extraction}} of {{Chemical Metadata}} from {{Raster Images}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Chemical {{Machine Vision}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Gkoutos, Georgios V. and Rzepa, Henry and Clark, Richard M. and Adjei, Osei and Johal, Harpal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2003</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Computer Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{43}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1342--1355}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{0095-2338}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/ci034017n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>IWAE: Importance Weighted Autoencoders</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</guid><description>Summary of Burda, Grosse &amp; Salakhutdinov's ICLR 2016 paper introducing Importance Weighted Autoencoders for tighter variational bounds</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a generative model that shares the same architecture as the Variational Autoencoder (VAE) but uses a different, tighter objective function. The key innovation is using importance weighting to derive a strictly tighter log-likelihood lower bound than the standard VAE objective.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The standard VAE has several limitations that motivated this work:</p>
<ul>
<li><strong>Strong assumptions</strong>: VAEs typically assume the posterior distribution is simple (e.g., approximately factorial) and that its parameters can be easily approximated from observations.</li>
<li><strong>Simplified representations</strong>: The VAE objective can force models to learn overly simplified representations that underutilize the network&rsquo;s full modeling capacity.</li>
<li><strong>Harsh penalization</strong>: The VAE objective harshly penalizes approximate posterior samples that are poor explanations for the data, which can be overly restrictive.</li>
<li><strong>Inactive units</strong>: VAEs tend to learn latent spaces with effective dimensions far below their capacity, where many latent units are ignored (a phenomenon later termed <strong>posterior collapse</strong>, where the approximate posterior collapses to the prior and conveys no information). The authors wanted to investigate whether a new objective could address this issue.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>IWAE objective function</strong>, denoted as $\mathcal{L}_{k}$.</p>
<ul>
<li>
<p><strong>VAE ($\mathcal{L}_{1}$ Bound)</strong>: The standard VAE maximizes $\mathcal{L}(x)=\mathbb{E}_{q(h|x)}[\log\frac{p(x,h)}{q(h|x)}]$. This is equivalent to the new bound when $k=1$.</p>
</li>
<li>
<p><strong>IWAE ($\mathcal{L}_{k}$ Bound)</strong>: The IWAE maximizes a tighter bound that uses $k$ samples drawn from the recognition model $q(h|x)$:</p>
</li>
</ul>
<p>$$\mathcal{L}_{k}(x)=\mathbb{E}_{h_{1},&hellip;,h_{k}\sim q(h|x)}\left[\log\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,h_{i})}{q(h_{i}|x)}\right]$$</p>
<ul>
<li>
<p><strong>Tighter Bound</strong>: The authors prove that this bound is always tighter than or equal to the VAE bound ($\mathcal{L}_{k+1} \geq \mathcal{L}_{k}$) and that as $k$ approaches infinity, $\mathcal{L}_{k}$ approaches the true log-likelihood $\log p(x)$.</p>
</li>
<li>
<p><strong>Increased Flexibility</strong>: Using multiple samples gives the IWAE additional flexibility to learn generative models whose posterior distributions are complex and violate the VAE&rsquo;s simplifying assumptions.</p>
</li>
</ul>
<h3 id="key-concept-averaging-inside-vs-outside-the-log">Key Concept: Averaging Inside vs. Outside the Log</h3>
<p>A crucial distinction exists between how VAE and IWAE utilize $k$ samples. Understanding this difference explains why increasing $k$ in IWAE improves the bound. In VAE, it reduces variance.</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-vae-vs-importance-weighted-autoencoder-iwae.webp"
         alt="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         title="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">VAE vs IWAE computation flow. The key difference is where the log operation occurs: VAE computes log(w_i) for each sample then averages. IWAE averages the weights first then applies log to the result.</figcaption>
    
</figure>

<p><strong>VAE (Average of Logs):</strong></p>
<p>For a VAE, using $k$ samples approximates:</p>
<p>$$\mathbb{E}\left[ \frac{1}{k} \sum_{i=1}^k \log w_i \right] \approx \text{ELBO}$$</p>
<p>where $w_i = p(x, h_i) / q(h_i | x)$. Increasing $k$ here only reduces the variance of the gradient estimator; the model still targets the same ELBO bound, so performance gains saturate quickly.</p>
<p><strong>IWAE (Log of Average):</strong></p>
<p>IWAE performs the averaging <em>inside</em> the logarithm:</p>
<p>$$\mathbb{E}\left[ \log \left( \frac{1}{k} \sum_{i=1}^k w_i \right) \right] = \mathcal{L}_k$$</p>
<p>By Jensen&rsquo;s Inequality ($\log(\mathbb{E}[X]) \geq \mathbb{E}[\log(X)]$ for concave functions), this bound is mathematically guaranteed to be at least as tight as the VAE bound. Each increase in $k$ defines a new, strictly tighter lower bound on the log-likelihood.</p>
<p><strong>Why This Matters for Gradients:</strong></p>
<p>In IWAE, the gradient weights are normalized importance weights $\tilde{w}_i = w_i / \sum_j w_j$. This means &ldquo;bad&rdquo; samples (those with low $w_i$) contribute very little to the gradient update since they vanish from the weighted sum. VAE uses unweighted samples, so a single sample with extremely low probability produces a massive negative log value that can dominate the loss and harshly penalize the model. IWAE&rsquo;s formulation allows the model to focus learning on the samples that explain the data well.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors compared VAE and IWAE on density estimation tasks using the MNIST and Omniglot datasets. They evaluated two main network architectures: one with a single stochastic layer and another with two stochastic layers. The models were trained with varying numbers of importance samples ($k \in {1, 5, 50}$) to observe the effect on performance and latent space utilization. The primary metrics for evaluation were the test log-likelihood (estimated using 5000 samples) and the number of &ldquo;active&rdquo; latent units, which quantifies the richness of the learned representations.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-vs-vae-active-latent-units-comparison.webp"
         alt="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         title="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Active latent units for VAE vs IWAE (1 stochastic layer). VAE active units remain flat. IWAE increases with k. Data from Table 1 of Burda et al. (2016).</figcaption>
    
</figure>

<ul>
<li>
<p><strong>Better Performance</strong>: IWAE achieved significantly higher log-likelihoods than VAEs. IWAE performance improved with increasing $k$. VAE performance benefited slightly from using more samples ($k&gt;1$).</p>
</li>
<li>
<p><strong>Richer Representations</strong>: In all experiments with $k&gt;1$, IWAE learned more active latent dimensions than VAE, suggesting richer latent representations.</p>
</li>
<li>
<p><strong>Objective Drives Representation</strong>: The authors found that latent dimension inactivation is driven by the objective function. They demonstrated this through an &ldquo;objective swap&rdquo; experiment:</p>
</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-objective-swap-experiment.webp"
         alt="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         title="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Objective swap experiment on MNIST (1 stochastic layer). Switching a trained VAE to the IWAE objective improves both metrics. Switching IWAE to VAE degrades them. Data from Table 2 of Burda et al. (2016).</figcaption>
    
</figure>

<p>This experiment provides causal evidence that the objective function itself determines latent utilization:</p>
<ul>
<li><strong>VAE → IWAE</strong>: A converged VAE model, when fine-tuned with the IWAE objective ($k=50$), gained 3 active units (19 → 22) and improved test NLL from 86.76 to 84.88.</li>
<li><strong>IWAE → VAE</strong>: A converged IWAE model fine-tuned with the VAE objective lost 2 active units (25 → 23) and worsened test NLL from 84.78 to 86.02.</li>
</ul>
<p>The symmetry of these results rules out explanations based on optimization dynamics, initialization, or architecture. The objective function is the determining factor.</p>
<ul>
<li><strong>Conclusion</strong>: IWAEs learn richer latent representations and achieve better generative performance than VAEs with equivalent architectures and training time.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: $28 \times 28$ binarized handwritten digits (60,000 training / 10,000 test).</li>
<li><strong>Omniglot</strong>: $28 \times 28$ binarized handwritten characters from various alphabets (24,345 training / 8,070 test).</li>
<li><strong>Binarization</strong>: Dynamic sampling where binary values are sampled with expectations equal to the real pixel intensities (following Salakhutdinov &amp; Murray, 2008).</li>
<li><strong>Fixed Binarization</strong>: Results on a fixed binarization of MNIST (Larochelle, 2011) confirm that IWAE outperforms VAE across preprocessing methods. It exhibits notably more overfitting compared to dynamic sampling.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main network architectures were tested:</p>
<ol>
<li>One stochastic layer (50 units) with two deterministic layers (200 units each).</li>
<li>Two stochastic layers (100 and 50 units) with deterministic layers in between.</li>
</ol>
<ul>
<li><strong>Activations</strong>: <code>tanh</code> for deterministic layers; <code>exp</code> applied to variance predictions to ensure positivity.</li>
<li><strong>Distributions</strong>: Gaussian latent layers with diagonal covariance; Bernoulli observation layer.</li>
<li><strong>Initialization</strong>: Glorot &amp; Bengio (2010) heuristic.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9$, $\beta_2=0.999$, $\epsilon=10^{-4}$).</li>
<li><strong>Batch Size</strong>: 20.</li>
<li><strong>Learning Rate Schedule</strong>: Annealed rate of $0.001 \cdot 10^{-i/7}$ for $3^i$ epochs (where $i=0&hellip;7$), totaling 3,280 passes over the data.</li>
<li><strong>Variance Control</strong>: A common concern with importance sampling is high variance. The authors prove that the Mean Absolute Deviation of their estimator is bounded by $2 + 2\delta$, where $\delta$ is the gap between the bound and true log-likelihood. As the bound tightens, variance remains controlled.</li>
<li><strong>Computational trick</strong>: Despite using $k$ samples per update, IWAE requires only $k$ forward passes and a single backward pass over all $k$ samples (rather than $k$ separate backward passes). The normalized importance weights $\tilde{w}_i$ serve as per-sample multipliers, so the backward cost is still $O(k)$ but avoids the overhead of $k$ separate gradient accumulations.</li>
</ul>
<p><strong>Relationship to Reweighted Wake-Sleep (RWS):</strong> Both IWAE and Reweighted Wake-Sleep (Bornschein &amp; Bengio, 2015) use importance-weighted samples and have closely related generative model updates. The key difference is that IWAE derives a single unified lower bound $\mathcal{L}_k$ and uses the reparameterization trick to train the recognition network jointly. RWS instead uses separate wake and sleep phases for the recognition network, which are not derived from $\mathcal{L}_k$.</p>
<h3 id="evaluation">Evaluation</h3>
<ol>
<li><strong>Test Log-Likelihood</strong>: Primary measure of generative performance, estimated as the mean of $\mathcal{L}_{5000}$ (5000 samples) on the test set.</li>
<li><strong>Active Units</strong>: To quantify latent space richness, the authors measured &ldquo;active&rdquo; latent dimensions. A unit $u$ was defined as active if its activity statistic $A_{u}=\text{Cov}_{x}(\mathbb{E}_{u\sim q(u|x)}[u])$ exceeded $10^{-2}$. The $10^{-2}$ threshold is justified by a bimodal distribution of the log activity statistic, showing clear separation between active and inactive units.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: GPU-based implementation using mini-batch replication to parallelize the $k$ samples.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Burda, Y., Grosse, R., &amp; Salakhutdinov, R. (2016). Importance Weighted Autoencoders. <em>International Conference on Learning Representations (ICLR) 2016</em>. <a href="https://arxiv.org/abs/1509.00519">https://arxiv.org/abs/1509.00519</a></p>
<p><strong>Publication</strong>: ICLR 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{burda2016importance,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Importance Weighted Autoencoders}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yuri Burda and Roger Grosse and Ruslan Salakhutdinov}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1509.00519}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1509.00519">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself unlocks the true benefits. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a fundamentally more powerful objective.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and this mathematical difference is profound.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two huge practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to the &ldquo;dead units&rdquo; problem.</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To fix a &ldquo;dead unit&rdquo; problem or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>Auto-Encoding Variational Bayes: VAE Paper Summary</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</guid><description>Summary of Kingma &amp; Welling's foundational VAE paper introducing the reparameterization trick and variational autoencoders.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a foundational <strong>Method</strong> paper that introduces a novel generative mechanism (the VAE) and a new optimization technique (the reparameterization trick), supported by strong theoretical derivation. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the <strong>variational auto-encoder (VAE)</strong> when neural networks are used as the recognition model.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors address two central intractabilities in directed probabilistic models with continuous latent variables:</p>















<figure class="post-figure center ">
    <img src="/img/notes/autoencoding-variational-bayes-figure-1-model-diagram.webp"
         alt="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         title="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Figure 1 from the paper: The directed graphical model. Solid lines denote the generative model $p_\theta(z)p_\theta(x|z)$, dashed lines denote the variational approximation $q_\phi(z|x)$. The variational parameters $\phi$ are learned jointly with the generative parameters $\theta$.</figcaption>
    
</figure>

<ol>
<li>
<p><strong>Intractable Posteriors</strong>: In models with continuous latent variables (like those with non-linear hidden layers), the true posterior $p_{\theta}(z|x)$ cannot be calculated analytically, preventing the use of standard EM algorithms.</p>
</li>
<li>
<p><strong>Large Datasets</strong>: Sampling-based solutions like Monte Carlo EM (MCEM) require expensive sampling loops per datapoint. This makes them too slow for large datasets where batch optimization is too costly and efficient minibatch updates are required.</p>
</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<h3 id="the-reparameterization-trick-sgvb-estimator">The Reparameterization Trick (SGVB Estimator)</h3>
<p>The core innovation is the <strong>Stochastic Gradient Variational Bayes (SGVB)</strong> estimator. The authors solve the high variance of standard gradient estimation by &ldquo;reparameterizing&rdquo; the random variable $\tilde{z}$.</p>
<p>They express $z$ as a deterministic function of the input $x$ and an auxiliary noise variable $\epsilon$:</p>
<p>$$\tilde{z} = g_{\phi}(\epsilon, x) \quad \text{with} \quad \epsilon \sim p(\epsilon)$$</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-reparameterization-trick.webp"
         alt="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         title="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reparameterization trick. (A) Standard stochastic nodes block gradient flow during backpropagation. (B) By expressing $z = \mu + \sigma \odot \epsilon$ with external noise $\epsilon \sim \mathcal{N}(0,1)$, gradients can flow through the deterministic path to the parameters $\phi$.</figcaption>
    
</figure>

<ul>
<li><strong>Mechanism</strong>: For a Gaussian posterior, $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim \mathcal{N}(0, I)$.</li>
<li><strong>Impact</strong>: This makes the Monte Carlo estimate differentiable with respect to the variational parameters $\phi$, allowing the variational lower bound to be optimized via standard stochastic gradient ascent (like SGD or Adagrad).</li>
</ul>
<h3 id="the-aevb-algorithm-the-vae">The AEVB Algorithm (The VAE)</h3>
<p>The <strong>Auto-Encoding VB (AEVB)</strong> algorithm amortizes inference by learning a global recognition model (encoder) $q_{\phi}(z|x)$ jointly with the generative model (decoder) $p_{\theta}(x|z)$.</p>
<p><strong>Objective Function</strong>: Maximize the variational lower bound $\mathcal{L}(\theta, \phi; x^{(i)})$:</p>
<p>$$\mathcal{L} \simeq -D_{KL}(q_\phi(z|x^{(i)}) | p_\theta(z)) + \frac{1}{L} \sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})$$</p>
<ul>
<li><strong>First Term (Regularizer)</strong>: Forces the approximate posterior to match the prior (integrated analytically for Gaussians).</li>
<li><strong>Second Term (Reconstruction Error)</strong>: The expected negative reconstruction error (estimated via sampling).</li>
</ul>
<p>This mirrors the standard auto-encoder objective, adding a variational regularizer.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was benchmarked against the <strong>Wake-Sleep</strong> algorithm and <strong>Monte Carlo EM (MCEM)</strong> using the <strong>MNIST</strong> (digits) and <strong>Frey Face</strong> (continuous faces) datasets.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li>
<p><strong>Efficiency</strong>: AEVB converged faster and reached a better lower bound than Wake-Sleep (Figure 2). It scaled efficiently to the full MNIST dataset. MCEM&rsquo;s per-datapoint sampling cost made it impractical at full dataset scale, so comparisons were limited to small subsets (Figure 3).</p>
</li>
<li>
<p><strong>Regularization</strong>: The KL-divergence term provided a regularizing effect, preventing overfitting while increasing latent dimensions ($N_z$).</p>
</li>
<li>
<p><strong>Manifold Learning</strong>: The model successfully learned smooth 2D latent manifolds (visualized in Appendix A), grouping similar digits/faces together.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Evaluation Data</strong>: For the marginal likelihood comparison (Figure 3), the paper used MNIST with varying training set sizes (N_train = 100 and N_train = 5000) to compare data efficiency (marginal log-likelihood vs. training samples seen) across algorithms. A smaller network (100 hidden units, 3 latent variables) was used for this comparison because the marginal likelihood estimator only works reliably in low-dimensional latent spaces.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Algorithm</strong>: Stochastic gradient ascent with <strong>Adagrad</strong> (global stepsizes chosen from ${0.01, 0.02, 0.1}$).</li>
<li><strong>Regularization</strong>: The objective included a weight decay term corresponding to a prior $p(\theta)=\mathcal{N}(0,I)$.</li>
<li><strong>Minibatches</strong>: Size $M=100$ with $L=1$ sample per datapoint.</li>
<li><strong>Initialization</strong>: Parameters sampled from $\mathcal{N}(0, 0.01)$.</li>
</ul>
<h3 id="models">Models</h3>
<p>The original VAE used simple Multi-Layered Perceptrons (MLPs):</p>
<ul>
<li><strong>Symmetry</strong>: The encoder and decoder were symmetric, having an equal number of hidden units.</li>
<li><strong>Hidden Units</strong>: 500 units for MNIST, 200 for Frey Face (to prevent overfitting on the smaller dataset).</li>
<li><strong>Activations</strong>: <strong>Tanh</strong> activation functions for the hidden layers.</li>
<li><strong>Latent Space</strong>: Experimented with $N_z$ ranging from 2 to 200.</li>
<li><strong>Outputs</strong>:
<ul>
<li><em>MNIST</em>: <strong>Bernoulli</strong> MLP (sigmoid output).</li>
<li><em>Frey Face</em>: <strong>Gaussian</strong> MLP, with means constrained to $(0,1)$ via sigmoid.</li>
</ul>
</li>
<li><strong>Encoder Architecture</strong>: For the Gaussian encoder, the mean $\mu$ and log-variance $\log(\sigma^2)$ are linear outputs from the shared hidden layer (they share the hidden layer weights and have separate output weights).</li>
<li><strong>Log-Variance</strong>: The encoder predicted $\log(\sigma^2)$ for numerical stability.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The paper distinguishes between two metrics:</p>
<ul>
<li><strong>Variational Lower Bound</strong>: Used as the training objective (what the model optimizes).</li>
<li><strong>Marginal Likelihood</strong>: Used for final evaluation (Figure 3). The true marginal likelihood $p_\theta(x)$ was estimated using an Importance Sampling estimator constructed from samples drawn via Hybrid Monte Carlo (HMC), as detailed in Appendix D. This estimator uses: $p_{\theta}(x^{(i)}) \simeq (\frac{1}{L}\sum \frac{q(z)}{p(z)p(x|z)})^{-1}$.</li>
</ul>
<p>This distinction is critical: the training metric (lower bound) differs from the evaluation metric (estimated marginal likelihood).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: Trained on a standard Intel Xeon CPU (approx. 40 GFLOPS); no GPUs were used.</li>
<li><strong>Training Time</strong>: Approximately 20-40 minutes per million training samples.</li>
</ul>
<h3 id="key-implementation-details-from-appendices">Key Implementation Details from Appendices</h3>
<ul>
<li><strong>Appendix A</strong>: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets.</li>
<li><strong>Appendix B</strong>: Closed-form solution for the KL divergence of two Gaussians, essential for implementing the efficient version of the estimator (Equation 10).</li>
<li><strong>Appendix C</strong>: Exact MLP equations, including the use of tanh hidden layers and specific output layers for Bernoulli vs. Gaussian data. Includes specifications for <strong>Bernoulli MLPs</strong> (binary data) and <strong>Gaussian MLPs</strong> (real-valued data).</li>
<li><strong>Appendix D</strong>: Marginal likelihood estimation protocol using Hybrid Monte Carlo (HMC) and importance sampling for evaluation (Figure 3).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Diederik P. Kingma and Max Welling. &ldquo;Auto-Encoding Variational Bayes.&rdquo; arXiv:1312.6114 [stat.ML], 2013. <a href="https://doi.org/10.48550/arXiv.1312.6114">https://doi.org/10.48550/arXiv.1312.6114</a></p>
<p><strong>Publication</strong>: arXiv preprint (originally 2013)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kingma2022autoencodingvariationalbayes,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Variational Bayes}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Diederik P Kingma and Max Welling}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2013}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1312.6114}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{stat.ML}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1312.6114}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Variational_autoencoder">Wikipedia: Variational Autoencoder</a> - General overview</li>
<li><a href="https://openreview.net/forum?id=33X9fd2-9FyZd">OpenReview</a> - Original peer review with author responses</li>
<li><a href="/posts/modern-variational-autoencoder-in-pytorch/">Modern VAE in PyTorch</a> - Implementation tutorial on this site</li>
</ul>
]]></content:encoded></item><item><title>αExtractor: Chemical Info from Biomedical Literature</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/alpha-extractor/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/alpha-extractor/</guid><description>αExtractor uses ResNet-Transformer to extract chemical structures from literature images, including noisy and hand-drawn molecules.</description><content:encoded><![CDATA[<h2 id="methodological-contribution-a-robust-optical-recognition-system">Methodological Contribution: A Robust Optical Recognition System</h2>
<p>This is primarily a <strong>Method</strong> ($\Psi_{\text{Method}}$) paper with a significant secondary <strong>Resource</strong> ($\Psi_{\text{Resource}}$) contribution (see the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> for more on these categories).</p>
<p>The dominant methodological contribution is the ResNet-Transformer recognition architecture that achieves state-of-the-art performance through robustness engineering. It specifically focuses on training on 20 million synthetic images with aggressive augmentation to handle degraded image conditions. The work answers the core methodological question &ldquo;How well does this work?&rdquo; through extensive benchmarking against existing OCSR tools and ablation studies validating architectural choices.</p>
<p>The secondary resource contribution comes from releasing αExtractor as a freely available web service, correcting labeling errors in standard benchmarks (CLEF, UOB, JPO), and providing an end-to-end document processing pipeline for biomedical literature mining.</p>
<h2 id="motivation-unlocking-visual-chemical-knowledge-in-biomedical-literature">Motivation: Unlocking Visual Chemical Knowledge in Biomedical Literature</h2>
<p>The motivation addresses a familiar pain point in chemical informatics within a biomedical context. Vast amounts of chemical knowledge in biomedical literature exist only as images, such as molecular structures embedded in figures, chemical synthesis schemes, and compound diagrams. This visual knowledge remains effectively invisible to computational methods, which creates a massive bottleneck for drug discovery research, systematic reviews, and large-scale chemical database construction.</p>
<p>Existing OCSR tools face two critical problems when applied to biomedical literature:</p>
<ol>
<li>
<p><strong>Real-world image quality</strong>: Biomedical papers often contain low-resolution figures, images with complex backgrounds, noise from scanning/digitization, and inconsistent drawing styles across different journals and decades of publications.</p>
</li>
<li>
<p><strong>End-to-end extraction</strong>: Most OCSR systems assume the presence of clean, cropped molecular images. In practice, you need to first find the molecular structures within multi-panel figures, reaction schemes, and dense document layouts before you can recognize them.</p>
</li>
</ol>
<p>The authors argue that a practical literature mining system needs to solve both problems simultaneously via robust recognition under noisy conditions and automated detection of molecular images within complex documents.</p>
<h2 id="core-innovation-robust-resnet-transformer-architecture">Core Innovation: Robust ResNet-Transformer Architecture</h2>
<p>The core innovation lies in combining a competition-winning recognition architecture with extensive robustness engineering and end-to-end document processing. The key contributions include:</p>
<ol>
<li>
<p><strong>ResNet-Transformer Recognition Model</strong>: The core recognition system uses a <strong>Residual Neural Network (ResNet)</strong> encoder paired with a <strong>Transformer decoder</strong> in an image-captioning framework. This architecture won first place in a Kaggle molecular translation competition, which provided a strong foundation for the recognition task. Let the input image be $I$. The model maximizes the joint likelihood of the SMILES tokens $T$ and coordinate sequences $X, Y$:
$$
\begin{aligned}
\mathcal{L}_{\text{total}} = - \sum_{i=1}^{L} \log P(T_i \mid I, T_{&lt;i}) - \lambda \sum_{i=1}^{L} \big(\log P(X_i \mid I, X_{&lt;i}) + \log P(Y_i \mid I, Y_{&lt;i})\big)
\end{aligned}
$$
Here, continuous $X$ and $Y$ atom coordinates are mapped strictly to 200 discrete bins to formulate the coordinate prediction as a standard classification task alongside SMILES generation.</p>
</li>
<li>
<p><strong>Enhanced Molecular Representation</strong>: The model produces an augmented representation that encompasses:</p>
<ul>
<li>Standard molecular connectivity information</li>
<li><strong>Bond type tokens</strong> (solid wedge bonds, dashed bonds, etc.) that preserve 3D stereochemical information</li>
<li><strong>Atom coordinate predictions</strong> that allow reconstruction of the exact molecular pose from the original image</li>
</ul>
<p>This dual prediction of discrete structure and continuous coordinates makes the output strictly faithful to the source material and enables better quality assessment.</p>
</li>
<li>
<p><strong>Massive Synthetic Training Dataset</strong>: The model was trained on approximately <strong>20 million synthetic molecular images</strong> generated from PubChem SMILES with aggressive data augmentation. The augmentation strategy randomized visual styles, image quality, and rendering parameters to create maximum diversity, ensuring the network rarely saw the same molecular depiction twice. This forces the model to learn robust, style-invariant features.</p>
</li>
<li>
<p><strong>End-to-End Document Processing Pipeline</strong>: αExtractor integrates <strong>object detection</strong> and <strong>structure recognition</strong> into a complete document mining system:</p>
<ul>
<li>An object detection model automatically locates molecular images within PDF documents</li>
<li>The recognition model converts detected images to structured representations</li>
<li>A web service interface makes the entire pipeline accessible to researchers without machine learning expertise</li>
</ul>
</li>
<li>
<p><strong>Robustness-First Design</strong>: The system was explicitly designed to handle degraded image conditions that break traditional OCSR tools, including low resolution, background interference, color variations, and scanning artifacts commonly found in legacy biomedical literature.</p>
</li>
</ol>
<h2 id="experimental-methodology-stress-testing-under-real-world-conditions">Experimental Methodology: Stress Testing under Real-World Conditions</h2>
<p>The evaluation focused on demonstrating robust performance across diverse image conditions, from pristine benchmarks to challenging real-world scenarios:</p>
<ol>
<li>
<p><strong>Benchmark Dataset Evaluation</strong>: αExtractor was tested on four standard OCSR benchmarks:</p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
</ul>
<p>Performance was measured using exact SMILES match accuracy.</p>
</li>
<li>
<p><strong>Error Analysis and Dataset Correction</strong>: During evaluation, the researchers discovered numerous labeling errors in the original benchmark datasets. They systematically identified and corrected these errors, then re-evaluated all methods on the cleaned datasets to get more accurate performance measurements.</p>
</li>
<li>
<p><strong>Robustness Stress Testing</strong>: The system was evaluated on two challenging datasets specifically designed to test robustness:</p>
<ul>
<li><strong>Color background images</strong> (200 samples): Molecular structures on complex, colorful backgrounds that simulate real figure conditions</li>
<li><strong>Low-quality images</strong> (200 samples): Degraded images with noise, blur, and artifacts typical of scanned documents</li>
</ul>
<p>These tests compared αExtractor against open-source alternatives under realistic degradation conditions.</p>
</li>
<li>
<p><strong>Generalization Testing</strong>: In the most challenging experiment, αExtractor was tested on <strong>hand-drawn molecular structures</strong>, representing a completely different visual domain not represented in the training data. This tested whether the learned features could generalize beyond digital rendering styles to human-drawn chemistry.</p>
</li>
<li>
<p><strong>End-to-End Document Extraction</strong>: The complete pipeline was evaluated on 50 PDF files containing 2,336 molecular images. This tested both the object detection component (finding molecules in complex documents) and the recognition component (converting them to SMILES) in a realistic literature mining scenario.</p>
</li>
<li>
<p><strong>Speed Benchmarking</strong>: Inference time was measured to demonstrate the practical efficiency needed for large-scale document processing.</p>
</li>
</ol>
<h2 id="results--conclusions-breakthrough-in-degraded-image-performance">Results &amp; Conclusions: Breakthrough in Degraded Image Performance</h2>
<ul>
<li>
<p><strong>Substantial Accuracy Gains</strong>: On the four benchmark datasets, αExtractor achieved accuracies of 91.83% (CLEF), 98.47% (UOB), 88.67% (JPO), and 93.64% (USPTO), significantly outperforming existing methods. After correcting dataset labeling errors, the true accuracies were even higher, reaching <strong>95.77% on CLEF, 99.86% on UOB, and 92.44% on JPO</strong>.</p>
</li>
<li>
<p><strong>Exceptional Robustness</strong>: Open-source competitors struggled on degraded images (achieving 5.5% accuracy at best). αExtractor maintained <strong>over 90% accuracy</strong> on both color background and low-quality image datasets, demonstrating the effectiveness of the massive synthetic training strategy.</p>
</li>
<li>
<p><strong>Remarkable Generalization</strong>: On hand-drawn molecules, a domain completely absent from training data, αExtractor achieved <strong>61.4% accuracy</strong> while other tools scored below 3%. This suggests the model learned genuinely chemical features.</p>
</li>
<li>
<p><strong>Practical End-to-End Performance</strong>: In the complete document processing evaluation, αExtractor detected <strong>95.1% of molecular images</strong> (2,221 out of 2,336) and correctly recognized <strong>94.5% of detected structures</strong> (2,098 correct predictions). This demonstrates the system&rsquo;s readiness for real-world literature mining applications.</p>
</li>
<li>
<p><strong>Dataset Quality Issues</strong>: The systematic discovery of labeling errors in standard benchmarks highlights a broader problem in OCSR evaluation. The corrected datasets provide more reliable baselines for future method development.</p>
</li>
<li>
<p><strong>Spatial Layout Limitation</strong>: αExtractor correctly identifies molecular connectivity; however, the re-rendered structures may have different spatial layouts than the originals. This could complicate visual verification for complex molecules, even if the chemical information remains accurate.</p>
</li>
</ul>
<p>The work establishes αExtractor as a significant advance in practical OCSR for biomedical applications. The combination of robust recognition, end-to-end document processing, and exceptional generalization makes it suitable for large-scale literature mining tasks where previous tools would fail. The focus on real-world robustness over benchmark optimization represents a mature approach to deploying machine learning in scientific workflows.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This paper is <strong>Partially Reproducible</strong>. While the authors detail the model architectures and training techniques, the source code, training dataset (20M synthetic images), and pre-trained weights remain closed-source and proprietary. The authors released a sample of their test data and host an online web server for running inference.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/CLEF_corrected">Corrected CLEF Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the CLEF benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/UOB_corrected">Corrected UOB Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the UOB benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/JPO_corrected">Corrected JPO Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the JPO benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Colored_Background">Color Background Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of molecular structures on complex, colorful backgrounds.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Low_Quality">Low Quality Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of degraded images with noise, blur, and artifacts.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/PDF">PDF Test Set</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Sample PDF files for end-to-end document extraction evaluation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://extractor.alphama.com.cn/csr">αExtractor Web Server</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Online service for running inference using the proprietary system.</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Backbone:</strong> ResNet50 producing output of shape $2048 \times 19 \times 19$, projected to 512 channels via a feed-forward layer</li>
<li><strong>Transformer Architecture:</strong> 3 encoder layers and 3 decoder layers with hidden dimension of 512</li>
<li><strong>Output Format:</strong> Generates SMILES tokens plus two auxiliary coordinate sequences (X-axis and Y-axis) that are length-aligned with the SMILES tokens via padding</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Architecture:</strong> DETR (Detection Transformer) with ResNet101 backbone</li>
<li><strong>Transformer Architecture:</strong> 6 encoder layers and 6 decoder layers with hidden dimension of 256</li>
<li><strong>Purpose:</strong> Locates molecular images within PDF pages before recognition</li>
</ul>
<p><strong>Coordinate Prediction:</strong></p>
<ul>
<li>Continuous X/Y coordinates are discretized into <strong>200 discrete bins</strong></li>
<li>Padding tokens added to coordinate sequences to align perfectly with SMILES token sequence, enabling simultaneous structure and pose prediction</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>Training Data:</strong></p>
<ul>
<li><strong>Synthetic Generation:</strong> Python script rendering PubChem SMILES into 2D images</li>
<li><strong>Dataset Size:</strong> Approximately 20.3 million synthetic molecular images from PubChem</li>
<li><strong>Superatom Handling:</strong> 50% of molecules had functional groups replaced with superatoms (e.g., &ldquo;COOH&rdquo;) or generic labels (R1, X1) to match literature drawing conventions</li>
<li><strong>Rendering Augmentation:</strong> Randomized bond thickness, bond spacing, font size, font color, and padding size</li>
</ul>
<p><strong>Geometric Augmentation:</strong></p>
<ul>
<li>Shear along x-axis: $\pm 15^\circ$</li>
<li>Rotation: $\pm 15^\circ$</li>
<li>Piecewise affine scaling</li>
</ul>
<p><strong>Noise Injection:</strong></p>
<ul>
<li>Pepper noise: 0-2%</li>
<li>Salt noise: 0-40%</li>
<li>Gaussian noise: scale 0-0.16</li>
</ul>
<p><strong>Destructive Augmentation:</strong></p>
<ul>
<li>JPEG compression: severity levels 2-5</li>
<li>Random masking</li>
</ul>
<p><strong>Evaluation Datasets:</strong></p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
<li><strong>Color background images</strong>: 200 samples</li>
<li><strong>Low-quality images</strong>: 200 samples</li>
<li><strong>Hand-drawn structures</strong>: Test set for generalization</li>
<li><strong>End-to-end document extraction</strong>: 50 PDFs (567 pages, 2,336 molecular images)</li>
</ul>
<h3 id="training">Training</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 100</li>
<li><strong>Epochs:</strong> 5</li>
<li><strong>Loss Function:</strong> Cross-entropy loss for both SMILES prediction and coordinate prediction</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 24</li>
<li><strong>Training Strategy:</strong> Pre-trained on synthetic &ldquo;Lower Quality&rdquo; data for 5 epochs, then fine-tuned on annotated real &ldquo;High Quality&rdquo; data for 30 epochs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics:</strong></p>
<ul>
<li><strong>Recognition</strong>: SMILES accuracy (exact match)</li>
<li><strong>End-to-End Pipeline</strong>:
<ul>
<li><strong>Recall</strong>: 95.1% for detection</li>
<li><strong>Accuracy</strong>: 94.5% for recognition</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Inference Hardware:</strong></p>
<ul>
<li>Cloud CPU server (8 CPUs, 64 GB RAM)</li>
<li><strong>Throughput:</strong> Processed 50 PDFs (567 pages) in 40 minutes</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xiong, J., Liu, X., Li, Z., Xiao, H., Wang, G., Niu, Z., Fei, C., Zhong, F., Wang, G., Zhang, W., Fu, Z., Liu, Z., Chen, K., Jiang, H., &amp; Zheng, M. (2024). αExtractor: A system for automatic extraction of chemical information from biomedical literature. <em>Science China Life Sciences</em>, 67(3), 618-621. <a href="https://doi.org/10.1007/s11427-023-2388-x">https://doi.org/10.1007/s11427-023-2388-x</a></p>
<p><strong>Publication</strong>: Science China Life Sciences (2024)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.1007/s11427-023-2388-x">Paper on Springer</a></li>
</ul>
]]></content:encoded></item><item><title>SubGrapher: Visual Fingerprinting of Chemical Structures</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/subgrapher/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/subgrapher/</guid><description>Novel OCSR method creating molecular fingerprints from images through functional group segmentation for database retrieval.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Morin, L., Meijer, G. I., Weber, V., Van Gool, L., &amp; Staar, P. W. J. (2025). Subgrapher: Visual fingerprinting of chemical structures. Journal of Cheminformatics, 17(1), 149. <a href="https://doi.org/10.1186/s13321-025-01091-4">https://doi.org/10.1186/s13321-025-01091-4</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics (2025)</p>
<h2 id="paper-classification-and-taxonomy">Paper Classification and Taxonomy</h2>
<p>This is primarily a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution. Using the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> framework:</p>
<p><strong>Primary Classification: Method</strong></p>
<p>The dominant basis vector is Methodological because SubGrapher introduces a novel architecture that fundamentally changes the OCSR workflow from two-step reconstruction (image → structure → fingerprint) to single-step fingerprinting (image → visual fingerprint). The paper validates this approach through systematic comparison against state-of-the-art methods (MolGrapher, OSRA, DECIMER, MolScribe), demonstrating superior performance on specific tasks like retrieval and robustness to image quality degradation.</p>
<p><strong>Secondary Classification: Resource</strong></p>
<p>The paper makes non-negligible resource contributions by releasing:</p>
<ul>
<li>Code and model weights on <a href="https://github.com/DS4SD/SubGrapher">GitHub</a> and <a href="https://huggingface.co/ds4sd/SubGrapher">HuggingFace</a></li>
<li>Five new visual fingerprinting benchmark datasets for molecule retrieval tasks</li>
<li>Comprehensive functional group knowledge base (1,534 substructures)</li>
</ul>
<h2 id="motivation-extracting-complex-structures-from-noisy-images">Motivation: Extracting Complex Structures from Noisy Images</h2>
<p>The motivation tackles a fundamental challenge in chemical informatics: extracting molecular information from the vast amounts of unstructured scientific literature, particularly patents. Millions of molecular structures exist only as images in these documents, making them inaccessible for computational analysis, database searches, or machine learning applications.</p>
<p>Traditional Optical Chemical Structure Recognition (OCSR) tools attempt to fully reconstruct molecular graphs from images, converting them into machine-readable formats like SMILES. However, these approaches face two critical limitations:</p>
<ol>
<li><strong>Brittleness to image quality</strong>: Poor resolution, noise, or unconventional drawing styles frequently cause complete failure</li>
<li><strong>Inability to handle complex structures</strong>: Markush structures, generic molecular templates with variable R-groups commonly used in patents, cannot be processed by conventional OCSR methods</li>
</ol>
<p>The key insight driving SubGrapher is that full molecular reconstruction may be unnecessary for many applications. For tasks like database searching, similarity analysis, or document retrieval, a molecular fingerprint - a vectorized representation capturing structural features - is often sufficient. This realization opens up a new approach: bypass the fragile reconstruction step and create fingerprints directly from visual information.</p>
<h2 id="key-innovation-direct-visual-fingerprinting">Key Innovation: Direct Visual Fingerprinting</h2>
<p>SubGrapher introduces a fundamentally different paradigm for extracting chemical information from images. It creates &ldquo;visual fingerprints&rdquo; through functional group recognition. The key innovations are:</p>
<ol>
<li>
<p><strong>Direct Image-to-Fingerprint Pipeline</strong>: SubGrapher eliminates the traditional two-step process (image → structure → fingerprint) by generating fingerprints directly from pixels. This single-stage approach avoids error accumulation from failed structure reconstructions and can handle images that would completely break conventional OCSR tools.</p>
</li>
<li>
<p><strong>Dual Instance Segmentation Architecture</strong>: The system employs two specialized Mask-RCNN networks working in parallel:</p>
<ul>
<li><strong>Functional group detector</strong>: Trained to identify 1,534 expert-defined functional groups using pixel-level segmentation masks</li>
<li><strong>Carbon backbone detector</strong>: Recognizes 27 common carbon chain patterns to capture the molecular scaffold</li>
</ul>
<p>Using instance segmentation provides detailed spatial information and higher accuracy through richer supervision during training.</p>
</li>
<li>
<p><strong>Extensive Functional Group Knowledge Base</strong>: The method leverages one of the most comprehensive open-source collections of functional groups, encompassing 1,534 substructures. These were systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li>
<p><strong>Substructure-Graph Construction</strong>: After detecting functional groups and carbon backbones, SubGrapher builds a connectivity graph where:</p>
<ul>
<li>Each node represents an identified substructure</li>
<li>Edges connect substructures whose bounding boxes overlap (with 10% margin expansion)</li>
<li>This graph captures both the chemical components and their spatial relationships</li>
</ul>
</li>
<li>
<p><strong>Substructure-based Visual Molecular Fingerprint (SVMF)</strong>: The final output is a continuous, count-based fingerprint formally defined as a matrix $SVMF(m) \in \mathbb{R}^{n xn}$ where $n=1561$ (1,534 functional groups + 27 carbon backbones). The matrix is stored as a compressed upper triangular form:</p>
<p><strong>Diagonal elements</strong> ($i = j$): Weighted count of substructure $i$
$$SVMF_{ii}(m) = h_1 \cdot \text{count}(s_i)$$
where $h_1 = 10$ is the diagonal weight hyperparameter.</p>
<p><strong>Off-diagonal elements</strong> ($i \neq j$): Intersection coefficient based on shortest path distance $d$ in the substructure graph
$$SVMF_{ij}(m) = h_2(d) \cdot \text{intersection}(s_i, s_j)$$
where the distance decay function $h_2(d)$ is:</p>
<ul>
<li>$d \leq 1$: weight = 2</li>
<li>$d = 2$: weight = 2/4 = 0.5</li>
<li>$d = 3$: weight = 2/16 = 0.125</li>
<li>$d = 4$: weight = $2/256 \approx 0.0078$</li>
<li>$d &gt; 4$: weight = 0</li>
</ul>
<p><strong>Key properties</strong>:</p>
<ul>
<li>Functional groups receive higher weight than carbon chains in the graph construction</li>
<li>Similarity between fingerprints calculated using Euclidean distance</li>
<li>Resulting fingerprints are highly sparse (average 0.001% non-zero elements)</li>
<li>Compressed storage enables efficient database searches</li>
</ul>
</li>
<li>
<p><strong>Markush Structure Compatibility</strong>: Unlike traditional OCSR methods that fail on generic structures, SubGrapher can process Markush structures by recognizing their constituent functional groups and creating meaningful fingerprints for similarity searches.</p>
</li>
</ol>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The evaluation focused on demonstrating SubGrapher&rsquo;s effectiveness across two critical tasks: accurate substructure detection and robust molecule retrieval from diverse image collections.</p>
<h4 id="substructure-detection-performance">Substructure Detection Performance</h4>
<p>SubGrapher&rsquo;s ability to identify functional groups was tested on three challenging benchmarks that expose different failure modes of OCSR systems:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size</th>
          <th>Description</th>
          <th>Key Challenge</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>JPO</strong></td>
          <td>341 images</td>
          <td>Japanese Patent Office images</td>
          <td>Low quality, noise, artifacts, non-standard drawing styles</td>
      </tr>
      <tr>
          <td><strong>USPTO-10K-L</strong></td>
          <td>1,000 images</td>
          <td>Large molecules (&gt;70 atoms)</td>
          <td>Scale variation, structural complexity, many functional groups</td>
      </tr>
      <tr>
          <td><strong>USPTO-Markush</strong></td>
          <td>74 images</td>
          <td>Generic Markush structures</td>
          <td>Variable R-groups, abstract patterns, template representation</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings:</strong></p>
<ol>
<li>
<p><strong>JPO Dataset (Low-Quality Patent Images)</strong>: SubGrapher achieved the highest Molecule Exact Match rate (83%), demonstrating robustness to image quality degradation that breaks traditional rule-based methods.</p>
</li>
<li>
<p><strong>USPTO-10K-L (Large Molecules)</strong>: The object detection approach handled scale variation better than conventional OCSR tools, achieving superior Substructure F1-scores on these challenging targets.</p>
</li>
<li>
<p><strong>USPTO-Markush (Generic Structures)</strong>: SubGrapher was the only method capable of processing these images, as traditional OCSR tools cannot handle generic molecular templates.</p>
</li>
</ol>
<p>Qualitative analysis revealed that SubGrapher correctly identified functional groups in scenarios where other methods failed completely: images with captions, unconventional drawing styles, or significant quality degradation.</p>
<h4 id="visual-fingerprinting-for-molecule-retrieval">Visual Fingerprinting for Molecule Retrieval</h4>
<p>The core application was evaluated using a retrieval task designed to simulate real-world database searching:</p>
<ol>
<li>
<p><strong>Benchmark Creation</strong>: Five benchmark datasets were constructed around structurally similar molecules (adenosine, camphor, cholesterol, limonene, and pyridine), each containing 500 similar molecular images.</p>
</li>
<li>
<p><strong>Retrieval Task</strong>: Given a SMILES string as a query, the goal was to find the corresponding molecular image within the dataset of 500 visually similar structures. This tests whether the visual fingerprint can distinguish between closely related molecules.</p>
</li>
<li>
<p><strong>Performance Comparison</strong>: SubGrapher significantly outperformed baseline methods, retrieving the correct molecule at an average rank of 95 out of 500. The key advantage was robustness&mdash;SubGrapher generates a unique fingerprint for every image, even with partial or uncertain predictions. In contrast, OCSR-based methods frequently fail to produce valid SMILES, making them unable to generate fingerprints for comparison.</p>
</li>
<li>
<p><strong>Real-World Case Study</strong>: A practical demonstration involved searching a 54-page patent document containing 356 chemical images for a specific Markush structure. SubGrapher successfully located the target structure, highlighting its utility for large-scale document mining.</p>
</li>
</ol>
<h4 id="training-data-generation">Training Data Generation</h4>
<p>Since no public datasets existed with the required pixel-level mask annotations for functional groups, the researchers developed a comprehensive synthetic data generation pipeline:</p>
<ol>
<li>
<p><strong>Extended MolDepictor</strong>: They enhanced existing molecular rendering tools to create images from SMILES strings and generate corresponding segmentation masks for all substructures present in each molecule.</p>
</li>
<li>
<p><strong>Markush Structure Rendering</strong>: The pipeline was extended to handle complex generic structures, creating training data for molecular templates that conventional tools cannot process.</p>
</li>
<li>
<p><strong>Diverse Molecular Sources</strong>: Training molecules were sourced from both ChEMBL and PubChem to ensure broad chemical diversity and coverage of different structural families.</p>
</li>
</ol>
<h2 id="results-impact-and-limitations">Results, Impact, and Limitations</h2>
<ul>
<li><strong>Superior Robustness to Image Quality</strong>: SubGrapher consistently outperformed traditional OCSR methods on degraded images, particularly the JPO patent dataset. While rule-based systems failed completely on poor-quality images, SubGrapher&rsquo;s learned representations proved resilient to noise, artifacts, and unconventional drawing styles.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SubGrapher</th>
          <th>MolScribe</th>
          <th>OSRA</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>S-F1</strong> (JPO)</td>
          <td>92</td>
          <td>94</td>
          <td>81</td>
          <td>SubGrapher comparable to SOTA on detection</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (JPO)</td>
          <td><strong>83</strong></td>
          <td>82</td>
          <td>67</td>
          <td>SubGrapher wins on exact match for noisy images</td>
      </tr>
      <tr>
          <td><strong>Avg Retrieval Rank</strong></td>
          <td><strong>~95/500</strong></td>
          <td>&gt;100/500</td>
          <td>&gt;100/500</td>
          <td>SubGrapher significantly better at finding correct molecule</td>
      </tr>
  </tbody>
</table>
<ul>
<li>
<p><strong>Effective Handling of Scale and Complexity</strong>: The instance segmentation approach successfully managed large molecules and complex structures where traditional graph-reconstruction methods struggled. The Substructure F1-scores on USPTO-10K-L and USPTO-Markush benchmarks demonstrated clear advantages for challenging molecular targets.</p>
</li>
<li>
<p><strong>Breakthrough for Markush Structure Processing</strong>: SubGrapher represents the first practical solution for extracting information from Markush structures. These generic molecular templates appear frequently in patents but cannot be processed by conventional OCSR tools. This capability significantly expands the scope of automatically extractable chemical information from patent literature.</p>
</li>
<li>
<p><strong>Robust Molecule Retrieval Performance</strong>: The visual fingerprinting approach achieved reliable retrieval performance (average rank 95/500) across diverse molecular families. The key advantage was consistency&mdash;SubGrapher generates meaningful fingerprints even from partial or uncertain predictions, while OCSR-based methods often fail to produce any usable output.</p>
</li>
<li>
<p><strong>Practical Document Mining Capability</strong>: The successful identification of specific Markush structures within large patent documents (54 pages, 356 images) demonstrates real-world applicability for large-scale literature mining and intellectual property analysis.</p>
</li>
<li>
<p><strong>Single-Stage Architecture Benefits</strong>: By eliminating the traditional image → structure → fingerprint pipeline, SubGrapher avoids error accumulation from failed molecular reconstructions. Every input image produces a fingerprint, making the system more reliable for batch processing of diverse document collections.</p>
</li>
<li>
<p><strong>Limitations and Scope</strong>: The method remains focused on common organic functional groups and may struggle with inorganic chemistry, organometallic complexes, or highly specialized molecular classes not well-represented in the training data. The 1,534 functional groups, while extensive, represent a curated subset of chemical space.</p>
</li>
</ul>
<p>The work establishes a new paradigm for chemical information extraction from images, demonstrating that direct fingerprint generation can be more robust and practical than traditional structure reconstruction approaches. SubGrapher&rsquo;s ability to handle Markush structures and degraded images makes it particularly valuable for patent analysis and large-scale document mining, where traditional OCSR methods frequently fail. The approach suggests that task-specific learning (fingerprints for retrieval) can outperform general-purpose reconstruction methods in many practical applications.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data Generation</strong>: The paper developed a custom synthetic data pipeline since no public datasets existed with pixel-level mask annotations for functional groups:</p>
<ul>
<li><strong>Extended MolDepictor</strong>: Enhanced molecular rendering tool to generate both images and corresponding segmentation masks for all substructures</li>
<li><strong>Markush Structure Rendering</strong>: Pipeline extended to handle complex generic structures</li>
<li><strong>Source Molecules</strong>: ChEMBL and PubChem for broad chemical diversity</li>
</ul>
<p><strong>Evaluation Benchmarks</strong>:</p>
<ul>
<li><strong>JPO Dataset</strong>: Real patent images with poor resolution, noise, and artifacts</li>
<li><strong>USPTO-10K-L</strong>: Large complex molecular structures</li>
<li><strong>USPTO-Markush</strong>: Generic structures with variable R-groups</li>
<li><strong>Retrieval Benchmarks</strong>: Five datasets (adenosine, camphor, cholesterol, limonene, pyridine), each with 500 similar molecular images</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Dual instance segmentation system using Mask-RCNN</p>
<ul>
<li><strong>Functional Group Detector</strong>: Mask-RCNN trained to identify 1,534 expert-defined functional groups</li>
<li><strong>Carbon Backbone Detector</strong>: Mask-RCNN trained to recognize 27 common carbon chain patterns</li>
<li><strong>Backbone Network</strong>: Not specified in paper (typically ResNet-50/101 or Swin Transformer for Mask-RCNN)</li>
</ul>
<p><strong>Functional Group Knowledge Base</strong>: 1,534 substructures systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Functional Group Definition</strong>:</p>
<ul>
<li><strong>1,534 Functional Groups</strong>: Defined by manually curated SMARTS patterns
<ul>
<li>Must contain heteroatoms (O, N, S, P, B)</li>
<li>Frequency threshold: ~1,000+ occurrences in PubChem</li>
<li>Systematically constructed from chemically logical atom combinations</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li><strong>27 Carbon Backbones</strong>: Patterns of 3-6 carbon atoms (rings and chains) to capture molecular scaffolds</li>
</ul>
<p><strong>Substructure-Graph Construction</strong>:</p>
<ol>
<li>Detect functional groups and carbon backbones using Mask-RCNN models</li>
<li>Build connectivity graph:
<ul>
<li>Each node represents an identified substructure instance</li>
<li>Edges connect substructures whose bounding boxes overlap</li>
<li>Bounding boxes expanded by 10% of smallest box&rsquo;s diagonal to ensure connectivity between adjacent groups</li>
<li>Functional groups weighted higher than carbon chains in graph construction</li>
</ul>
</li>
</ol>
<p><strong>SVMF Fingerprint Generation</strong>:</p>
<ul>
<li>Matrix form: $SVMF(m) \in \mathbb{R}^{n xn}$ where $n=1561$</li>
<li>Stored as compressed sparse upper triangular matrix</li>
<li><strong>Diagonal elements</strong>: $SVMF_{ii} = h_1 \cdot \text{count}(s_i)$ where $h_1 = 10$</li>
<li><strong>Off-diagonal elements</strong>: $SVMF_{ij} = h_2(d) \cdot \text{intersection}(s_i, s_j)$ where:
<ul>
<li>$h_2(d) = 2$ for $d \leq 1$</li>
<li>$h_2(d) = 2/4^{(d-1)}$ for $d = 2, 3, 4$</li>
<li>$h_2(d) = 0$ for $d &gt; 4$</li>
</ul>
</li>
<li>Average sparsity: 0.001% non-zero elements</li>
<li>Similarity metric: Euclidean distance between fingerprint vectors</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Substructure F1-score (S-F1)</strong>: Harmonic mean of precision and recall for individual substructure detection across all molecules in the dataset</li>
<li><strong>Molecule Exact Match (M-EM)</strong>: Percentage of molecules where S-F1 = 1.0 (all substructures correctly identified)</li>
<li><strong>Retrieval Rank</strong>: Average rank of ground truth molecule in candidate list of 500 similar structures when querying with SMILES fingerprint</li>
</ul>
<p><strong>Baselines</strong>: Compared against SOTA OCSR methods:</p>
<ul>
<li>Deep learning: MolScribe, MolGrapher, DECIMER</li>
<li>Rule-based: OSRA</li>
<li>Fingerprint methods: RDKit Daylight, MHFP (applied to OCSR outputs)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Training and inference hardware details are not provided in the main text or would be found in supplementary materials.</p>
<h3 id="implementation-gaps">Implementation Gaps</h3>
<p>The following details are not available in the paper and would require access to the code repository or supplementary information:</p>
<ul>
<li>Specific backbone architecture for Mask-RCNN (ResNet variant, Swin Transformer, etc.)</li>
<li>Optimizer type (AdamW, SGD, etc.)</li>
<li>Learning rate and scheduler</li>
<li>Batch size and number of training epochs</li>
<li>Loss function weights (box loss vs. mask loss)</li>
<li>GPU/TPU specifications used for training</li>
<li>Training time and computational requirements</li>
</ul>
]]></content:encoded></item><item><title>GTR-CoT: Graph Traversal Chain-of-Thought for Molecules</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/gtr-mol-vlm/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/gtr-mol-vlm/</guid><description>GTR-CoT uses graph traversal chain-of-thought reasoning to improve optical chemical structure recognition accuracy.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, J., Yang, H., Wu, J., He, Y., Wei, X., Wang, Y., Liu, C., Ge, L., Wu, L., Wang, B., Lin, D., &amp; He, C. (2025). GTR-CoT: Graph Traversal as Visual Chain of Thought for Molecular Structure Recognition (No. arXiv:2506.07553; Version 2). arXiv. <a href="https://doi.org/10.48550/arXiv.2506.07553">https://doi.org/10.48550/arXiv.2506.07553</a></p>
<p><strong>Publication</strong>: arXiv preprint (2025)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2506.07553">Paper on arXiv</a></li>
</ul>
<h2 id="contribution-vision-language-modeling-for-ocsr">Contribution: Vision-Language Modeling for OCSR</h2>
<p>This is a <strong>method paper</strong> that introduces GTR-Mol-VLM, a Vision-Language Model for Optical Chemical Structure Recognition (OCSR). The work addresses the persistent challenge of converting molecular structure images into machine-readable formats, with a particular focus on handling chemical abbreviations that have plagued existing systems.</p>
<h2 id="motivation-the-abbreviation-bottleneck">Motivation: The Abbreviation Bottleneck</h2>
<p>The motivation tackles a long-standing bottleneck in chemical informatics: most existing OCSR systems fail catastrophically when they encounter abbreviated functional groups. When a chemist draws &ldquo;Ph&rdquo; for phenyl or &ldquo;Et&rdquo; for ethyl, current models either crash or produce wildly incorrect structures because they&rsquo;ve been trained on data where images contain abbreviations but the ground-truth labels contain fully expanded molecular graphs.</p>
<p>This creates a fundamental mismatch. The model sees &ldquo;Ph&rdquo; in the image but is told the &ldquo;correct&rdquo; answer is a full benzene ring. It&rsquo;s like teaching someone to read by showing them &ldquo;Dr.&rdquo; but insisting the right answer is &ldquo;Doctor&rdquo; - the supervision signal is inconsistent with what&rsquo;s actually visible.</p>
<p>Beyond this data problem, existing graph-parsing methods use a two-stage approach: predict all atoms first, then predict all bonds. This is inefficient and ignores the structural constraints that could help during prediction. The authors argue that mimicking how humans analyze molecular structures - following bonds from atom to atom in a connected traversal - would be more effective.</p>
<h2 id="novelty-graph-traversal-as-visual-chain-of-thought">Novelty: Graph Traversal as Visual Chain-of-Thought</h2>
<p>The novelty lies in combining two key insights about how to properly train and architect OCSR systems. The main contributions are:</p>
<ol>
<li>
<p><strong>Graph Traversal as Visual Chain of Thought</strong>: GTR-Mol-VLM generates molecular graphs by traversing them sequentially, predicting an atom, then its connected bond, then the next atom, and so on. This mimics how a human chemist would trace through a structure and allows the model to use previously predicted atoms and bonds as context for subsequent predictions.</p>
<p>Formally, the model output sequence for image $I_m$ is generated as:</p>
<p>$$ R_m = \text{concat}(CoT_m, S_m) $$</p>
<p>where $CoT_m$ represents the deterministic graph traversal steps (atoms and bonds) and $S_m$ is the final SMILES representation. This intermediate reasoning step makes the model more interpretable and helps it learn the structural logic of molecules.</p>
</li>
<li>
<p><strong>&ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; Principle</strong>: This addresses the abbreviation problem head-on. The authors correct the ground-truth annotations to match what&rsquo;s actually visible in the image.</p>
<p>They treat abbreviations like &ldquo;Ph&rdquo; as single &ldquo;superatoms&rdquo; and build a pipeline to automatically detect and correct training data. Using OCR to extract visible text from molecular images, they replace the corresponding expanded substructures in the ground-truth with the appropriate abbreviation tokens. This ensures the supervision signal is consistent with the visual input.</p>
</li>
<li>
<p>Methodology &amp; Experiments: Baselines and Datasetspport this approach, the authors created a large-scale dataset combining 1M synthetic molecules from PubChem with 351K corrected real-world patent images from USPTO. The key innovation is the correction pipeline that identifies abbreviations in patent images and fixes the inconsistent ground-truth labels.</p>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: Traditional SMILES-based evaluation fails for molecules with abbreviations because canonicalization breaks down. The authors created a new benchmark that evaluates graph structure directly, providing three metrics: direct SMILES generation, graph-derived SMILES, and exact graph matching.</p>
</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The evaluation focused on demonstrating that GTR-Mol-VLM&rsquo;s design principles solve real problems that plague existing OCSR systems:</p>
<ol>
<li>
<p><strong>Comprehensive Baseline Comparison</strong>: GTR-Mol-VLM was tested against three categories of models:</p>
<ul>
<li><strong>Specialist OCSR systems</strong>: MolScribe and MolNexTR (the current state-of-the-art)</li>
<li><strong>Chemistry-focused VLMs</strong>: ChemVLM</li>
<li><strong>General-purpose VLMs</strong>: GPT-4o, Claude-3.5-Sonnet, Qwen-VL-Max</li>
</ul>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: The new benchmark includes two subsets of patent images:</p>
<ul>
<li><strong>MolRec-Std</strong>: Standard patent images similar to existing benchmarks</li>
<li><strong>MolRec-Abb</strong>: Images specifically selected to contain many abbreviated functional groups</li>
</ul>
<p>This design directly tests whether models can handle the abbreviation problem that breaks existing systems.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Systematic experiments isolated the contribution of key design choices:</p>
<ul>
<li><strong>Chain-of-Thought vs. Direct</strong>: Comparing graph traversal CoT against direct SMILES prediction</li>
<li><strong>Traversal Strategy</strong>: Graph traversal vs. the traditional &ldquo;atoms-then-bonds&rdquo; approach</li>
<li><strong>Dataset Quality</strong>: Training on corrected vs. uncorrected data</li>
</ul>
</li>
<li>
<p><strong>Retraining Experiments</strong>: Existing specialist models (MolScribe, MolNexTR) were retrained from scratch on the corrected GTR-CoT-1.3M dataset to isolate the effect of data quality from architectural improvements.</p>
</li>
<li>
<p><strong>Qualitative Analysis</strong>: Visual inspection of predictions on challenging cases with heavy abbreviation usage, complex structures, and edge cases to understand failure modes.</p>
</li>
</ol>
<h2 id="results--conclusions-resolving-the-abbreviation-bottleneck">Results &amp; Conclusions: Resolving the Abbreviation Bottleneck</h2>
<ul>
<li>
<p><strong>Dramatic Performance Gains on Abbreviations</strong>: GTR-Mol-VLM achieves state-of-the-art performance across all metrics on both benchmark subsets, but the improvement is particularly striking on MolRec-Abb. Existing specialist models that perform well on standard images see their accuracy drop below 20% when abbreviations are present. GTR-Mol-VLM maintains high performance across both conditions.</p>
</li>
<li>
<p><strong>Data Correction is Critical</strong>: When existing models were retrained on the corrected GTR-CoT-1.3M dataset, their performance improved substantially, validating that the &ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; principle addresses a real problem in the training data. However, GTR-Mol-VLM still outperformed these retrained baselines, confirming that both the data correction and the architectural innovations contribute.</p>
</li>
<li>
<p><strong>Chain-of-Thought Helps</strong>: Ablation studies confirmed that generating the graph traversal sequence before the final SMILES string improves performance compared to direct prediction. The intermediate reasoning step provides valuable structure that helps the model learn chemical logic.</p>
</li>
<li>
<p><strong>Graph Traversal Beats Traditional Parsing</strong>: The sequential atom-bond traversal approach outperformed the traditional &ldquo;atoms-then-bonds&rdquo; method, supporting the hypothesis that mimicking human reasoning patterns is more effective.</p>
</li>
<li>
<p><strong>General VLMs Still Struggle</strong>: Despite their impressive capabilities in other domains, general-purpose VLMs like GPT-4o performed poorly on this task, highlighting the importance of domain-specific training and architectural considerations.</p>
</li>
<li>
<p><strong>Evaluation Methodology Matters</strong>: The new graph-based evaluation metrics revealed problems with traditional SMILES-based evaluation that previous work had missed. Many &ldquo;failures&rdquo; in existing benchmarks were actually correct graph predictions that got marked wrong due to canonicalization issues with abbreviations.</p>
</li>
</ul>
<p>The work establishes that addressing the abbreviation problem requires both correcting the training data and rethinking the model architecture. The combination of faithful data annotation and sequential graph generation represents a significant advance for making OCSR systems robust enough for real-world deployment on diverse chemical literature.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Base Model</strong>: GTR-Mol-VLM fine-tunes <strong>Qwen-VL 2.5 3B</strong>.</p>
<p><strong>Input/Output Mechanism</strong>:</p>
<ul>
<li><strong>Input</strong>: The model takes an image $I_m$ and a text prompt</li>
<li><strong>Output</strong>: The model generates $R_m = \text{concat}(CoT_m, S_m)$, where it first produces the Chain-of-Thought (the graph traversal steps) followed immediately by the final SMILES string</li>
<li><strong>Traversal Strategy</strong>: Uses <strong>depth-first traversal</strong> to alternately predict atoms and bonds</li>
</ul>
<p><strong>Prompt Structure</strong>: The model is prompted to &ldquo;list the types of atomic elements&hellip; the coordinates&hellip; and the chemical bonds&hellip; then&hellip; output a canonical SMILES&rdquo;. The CoT output is formatted as a JSON list of atoms (with coordinates) and bonds (with indices referring to previous atoms), interleaved.</p>
<h3 id="data">Data</h3>
<p><strong>Training Dataset (GTR-CoT-1.3M)</strong>:</p>
<ul>
<li><strong>Synthetic Component</strong>: 1 million molecular SMILES from PubChem, converted to images using Indigo</li>
<li><strong>Real Component</strong>: 351,000 samples from USPTO patents (filtered from an original 680,000)
<ul>
<li>Processed using an OCR pipeline to detect abbreviations (e.g., &ldquo;Ph&rdquo;, &ldquo;Et&rdquo;)</li>
<li>Ground truth expanded structures replaced with superatoms to match visible abbreviations in images</li>
<li>This &ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; correction ensures training supervision matches visual input</li>
</ul>
</li>
</ul>
<p><strong>Evaluation Dataset (MolRec-Bench)</strong>:</p>
<ul>
<li><strong>MolRec-USPTO</strong>: 5,423 standard molecular images from patents</li>
<li><strong>MolRec-Abb</strong>: 9,311 molecular images specifically containing abbreviated superatoms, derived from MolGrapher data</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Graph Traversal Algorithm</strong>:</p>
<ul>
<li>Depth-first traversal strategy</li>
<li>Alternating atom-bond prediction sequence</li>
<li>Each step uses previously predicted atoms and bonds as context</li>
</ul>
<p><strong>Training Procedure</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate</strong>: Peak learning rate of $1.6 \times 10^{-4}$ with cosine decay</li>
<li><strong>Warm-up</strong>: Linear warm-up for the first 10% of iterations</li>
<li><strong>Batch Size</strong>: 2 per GPU with gradient accumulation over 16 steps, yielding <strong>effective batch size of 1024</strong></li>
<li><strong>Efficiency</strong>: DeepSpeed ZeRO Stage 2 for memory reduction</li>
<li><strong>Duration</strong>: 3 epochs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong> (three complementary measures to handle abbreviation issues):</p>
<ul>
<li><strong><code>Gen_SMILES</code></strong>: Exact match ratio of SMILES strings directly generated by the VLM (image-captioning style)</li>
<li><strong><code>Gra_SMILES</code></strong>: Exact match ratio of SMILES strings derived from the predicted graph structure (graph-parsing style)</li>
<li><strong><code>Graph</code></strong>: Exact match ratio between ground truth and predicted graphs (node/edge comparison, bypassing SMILES canonicalization issues)</li>
</ul>
<p><strong>Baselines Compared</strong>:</p>
<ul>
<li>Specialist OCSR systems: MolScribe, MolNexTR</li>
<li>Chemistry-focused VLMs: ChemVLM</li>
<li>General-purpose VLMs: GPT-4o, Claude-3.5-Sonnet, Qwen-VL-Max</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute</strong>: Training performed on <strong>32 NVIDIA A100 GPUs</strong></p>
]]></content:encoded></item><item><title>MolNexTR: Dual-Stream Molecular Image Recognition</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molnextr/</link><pubDate>Sat, 04 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molnextr/</guid><description>Dual-stream encoder combining ConvNext and ViT for robust optical chemical structure recognition across diverse molecular drawing styles.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, Y., Leung, C. T., Huang, Y., Sun, J., Chen, H., &amp; Gao, H. (2024). MolNexTR: a generalized deep learning model for molecular image recognition. <em>Journal of Cheminformatics</em>, 16(141). <a href="https://doi.org/10.1186/s13321-024-00926-w">https://doi.org/10.1186/s13321-024-00926-w</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/CYF2000127/MolNexTR">GitHub Repository</a></li>
<li><a href="https://huggingface.co/datasets/CYF200127/MolNexTR/tree/main">HuggingFace Dataset/Model</a></li>
</ul>
<h2 id="methodology-overview-and-taxonomic-classification">Methodology Overview and Taxonomic Classification</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$). It proposes a novel neural network architecture (MolNexTR) that integrates ConvNext and Vision Transformers to solve the Optical Chemical Structure Recognition (OCSR) task. The paper validates this method through ablation studies and extensive benchmarking against current state-of-the-art models like MolScribe and DECIMER.</p>
<h2 id="the-challenge-of-domain-specific-drawing-styles-in-ocsr">The Challenge of Domain-Specific Drawing Styles in OCSR</h2>
<p>Converting molecular images from chemical literature into machine-readable formats (SMILES) is critical but challenging due to the high variance in drawing styles, fonts, and conventions (e.g., Markush structures, abbreviations). Existing methods have limitations:</p>
<ul>
<li><strong>Style Robustness</strong>: CNN-based and ViT-based models often struggle to generalize across diverse, non-standard drawing styles found in real literature.</li>
<li><strong>Feature Extraction</strong>: Pure ViT methods lack translation invariance and local feature representation, while pure CNNs struggle with global dependencies.</li>
<li><strong>Chemical Knowledge</strong>: Many models predict SMILES strings directly, making it difficult to enforce chemical validity or resolve complex stereochemistry and abbreviations.</li>
</ul>
<h2 id="core-innovation-dual-stream-encoding-and-image-contamination">Core Innovation: Dual-Stream Encoding and Image Contamination</h2>
<p>MolNexTR introduces three main innovations:</p>
<ol>
<li><strong>Dual-Stream Encoder</strong>: A hybrid architecture processing images simultaneously through a ConvNext stream (for local features) and a Vision Transformer stream (for long-range dependencies), fusing them to capture multi-scale information.</li>
<li><strong>Image Contamination Augmentation</strong>: A specialized data augmentation algorithm that simulates real-world &ldquo;noise&rdquo; found in literature, such as overlapping text, arrows, and partial molecular fragments, to improve robustness.</li>
<li><strong>Graph-Based Decoding with Post-Processing</strong>: Unlike pure image-to-SMILES translation, it predicts atoms and bonds (graph generation) and uses a stereochemical discrimination and abbreviation self-correction module to enforce chemical rules (e.g., chirality) and resolve superatoms (e.g., &ldquo;Ph&rdquo;, &ldquo;Bn&rdquo;).</li>
</ol>
<p>The prediction of atom labels and coordinates is formulated as a conditional autoregressive generation task, optimized via a cross-entropy loss:
$$ \mathcal{L}_{\text{atom}} = -\sum_{t=1}^{T} \log P(x_t \mid \text{Image}, x_{&lt;t}) $$</p>
<h2 id="experimental-setup-benchmarking-on-synthetic-and-real-data">Experimental Setup: Benchmarking on Synthetic and Real Data</h2>
<p>The model was trained on synthetic data (PubChem) and real patent data (USPTO). It was evaluated on six public benchmarks:</p>
<ul>
<li><strong>Synthetic</strong>: Indigo, ChemDraw, RDKit (rendered from 5,719 molecules)</li>
<li><strong>Real-World</strong>: CLEF, UOB, JPO, USPTO, Staker, and a newly curated ACS dataset (diverse styles)</li>
</ul>
<p><strong>Baselines</strong>: Compared against rule-based (OSRA, MolVec) and deep learning models (MolScribe, DECIMER, SwinOCSR, Img2Mol).</p>
<p><strong>Ablations</strong>: Tested the impact of the dual-stream encoder vs. single streams, and the contribution of individual augmentation strategies.</p>
<h2 id="empirical-results-and-robustness-findings">Empirical Results and Robustness Findings</h2>
<ul>
<li><strong>Performance</strong>: MolNexTR achieved 81-97% accuracy across test sets, outperforming the second-best method (often MolScribe) by margins of 0.3% to 10.0% (on the difficult ACS dataset).</li>
<li><strong>Robustness</strong>: The model showed superior resilience to image perturbations (rotation, noise) and &ldquo;curved arrow&rdquo; noise common in reaction mechanisms.</li>
<li><strong>Ablation Results</strong>: The dual-stream encoder consistently outperformed single CNN or ViT baselines, and the image contamination algorithm significantly boosted performance on noisy real-world data (ACS).</li>
<li><strong>Limitations</strong>: The model still struggles with extremely complex hand-drawn molecules and mechanism diagrams where arrows or text are conflated with structure.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong>:</p>
<ul>
<li><strong>Synthetic</strong>: ~1M molecules randomly selected from PubChem, rendered using RDKit and Indigo with varied styles (thickness, fonts, bond width)</li>
<li><strong>Real</strong>: 0.68M images from USPTO, with coordinates normalized from MOLfiles</li>
</ul>
<p><strong>Augmentation</strong>:</p>
<ul>
<li><strong>Render Augmentation</strong>: Randomized drawing styles (line width, font size, label modes)</li>
<li><strong>Image Augmentation</strong>: Rotation, cropping, blurring, noise (Gaussian, salt-and-pepper)</li>
<li><strong>Molecular Augmentation</strong>: Randomly replacing functional groups with abbreviations (from a list of &gt;100) or complex chains (e.g., CH3CH2NH2); adding R-groups</li>
<li><strong>Image Contamination</strong>: Adding &ldquo;noise&rdquo; objects (arrows, lines, text, partial structures) at a minimum distance from the main molecule to simulate literature artifacts</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Dual-Stream Encoder</strong>:</p>
<ul>
<li><strong>CNN Stream</strong>: ConvNext backbone (pre-trained on ImageNet), generating feature maps at scales $H/4$ to $H/32$</li>
<li><strong>ViT Stream</strong>: Parallel transformer blocks receiving patches of sizes $p=4, 8, 16, 32$. Uses Multi-Head Self-Attention (MHSA) and Feed-Forward Networks (FFN)</li>
<li><strong>Fusion</strong>: Outputs from both streams are concatenated</li>
</ul>
<p><strong>Decoder (Graph Generation)</strong>:</p>
<ul>
<li><strong>Transformer Decoder</strong>: 6 layers, 8 heads, hidden dim 256</li>
<li><strong>Task 1 (Atoms)</strong>: Autoregressive prediction of atom tokens $(l, x, y)$ (label + coordinates)</li>
<li><strong>Task 2 (Bonds)</strong>: Prediction of bond types between atom pairs (Single, Double, Triple, Aromatic, Wedge)</li>
</ul>
<p><strong>Post-Processing</strong>:</p>
<ul>
<li><strong>Stereochemistry</strong>: Uses predicted coordinates and bond types (wedge/dash) to resolve chirality using RDKit logic</li>
<li><strong>Abbreviation Correction</strong>: Matches superatoms to a dictionary; if unknown, attempts to greedily connect atoms based on valence or finds the nearest match ($\sigma=0.8$ similarity threshold)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder (ConvNext + ViT Encoder -&gt; Transformer Decoder)</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: ADAM (max lr 3e-4, linear warmup for 5% of steps)</li>
<li>Batch Size: 256</li>
<li>Image Size: $384 \x384$</li>
<li>Dropout: 0.1</li>
</ul>
</li>
<li><strong>Training</strong>: Fine-tuned CNN backbone for 40 epochs on 10 NVIDIA RTX 3090 GPUs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metric</strong>: SMILES sequence exact matching accuracy (canonicalized)</p>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>Synthetic</strong>: Indigo (5,719), ChemDraw (5,719), RDKit (5,719)</li>
<li><strong>Real</strong>: CLEF (992), UOB (5,740), JPO (450), USPTO (5,719), Staker (50,000), ACS (331)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: 10 NVIDIA RTX 3090 GPUs</li>
<li><strong>Cluster</strong>: HPG Cluster at HKUST</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chenMolNexTRGeneralizedDeep2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{MolNexTR: A Generalized Deep Learning Model for Molecular Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chen, Yufan and Leung, Ching Ting and Huang, Yong and Sun, Jianwei and Chen, Hao and Gao, Hanyu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{141}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00926-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolParser-7M &amp; WildMol: Large-Scale OCSR Datasets</title><link>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molparser_7m-wildmol/</link><pubDate>Fri, 03 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/optical-structure-recognition/molparser_7m-wildmol/</guid><description>MolParser-7M is the largest OCSR dataset with 7.7M image-text pairs of molecules and E-SMILES, including 400k real-world samples.</description><content:encoded><![CDATA[<h2 id="dataset-examples">Dataset Examples</h2>















<figure class="post-figure center ">
    <img src="/img/molparser-markush-example.webp"
         alt="Example of a complex Markush structure"
         title="Example of a complex Markush structure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example of a complex Markush structure that can be represented by the E-SMILES format but not by standard SMILES or FG-SMILES.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-low-quality-example.webp"
         alt="Sample from the WildMol benchmark"
         title="Sample from the WildMol benchmark"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A sample from the WildMol benchmark, showing a low-quality, noisy molecular image cropped from real-world literature that challenges OCSR systems.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-colored-example.webp"
         alt="Colored molecule with annotations"
         title="Colored molecule with annotations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A colored molecule with annotations, representing the diverse drawing styles found in scientific papers that OCSR models must handle.</figcaption>
    
</figure>

<h2 id="dataset-subsets">Dataset Subsets</h2>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Count</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>MolParser-7M (Training Set)</strong></td>
          <td>7,740,871</td>
          <td>A large-scale dataset for training OCSR models, split into pre-training and fine-tuning stages.</td>
      </tr>
      <tr>
          <td><strong>WildMol (Test Set)</strong></td>
          <td>20,000</td>
          <td>A benchmark of 20,000 human-annotated samples cropped from real PDF files to evaluate OCSR models in &lsquo;in-the-wild&rsquo; scenarios.</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmarks">Benchmarks</h2>

<div class="benchmarks-content">
  <div class="benchmark-section">
    <h3 id="wildmol-10k-accuracy">WildMol-10K Accuracy<a hidden class="anchor" aria-hidden="true" href="#wildmol-10k-accuracy">#</a></h3>
    <p class="benchmark-description">Evaluation of OCSR models on 10,000 real-world molecular images cropped from scientific literature and patents</p>
    <table class="benchmark-table">
      <thead>
        <tr>
          <th>Rank</th>
          <th>Model</th>
          <th>Accuracy (%)</th>
        </tr>
      </thead>
      <tbody>
        <tr class="top-result">
          <td>🥇 1</td>
          <td>
            <strong>MolParser-Base</strong><br><small>End-to-end visual recognition trained on MolParser-7M</small>
          </td>
          <td>76.9</td>
        </tr>
        <tr class="top-result">
          <td>🥈 2</td>
          <td>
            <strong>MolScribe</strong><br><small>Transformer-based OCSR system</small>
          </td>
          <td>66.4</td>
        </tr>
        <tr class="top-result">
          <td>🥉 3</td>
          <td>
            <strong>DECIMER 2.7</strong><br><small>Deep learning for chemical image recognition</small>
          </td>
          <td>56</td>
        </tr>
        <tr>
          <td>4</td>
          <td>
            <strong>MolGrapher</strong><br><small>Graph-based molecular structure recognition</small>
          </td>
          <td>45.5</td>
        </tr>
        <tr>
          <td>5</td>
          <td>
            <strong>MolVec 0.9.7</strong><br><small>Vector-based structure recognition</small>
          </td>
          <td>26.4</td>
        </tr>
        <tr>
          <td>6</td>
          <td>
            <strong>OSRA 2.1</strong><br><small>Optical Structure Recognition Application</small>
          </td>
          <td>26.3</td>
        </tr>
        <tr>
          <td>7</td>
          <td>
            <strong>Img2Mol</strong><br><small>Image-to-molecule translation</small>
          </td>
          <td>24.4</td>
        </tr>
        <tr>
          <td>8</td>
          <td>
            <strong>Imago 2.0</strong><br><small>Chemical structure recognition toolkit</small>
          </td>
          <td>6.9</td>
        </tr>
      </tbody>
    </table>
  </div>
</div>

<h2 id="key-contribution">Key Contribution</h2>
<p>Introduces MolParser-7M, the largest Optical Chemical Structure Recognition (OCSR) dataset, uniquely combining diverse synthetic data with a large volume of manually-annotated, &ldquo;in-the-wild&rdquo; images from real scientific documents to improve model robustness. Also introduces WildMol, a new challenging benchmark for evaluating OCSR performance on real-world data, including Markush structures.</p>
<h2 id="overview">Overview</h2>
<p>The MolParser project addresses the challenge of recognizing molecular structures from images found in real-world scientific documents. Unlike existing OCSR datasets that rely primarily on synthetically generated images, MolParser-7M incorporates 400,000 manually annotated images cropped from actual patents and scientific papers, making it the first large-scale dataset to bridge the gap between synthetic training data and real-world deployment scenarios.</p>
<h2 id="strengths">Strengths</h2>
<ul>
<li>Largest open-source OCSR dataset with over 7.7 million pairs</li>
<li>The only large-scale OCSR training set that includes a significant amount (400k) of &ldquo;in-the-wild&rdquo; data cropped from real patents and literature</li>
<li>High diversity of molecular structures from numerous sources (PubChem, ChEMBL, polymers, etc.)</li>
<li>Introduces the WildMol benchmark for evaluating performance on challenging, real-world data, including Markush structures</li>
<li>The &ldquo;in-the-wild&rdquo; fine-tuning data (MolParser-SFT-400k) was curated via an efficient active learning data engine with human-in-the-loop validation</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li>The E-SMILES format cannot represent certain complex cases, such as coordination bonds, dashed abstract rings, and Markush structures depicted with special patterns</li>
<li>The model and data do not yet fully exploit molecular chirality, which is critical for chemical properties</li>
<li>Performance could be further improved by scaling up the amount of real annotated training data</li>
</ul>
<h2 id="technical-notes">Technical Notes</h2>
<h3 id="synthetic-data-generation">Synthetic Data Generation</h3>
<p>To ensure diversity, molecular structures were collected from databases like ChEMBL, PubChem, and Kaggle BMS. A significant number of Markush, polymer, and fused-ring structures were also randomly generated. Images were rendered using RDKit and epam.indigo with randomized parameters (e.g., bond width, font size, rotation) to increase visual diversity.</p>
<h3 id="in-the-wild-data-engine-molparser-sft-400k">In-the-Wild Data Engine (MolParser-SFT-400k)</h3>
<p>A YOLOv11 object detection model (MolDet) located and cropped over 20 million molecule images from 1.22 million real PDFs (patents and papers). After de-duplication via p-hash similarity, 4 million unique images remained.</p>
<p>An active learning algorithm was used to select the most informative samples for annotation, targeting images where an ensemble of 5-fold models showed moderate confidence (0.6-0.9 Tanimoto similarity), indicating they were challenging but learnable.</p>
<p>This active learning approach with model pre-annotations reduced manual annotation time per molecule to 30 seconds, a 90% savings compared to annotating from scratch. In the final fine-tuning dataset, 56.04% of annotations directly utilized raw model pre-annotations, 20.97% passed review after a single manual correction, and just 9.13% required three or more rounds of annotation.</p>
<h3 id="e-smiles-specification">E-SMILES Specification</h3>
<p>To accommodate complex patent structures that standard SMILES cannot support, the authors introduced an Extended SMILES format (<code>SMILES&lt;sep&gt;EXTENSION</code>). The <code>EXTENSION</code> component uses XML-like tokens to manage complexities:</p>
<ul>
<li><code>&lt;a&gt;...&lt;/a&gt;</code> encapsulates Markush R-groups and abbreviation groups.</li>
<li><code>&lt;r&gt;...&lt;/r&gt;</code> denotes ring attachments with uncertainty positions.</li>
<li><code>&lt;c&gt;...&lt;/c&gt;</code> defines abstract rings.</li>
<li><code>&lt;dum&gt;</code> identifies a connection point.</li>
</ul>
<p>This format provides a crucial advancement for Markush-molecule matching and LLM integration, while retaining RDKit compatibility for the standard SMILES portion.</p>
]]></content:encoded></item><item><title>DenoiseVAE: Adaptive Noise for Molecular Pre-training</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/denoise-vae/</link><pubDate>Sun, 24 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/denoise-vae/</guid><description>Liu et al.'s ICLR 2025 paper introducing DenoiseVAE, which learns adaptive, atom-specific noise distributions for better molecular force fields.</description><content:encoded><![CDATA[<h2 id="paper-contribution-type">Paper Contribution Type</h2>
<p>This is a <strong>method paper</strong> with a supporting theoretical component. It introduces a new pre-training framework, DenoiseVAE, that challenges the standard practice of using fixed, hand-crafted noise distributions in denoising-based molecular representation learning.</p>
<h2 id="motivation-the-inter--and-intra-molecular-variations-problem">Motivation: The Inter- and Intra-molecular Variations Problem</h2>
<p>The motivation is to create a more physically principled denoising pre-training task for 3D molecules. The core idea of denoising is to learn molecular force fields by corrupting an equilibrium conformation with noise and then learning to recover it. However, existing methods use a single, hand-crafted noise strategy (e.g., Gaussian noise of a fixed scale) for all atoms across all molecules. This is physically unrealistic for two main reasons:</p>
<ol>
<li><strong>Inter-molecular differences</strong>: Different molecules have unique Potential Energy Surfaces (PES), meaning the space of low-energy (i.e., physically plausible) conformations is highly molecule-specific.</li>
<li><strong>Intra-molecular differences (Anisotropy)</strong>: Within a single molecule, different atoms have different degrees of freedom. For instance, an atom in a rigid functional group can move much less than one connected by a single, rotatable bond.</li>
</ol>
<p>The authors argue that this &ldquo;one-size-fits-all&rdquo; noise approach leads to inaccurate force field learning because it samples many physically improbable conformations.</p>
<h2 id="novelty-a-learnable-atom-specific-noise-generator">Novelty: A Learnable, Atom-Specific Noise Generator</h2>
<p>The core novelty is a framework that learns to generate noise tailored to each specific molecule and atom. This is achieved through three key innovations:</p>
<ol>
<li><strong>Learnable Noise Generator</strong>: The authors introduce a Noise Generator module (a 4-layer Equivariant Graph Neural Network) that takes a molecule&rsquo;s equilibrium conformation $X$ as input and outputs a unique, atom-specific Gaussian noise distribution (i.e., a different variance $\sigma_i^2$ for each atom $i$). This directly addresses the issues of PES specificity and force field anisotropy.</li>
<li><strong>Variational Autoencoder (VAE) Framework</strong>: The Noise Generator (encoder) and a Denoising Module (a 7-layer EGNN decoder) are trained jointly within a VAE paradigm. The noisy conformation is sampled using the reparameterization trick:
$$
\begin{aligned}
\tilde{x}_i &amp;= x_i + \epsilon \sigma_i
\end{aligned}
$$</li>
<li><strong>Principled Optimization Objective</strong>: The training loss balances two competing goals:
$$
\begin{aligned}
\mathcal{L}_{DenoiseVAE} &amp;= \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}
\end{aligned}
$$
<ul>
<li>A denoising reconstruction loss ($\mathcal{L}_{Denoise}$) encourages the Noise Generator to produce physically plausible perturbations from which the original conformation can be recovered. This implicitly constrains the noise to respect the molecule&rsquo;s underlying force fields.</li>
<li>A KL divergence regularization term ($\mathcal{L}_{KL}$) pushes the generated noise distributions towards a predefined prior. This prevents the trivial solution of generating zero noise and encourages the model to explore a diverse set of low-energy conformations.</li>
</ul>
</li>
</ol>
<p>The authors also provide a theoretical analysis showing that optimizing their objective is equivalent to maximizing the Evidence Lower Bound (ELBO) on the log-likelihood of observing physically realistic conformations.</p>
<h2 id="methodology--experimental-baselines">Methodology &amp; Experimental Baselines</h2>
<p>The model was pretrained on the PCQM4Mv2 dataset (approximately 3.4 million organic molecules) and then evaluated on a comprehensive suite of downstream tasks to test the quality of the learned representations:</p>
<ol>
<li><strong>Molecular Property Prediction (QM9)</strong>: The model was evaluated on 12 quantum chemical property prediction tasks for small molecules (134k molecules; 100k train, 18k val, 13k test split). DenoiseVAE achieved state-of-the-art or second-best performance on 11 of the 12 tasks, with particularly significant gains on $C_v$ (heat capacity), indicating better capture of vibrational modes.</li>
<li><strong>Force Prediction (MD17)</strong>: The task was to predict atomic forces from molecular dynamics trajectories for 8 different small molecules (9,500 train, 500 val split). DenoiseVAE was the top performer on 5 of the 8 molecules (Aspirin, Benzene, Ethanol, Naphthalene, Toluene), though it underperformed Frad on Malonaldehyde, Salicylic Acid, and Uracil by significant margins.</li>
<li><strong>Ligand Binding Affinity (PDBBind v2019)</strong>: On the PDBBind dataset with 30% and 60% protein sequence identity splits, the model showed strong generalization, outperforming baselines like Uni-Mol particularly on the more stringent 30% split across RMSE, Pearson correlation, and Spearman correlation.</li>
<li><strong>PCQM4Mv2 Validation</strong>: DenoiseVAE achieved a validation MAE of 0.0777 on the PCQM4Mv2 HOMO-LUMO gap prediction task with only 1.44M parameters, competitive with models 10-40x larger (e.g., GPS++ at 44.3M params achieves 0.0778).</li>
<li><strong>Ablation Studies</strong>: The authors analyzed the sensitivity to key hyperparameters, namely the prior&rsquo;s standard deviation ($\sigma$) and the KL-divergence weight ($\lambda$), confirming that $\lambda=1$ and $\sigma=0.1$ are optimal. Removing the KL term leads to trivial solutions (near-zero noise). An additional ablation on the Noise Generator depth found 4 EGNN layers optimal over 2 layers. A comparison of independent (diagonal) versus non-independent (full covariance) noise sampling showed comparable results, suggesting the EGNN already captures inter-atomic dependencies implicitly.</li>
<li><strong>Case Studies</strong>: Visualizations of the learned noise variances for different molecules confirmed that the model learns chemically intuitive noise patterns. For example, it applies smaller perturbations to atoms in a rigid bicyclic norcamphor derivative and larger ones to atoms in flexible functional groups of a cyclopropane derivative. Even identical functional groups (e.g., hydroxyl) receive different noise scales in different molecular contexts.</li>
</ol>
<h2 id="key-findings-on-force-field-learning">Key Findings on Force Field Learning</h2>
<ul>
<li><strong>Primary Conclusion</strong>: Learning a <strong>molecule-adaptive and atom-specific</strong> noise distribution is a superior strategy for denoising-based pre-training compared to using fixed, hand-crafted heuristics. This more physically-grounded approach leads to representations that better capture molecular force fields.</li>
<li><strong>Strong Benchmark Performance</strong>: DenoiseVAE achieves best or second-best results on 11 of 12 QM9 tasks, 5 of 8 MD17 molecules, and leads on the stringent 30% LBA split. Performance is mixed on some MD17 molecules (Malonaldehyde, Salicylic Acid, Uracil), where it trails Frad.</li>
<li><strong>Effective Framework</strong>: The proposed VAE-based framework, which jointly trains a Noise Generator and a Denoising Module, is an effective and theoretically sound method for implementing this adaptive noise strategy. The interplay between the reconstruction loss and the KL-divergence regularization is key to its success.</li>
<li><strong>Limitation and Future Direction</strong>: The method is based on classical force field assumptions. The authors note that integrating more accurate force fields represents a promising direction for future work.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<ul>
<li><strong>Source Code &amp; Weights</strong>: The authors have released their code at <a href="https://github.com/Serendipity-r/DenoiseVAE">Serendipity-r/DenoiseVAE</a> on GitHub.</li>
<li><strong>Implementation</strong>: Hyperparameters and architectures are detailed in the paper&rsquo;s appendix (A.14), and the repository provides reference implementations.</li>
</ul>
<h3 id="data">Data</h3>
<ul>
<li><strong>Pre-training Dataset</strong>: <a href="https://ogb.stanford.edu/docs/lsc/pcqm4mv2/">PCQM4Mv2</a> (approximately 3.4 million organic molecules)</li>
<li><strong>Property Prediction</strong>: <a href="https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.QM9.html">QM9 dataset</a> (134k molecules; 100k train, 18k val, 13k test split) for 12 quantum chemical properties</li>
<li><strong>Force Prediction</strong>: <a href="http://www.sgdml.org/#datasets">MD17 dataset</a> (9,500 train, 500 val split) for 8 different small molecules</li>
<li><strong>Ligand Binding Affinity</strong>: PDBBind v2019 (4,463 protein-ligand complexes) with 30% and 60% sequence identity splits</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise Generator</strong>: 4-layer Equivariant Graph Neural Network (EGNN) that outputs atom-specific Gaussian noise distributions</li>
<li><strong>Denoising Module</strong>: 7-layer EGNN decoder</li>
<li><strong>Training Objective</strong>: $\mathcal{L}_{DenoiseVAE} = \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}$ with $\lambda=1$</li>
<li><strong>Noise Sampling</strong>: Reparameterization trick with $\tilde{x}_i = x_i + \epsilon \sigma_i$</li>
<li><strong>Prior Distribution</strong>: Standard deviation $\sigma=0.1$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Model Size</strong>: 1.44M parameters total</li>
<li><strong>Fine-tuning Protocol</strong>: Noise Generator discarded after pre-training; only the pre-trained Denoising Module (7-layer EGNN) is retained for downstream fine-tuning</li>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate decay (max LR of 0.0005)</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>System Training</strong>: Fine-tuned end-to-end for specific tasks; force prediction involves computing the gradient of the predicted energy</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Ablation Studies</strong>: Sensitivity analysis confirmed $\lambda=1$ and $\sigma=0.1$ as optimal hyperparameters; removing the KL term leads to trivial solutions (near-zero noise)</li>
<li><strong>Noise Generator Depth</strong>: 4 EGNN layers outperformed 2 layers across both QM9 and MD17 benchmarks</li>
<li><strong>Covariance Structure</strong>: Full covariance matrix (non-independent noise sampling) yielded comparable results to diagonal variance (independent sampling), likely because the EGNN already integrates neighboring atom information</li>
<li><strong>O(3) Invariance</strong>: The method satisfies O(3) probabilistic invariance, meaning the noise distribution is unchanged under rotations and reflections</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU Configuration</strong>: Experiments conducted on a single RTX A3090 GPU; 6 GPUs with 144GB total memory sufficient for full reproduction</li>
<li><strong>CPU</strong>: Intel Xeon Gold 5318Y @ 2.10GHz</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, Y., Chen, J., Jiao, R., Li, J., Huang, W., &amp; Su, B. (2025). DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training. <em>The Thirteenth International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liu2025denoisevae,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yurou Liu and Jiahao Chen and Rui Jiao and Jiangmeng Li and Wenbing Huang and Bing Su}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Thirteenth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=ym7pr83XQr}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://iclr.cc/virtual/2025/poster/27701">ICLR 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=ym7pr83XQr">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=ym7pr83XQr">PDF on OpenReview</a></li>
</ul>
]]></content:encoded></item><item><title>Learning Smooth Interatomic Potentials with eSEN</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/learning-smooth-interatomic-potentials/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/learning-smooth-interatomic-potentials/</guid><description>Fu et al. propose energy conservation as a key MLIP diagnostic and introduce eSEN, bridging test accuracy and real performance.</description><content:encoded><![CDATA[<h2 id="paper-overview">Paper Overview</h2>
<p>This is a <strong>method paper</strong>. It addresses a critical disconnect in the evaluation of Machine Learning Interatomic Potentials (MLIPs) and introduces a novel architecture, <strong>eSEN</strong>, designed based on insights from this analysis. The paper proposes a new standard for evaluating MLIPs beyond simple test-set errors.</p>
<h2 id="the-energy-conservation-gap-in-mlip-evaluation">The Energy Conservation Gap in MLIP Evaluation</h2>
<p>The motivation addresses a well-known but under-addressed problem in the field: improvements in standard MLIP metrics (lower energy/force MAE on static test sets) do not reliably translate to better performance on complex downstream tasks like molecular dynamics (MD) simulations, materials stability prediction, or phonon calculations. The authors seek to understand why this gap exists and how to design models that are both accurate on test sets and physically reliable in practical scientific workflows.</p>
<h2 id="the-esen-architecture-and-continuous-representation">The eSEN Architecture and Continuous Representation</h2>
<p>The novelty is twofold, spanning both a conceptual framework for evaluation and a new model architecture:</p>
<ol>
<li>
<p><strong>Energy Conservation as a Diagnostic Test</strong>: The core conceptual contribution is using an MLIP&rsquo;s ability to conserve energy in out-of-distribution MD simulations as a crucial diagnostic test. The authors demonstrate that for models passing this test, a strong correlation between test-set error and downstream task performance is restored.</p>
</li>
<li>
<p><strong>The eSEN Architecture</strong>: The paper introduces the <strong>equivariant Smooth Energy Network (eSEN)</strong>, designed with specific choices to ensure a smooth and well-behaved Potential Energy Surface (PES):</p>
<ul>
<li><strong>Strictly Conservative Forces</strong>: Forces are computed exclusively as the negative gradient of energy ($F = -\nabla E$), using conservative force prediction instead of faster direct-force prediction heads.</li>
<li><strong>Continuous Representations</strong>: Maintains strict equivariance and smoothness by using equivariant gated non-linearities instead of discretizing spherical harmonic representations during nodewise processing.</li>
<li><strong>Smooth PES Construction</strong>: Critical design choices include using distance cutoffs, polynomial envelope functions ensuring derivatives go to zero at cutoffs, and limited radial basis functions to avoid overly sensitive PES.</li>
</ul>
</li>
<li>
<p><strong>Efficient Training Strategy</strong>: A two-stage training regimen with fast pre-training using a non-conservative direct-force model, followed by fine-tuning to enforce energy conservation. This captures the efficiency of direct-force training while ensuring physical robustness.</p>
</li>
</ol>
<h2 id="evaluating-ood-energy-conservation-and-physical-properties">Evaluating OOD Energy Conservation and Physical Properties</h2>
<p>The paper presents a comprehensive experimental validation:</p>
<ol>
<li>
<p><strong>Ablation Studies on Energy Conservation</strong>: MD simulations on out-of-distribution systems (TM23 and MD22 datasets) systematically tested key design choices (direct-force vs. conservative, representation discretization, neighbor limits, envelope functions). This empirically demonstrated which choices lead to energy drift despite negligible impact on test-set MAE.</p>
</li>
<li>
<p><strong>Physical Property Prediction Benchmarks</strong>: The eSEN model was evaluated on challenging downstream tasks:</p>
<ul>
<li><strong>Matbench-Discovery</strong>: Materials stability and thermal conductivity prediction, where eSEN achieved the highest F1 score among compliant models and excelled at both metrics simultaneously.</li>
<li><strong>MDR Phonon Benchmark</strong>: Predicting phonon properties that test accurate second and third-order derivatives of the PES. eSEN achieved state-of-the-art results, particularly outperforming direct-force models.</li>
<li><strong>SPICE-MACE-OFF</strong>: Standard energy and force prediction on organic molecules, demonstrating that physical plausibility design choices enhanced raw accuracy.</li>
</ul>
</li>
<li>
<p><strong>Correlation Analysis</strong>: Explicit plots of test-set energy MAE versus performance on downstream benchmarks showed weak overall correlation that becomes strong and predictive when restricted to models passing the energy conservation test.</p>
</li>
</ol>
<h2 id="outcomes-and-conclusions">Outcomes and Conclusions</h2>
<ul>
<li>
<p><strong>Primary Conclusion</strong>: Energy conservation is a critical, practical property for MLIPs. Using it as a filter re-establishes test-set error as a reliable proxy for model development, dramatically accelerating the innovation cycle. Models that are not conservative, even with low test error, are unreliable for many critical scientific applications.</p>
</li>
<li>
<p><strong>Model Performance</strong>: The eSEN architecture outperforms base models across diverse tasks, from energy/force prediction to geometry optimization, phonon calculations, and thermal conductivity prediction.</p>
</li>
<li>
<p><strong>Actionable Design Principles</strong>: The paper provides experimentally-validated architectural choices that promote physical plausibility. Seemingly minor details, like how atomic neighbors are selected, can have profound impacts on a model&rsquo;s utility in simulations.</p>
</li>
<li>
<p><strong>Efficient Path to Robust Models</strong>: The direct-force pre-training plus conservative fine-tuning strategy offers a practical method for developing physically robust models without incurring the full computational cost of conservative training from scratch.</p>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p>The eSEN architecture builds on components from <strong>eSCN</strong> (Equivariant Spherical Channel Network) and <strong>Equiformer</strong>, combining them with design choices that prioritize smoothness and energy conservation. The implementation integrates into the standard <code>fairchem</code> Open Catalyst experimental framework.</p>
<h4 id="layer-structure">Layer Structure</h4>
<ul>
<li><strong>Edgewise Convolution</strong>: Uses <code>SO2</code> convolution layers (from eSCN) with an envelope function applied. Source and target embeddings are concatenated before convolution.</li>
<li><strong>Nodewise Feed-Forward</strong>: Two equivariant linear layers with an intermediate <strong>SiLU-based gated non-linearity</strong> (from Equiformer).</li>
<li><strong>Normalization</strong>: Equivariant Layer Normalization (from Equiformer).</li>
</ul>
<h4 id="smoothness-design-choices">Smoothness Design Choices</h4>
<p>Several architectural decisions distinguish eSEN from prior work:</p>
<ul>
<li><strong>No Grid Projection</strong>: eSEN performs operations directly in the spherical harmonic space to maintain equivariance and energy conservation, bypassing the projection of spherical harmonics to spatial grids for non-linearity.</li>
<li><strong>Distance Cutoff for Graph Construction</strong>: Uses a strict distance cutoff (6 Å for MPTrj models, 5 Å for SPICE models). Neighbor limits introduce discontinuities that break energy conservation.</li>
<li><strong>Polynomial Envelope Functions</strong>: Ensures derivatives go to zero smoothly at the cutoff radius.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<h4 id="two-stage-training-esen-30m-mp">Two-Stage Training (eSEN-30M-MP)</h4>
<ol>
<li><strong>Direct-Force Pre-training</strong> (60 epochs): Uses <strong>DeNS</strong> (Denoising Non-equilibrium Structures) to reduce overfitting. This stage is fast because it does not require backpropagation through energy gradients.</li>
<li><strong>Conservative Fine-tuning</strong> (40 epochs): The direct-force head is removed, and forces are calculated via gradients ($F = -\nabla E$). This enforces energy conservation.</li>
</ol>
<p><strong>Important</strong>: DeNS is used exclusively during the direct-force pre-training stage, with a noising probability of 0.5, a standard deviation of 0.1 Å for the added Gaussian noise, and a DeNS loss coefficient of 10. The fine-tuning strategy reduces the wall-clock time for model training by 40% compared to training a conservative model from scratch for the same number of total epochs.</p>
<h4 id="optimization">Optimization</h4>
<ul>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate scheduler</li>
<li><strong>Max Learning Rate</strong>: $4 \times 10^{-4}$</li>
<li><strong>Batch Size</strong>: 512 (for MPTrj models)</li>
<li><strong>Weight Decay</strong>: $1 \times 10^{-3}$</li>
<li><strong>Gradient Clipping</strong>: Norm of 100</li>
<li><strong>Warmup</strong>: 0.1 epochs with a factor of 0.2</li>
</ul>
<h4 id="loss-function">Loss Function</h4>
<p>A composite loss combining per-atom energy MAE, force $L_2$ loss, and stress MAE:</p>
<p>$$
\begin{aligned}
\mathcal{L} = \lambda_{\text{e}} \frac{1}{N} \sum_{i=1}^N \lvert E_{i} - \hat{E}_{i} \rvert + \lambda_{\text{f}} \frac{1}{3N} \sum_{i=1}^N \lVert \mathbf{F}_{i} - \hat{\mathbf{F}}_{i} \rVert_2^2 + \lambda_{\text{s}} \lVert \mathbf{S} - \hat{\mathbf{S}} \rVert_1
\end{aligned}
$$</p>
<p>For MPTrj-30M, the weighting coefficients are set to $\lambda_{\text{e}} = 20$, $\lambda_{\text{f}} = 20$, and $\lambda_{\text{s}} = 5$.</p>
<h3 id="data">Data</h3>
<h4 id="training-data">Training Data</h4>
<ul>
<li><strong>Inorganic</strong>: MPTrj (Materials Project Trajectory) dataset</li>
<li><strong>Organic</strong>: SPICE-MACE-OFF dataset</li>
</ul>
<h4 id="test-data-construction">Test Data Construction</h4>
<ul>
<li><strong>MPTrj Testing</strong>: Since MPTrj lacks an official test split, the authors created a test set using 5,000 random samples from the <strong>subsampled Alexandria (sAlex)</strong> dataset to ensure fair comparison.</li>
<li><strong>Out-of-Distribution Conservation Testing</strong>:
<ul>
<li><em>Inorganic</em>: <strong>TM23</strong> dataset (transition metal defects). Simulation: 100 ps, 5 fs timestep.</li>
<li><em>Organic</em>: <strong>MD22</strong> dataset (large molecules). Simulation: 100 ps, 1 fs timestep.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Compute for training operations predominantly utilizes <strong>80GB NVIDIA A100 GPUs</strong>.</p>
<h4 id="inference-efficiency">Inference Efficiency</h4>
<p>For a periodic system of <strong>216 atoms</strong> on a single A100 (PyTorch 2.4.0, CUDA 12.1, no compile/torchscript), the 2-layer eSEN models achieve approximately <strong>0.4 million steps per day</strong> (3.2M parameters) and <strong>0.8 million steps per day</strong> (6.5M parameters), comparable to MACE-OFF-L at 0.7 million steps per day.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The paper evaluated eSEN across three major benchmark tasks. Key evaluation metrics included energy MAE (meV/atom), force MAE (meV/Å), stress MAE (meV/Å/atom), F1 score for stability prediction, $\kappa_{\text{SRME}}$ for thermal conductivity, and phonon frequency accuracy.</p>
<h4 id="ablation-test-set-mae-table-1">Ablation Test-Set MAE (Table 1)</h4>
<p>Design choices that dramatically affect energy conservation have negligible impact on static test-set MAE, which is precisely why test-set error alone is misleading. All models are 2-layer with 3.2M parameters, $L_{\text{max}} = 2$, $M_{\text{max}} = 2$:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Energy MAE</th>
          <th>Force MAE</th>
          <th>Stress MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>eSEN (default)</td>
          <td>17.02</td>
          <td>43.96</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, direct-force</td>
          <td>18.66</td>
          <td>43.62</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>eSEN, neighbor limit</td>
          <td>17.30</td>
          <td>44.11</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, no envelope</td>
          <td>17.60</td>
          <td>44.69</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, $N_{\text{basis}} = 512$</td>
          <td>19.87</td>
          <td>48.29</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, Bessel</td>
          <td>17.65</td>
          <td>44.83</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=6</td>
          <td>17.05</td>
          <td>43.10</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=10</td>
          <td>17.11</td>
          <td>43.13</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=14</td>
          <td>17.12</td>
          <td>43.09</td>
          <td>0.14</td>
      </tr>
  </tbody>
</table>
<p>Energy MAE in meV/atom. Force MAE in meV/Å. Stress MAE in meV/Å/atom.</p>
<h4 id="matbench-discovery-tables-2-and-3">Matbench-Discovery (Tables 2 and 3)</h4>
<p><strong>Compliant models</strong> (trained only on MPTrj or its subset), unique prototype split:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>DAF</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>0.831</strong></td>
          <td><strong>5.260</strong></td>
          <td><strong>0.340</strong></td>
          <td><strong>0.0752</strong></td>
      </tr>
      <tr>
          <td>eqV2-S-DeNS</td>
          <td>0.815</td>
          <td>5.042</td>
          <td>1.676</td>
          <td>0.0757</td>
      </tr>
      <tr>
          <td>MatRIS-MP</td>
          <td>0.809</td>
          <td>5.049</td>
          <td>0.861</td>
          <td>0.0773</td>
      </tr>
      <tr>
          <td>AlphaNet-MP</td>
          <td>0.799</td>
          <td>4.863</td>
          <td>1.31</td>
          <td>0.1067</td>
      </tr>
      <tr>
          <td>DPA3-v2-MP</td>
          <td>0.786</td>
          <td>4.822</td>
          <td>0.959</td>
          <td>0.0823</td>
      </tr>
      <tr>
          <td>ORB v2 MPtrj</td>
          <td>0.765</td>
          <td>4.702</td>
          <td>1.725</td>
          <td>0.1007</td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>0.760</td>
          <td>4.629</td>
          <td>0.550</td>
          <td>0.0847</td>
      </tr>
      <tr>
          <td>GRACE-2L-MPtrj</td>
          <td>0.691</td>
          <td>4.163</td>
          <td>0.525</td>
          <td>0.0897</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>0.669</td>
          <td>3.777</td>
          <td>0.647</td>
          <td>0.0915</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>0.613</td>
          <td>3.361</td>
          <td>1.717</td>
          <td>0.0949</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>0.569</td>
          <td>2.882</td>
          <td>1.412</td>
          <td>0.1117</td>
      </tr>
  </tbody>
</table>
<p>eSEN-30M-MP excels at both F1 and $\kappa_{\text{SRME}}$ simultaneously, while all previous models only achieve SOTA on one or the other.</p>
<p><strong>Non-compliant models</strong> (trained on additional datasets):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-OAM</strong></td>
          <td><strong>0.925</strong></td>
          <td><strong>0.170</strong></td>
          <td><strong>0.0608</strong></td>
      </tr>
      <tr>
          <td>eqV2-M-OAM</td>
          <td>0.917</td>
          <td>1.771</td>
          <td>0.0691</td>
      </tr>
      <tr>
          <td>ORB v3</td>
          <td>0.905</td>
          <td>0.210</td>
          <td>0.0750</td>
      </tr>
      <tr>
          <td>SevenNet-MF-ompa</td>
          <td>0.901</td>
          <td>0.317</td>
          <td>0.0639</td>
      </tr>
      <tr>
          <td>DPA3-v2-OpenLAM</td>
          <td>0.890</td>
          <td>0.687</td>
          <td>0.0679</td>
      </tr>
      <tr>
          <td>GRACE-2L-OAM</td>
          <td>0.880</td>
          <td>0.294</td>
          <td>0.0666</td>
      </tr>
      <tr>
          <td>MatterSim-v1-5M</td>
          <td>0.862</td>
          <td>0.574</td>
          <td>0.0733</td>
      </tr>
      <tr>
          <td>MACE-MPA-0</td>
          <td>0.852</td>
          <td>0.412</td>
          <td>0.0731</td>
      </tr>
  </tbody>
</table>
<p>The eSEN-30M-OAM model starts from eSEN-30M-OMat (trained on OMat24), then is fine-tuned for 1 epoch on a dataset combining sAlex and 8 copies of MPTrj.</p>
<h4 id="mdr-phonon-benchmark-table-4">MDR Phonon Benchmark (Table 4)</h4>
<p>Metrics: maximum phonon frequency MAE($\omega_{\text{max}}$) in K, vibrational entropy MAE($S$) in J/K/mol, Helmholtz free energy MAE($F$) in kJ/mol, heat capacity MAE($C_V$) in J/K/mol.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>MAE($\omega_{\text{max}}$)</th>
          <th>MAE($S$)</th>
          <th>MAE($F$)</th>
          <th>MAE($C_V$)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>21</strong></td>
          <td><strong>13</strong></td>
          <td><strong>5</strong></td>
          <td><strong>4</strong></td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>26</td>
          <td>28</td>
          <td>10</td>
          <td>5</td>
      </tr>
      <tr>
          <td>GRACE-2L (r6)</td>
          <td>40</td>
          <td>25</td>
          <td>9</td>
          <td>5</td>
      </tr>
      <tr>
          <td>SevenNet-0</td>
          <td>40</td>
          <td>48</td>
          <td>19</td>
          <td>9</td>
      </tr>
      <tr>
          <td>MACE</td>
          <td>61</td>
          <td>60</td>
          <td>24</td>
          <td>13</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>89</td>
          <td>114</td>
          <td>45</td>
          <td>21</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>98</td>
          <td>150</td>
          <td>56</td>
          <td>22</td>
      </tr>
  </tbody>
</table>
<p>Direct-force models show dramatically worse performance at the standard 0.01 Å displacement (e.g., eqV2-S-DeNS: 280/224/54/94) but improve at larger displacements (0.2 Å: 58/26/8/8), revealing that their PES is rough near energy minima.</p>
<h4 id="spice-mace-off-table-5">SPICE-MACE-OFF (Table 5)</h4>
<p>Test set MAE for organic molecule energy/force prediction. Energy MAE in meV/atom, force MAE in meV/Å:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MACE-4.7M (E/F)</th>
          <th>EscAIP-45M* (E/F)</th>
          <th>eSEN-3.2M (E/F)</th>
          <th>eSEN-6.5M (E/F)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PubChem</td>
          <td>0.88 / 14.75</td>
          <td>0.53 / 5.86</td>
          <td>0.22 / 6.10</td>
          <td><strong>0.15</strong> / <strong>4.21</strong></td>
      </tr>
      <tr>
          <td>DES370K M.</td>
          <td>0.59 / 6.58</td>
          <td>0.41 / 3.48</td>
          <td>0.17 / 1.85</td>
          <td><strong>0.13</strong> / <strong>1.24</strong></td>
      </tr>
      <tr>
          <td>DES370K D.</td>
          <td>0.54 / 6.62</td>
          <td>0.38 / 2.18</td>
          <td>0.20 / 2.77</td>
          <td><strong>0.15</strong> / <strong>2.12</strong></td>
      </tr>
      <tr>
          <td>Dipeptides</td>
          <td>0.42 / 10.19</td>
          <td>0.31 / 5.21</td>
          <td>0.10 / 3.04</td>
          <td><strong>0.07</strong> / <strong>2.00</strong></td>
      </tr>
      <tr>
          <td>Sol. AA</td>
          <td>0.98 / 19.43</td>
          <td>0.61 / 11.52</td>
          <td>0.30 / 5.76</td>
          <td><strong>0.25</strong> / <strong>3.68</strong></td>
      </tr>
      <tr>
          <td>Water</td>
          <td>0.83 / 13.57</td>
          <td>0.72 / 10.31</td>
          <td>0.24 / 3.88</td>
          <td><strong>0.15</strong> / <strong>2.50</strong></td>
      </tr>
      <tr>
          <td>QMugs</td>
          <td>0.45 / 16.93</td>
          <td>0.41 / 8.74</td>
          <td>0.16 / 5.70</td>
          <td><strong>0.12</strong> / <strong>3.78</strong></td>
      </tr>
  </tbody>
</table>
<p>*EscAIP-45M is a direct-force model. eSEN-6.5M outperforms MACE-OFF-L and EscAIP on all test splits. The smaller eSEN-3.2M has inference efficiency comparable to MACE-4.7M while achieving lower MAE.</p>
<hr>
<h2 id="why-these-design-choices-matter">Why These Design Choices Matter</h2>
<h3 id="bounded-energy-derivatives-and-the-verlet-integrator">Bounded Energy Derivatives and the Verlet Integrator</h3>
<p>The theoretical foundation for why smoothness matters comes from Theorem 5.1 of Hairer et al. (2003). For the Verlet integrator (the standard NVE integrator), the total energy drift satisfies:</p>
<p>$$
|E(\mathbf{r}_T, \mathbf{a}) - E(\mathbf{r}_0, \mathbf{a})| \leq C \Delta t^2 + C_N \Delta t^N T
$$</p>
<p>where $T$ is the total simulation time ($T \leq \Delta t^{-N}$), $N$ is the highest order for which the $N$th derivative of $E$ is continuously differentiable with bounded derivative, and $C$, $C_N$ are constants independent of $T$ and $\Delta t$. The first term is a time-independent fluctuation of $O(\Delta t^2)$; the second term governs long-term conservation. This means the PES must be continuously differentiable to high order, with bounded derivatives, for energy conservation in long-time simulations.</p>
<h3 id="architectural-choices-that-break-conservation">Architectural Choices That Break Conservation</h3>
<p>The authors provide theoretical justification for why specific architectural choices break energy conservation:</p>
<ul>
<li><strong>Max Neighbor Limit (KNN)</strong>: Introduces discontinuity in the PES. If a neighbor at distance $r$ moves to $r + \epsilon$ and drops out of the top-$K$, the energy changes discontinuously.</li>
<li><strong>Grid Discretization</strong>: Projecting spherical harmonics to a spatial grid introduces discretization errors in energy gradients that break conservation. This can be mitigated with higher-resolution grids but not eliminated.</li>
<li><strong>Direct-Force Prediction</strong>: Imposes no mathematical constraint that forces must be the gradient of an energy scalar field. In other words, $\nabla \times \mathbf{F} \neq 0$ is permitted, violating the requirement for a conservative force field.</li>
</ul>
<h3 id="displacement-sensitivity-in-phonon-calculations">Displacement Sensitivity in Phonon Calculations</h3>
<p>An important empirical finding concerns how displacement values affect phonon predictions. Conservative models (eSEN, MACE) show convergent phonon band structures as displacement decreases toward zero. In contrast, direct-force models (eqV2-S-DeNS) fail to converge, exhibiting missing acoustic branches and spurious imaginary frequencies at small displacements. While direct-force models achieve competitive thermodynamic property accuracy at large displacements (0.2 Å), this is deceptive: the underlying phonon band structures remain inaccurate, and the apparent accuracy comes from Boltzmann-weighted integrals smoothing over errors.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fu, X., Wood, B. M., Barroso-Luque, L., Levine, D. S., Gao, M., Dzamba, M., &amp; Zitnick, C. L. (2025). Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{fu2025learning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fu, Xiang and Wood, Brandon M. and Barroso-Luque, Luis and Levine, Daniel S. and Gao, Meng and Dzamba, Misko and Zitnick, C. Lawrence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45302">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=R0PBjxIbgm">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=R0PBjxIbgm">PDF on OpenReview</a></li>
<li><a href="https://huggingface.co/facebook/OMAT24">OMAT24 model on Hugging Face</a></li>
<li><a href="https://github.com/facebookresearch/fairchem">Code on GitHub (fairchem)</a></li>
</ul>
]]></content:encoded></item><item><title>Efficient DFT Hamiltonian Prediction via Adaptive Sparsity</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/efficient-dft-hamiltonian-predicton-sphnet/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/efficient-dft-hamiltonian-predicton-sphnet/</guid><description>Luo et al. introduce SPHNet, using adaptive sparsity to dramatically improve SE(3)-equivariant Hamiltonian prediction efficiency.</description><content:encoded><![CDATA[<h2 id="core-innovation-adaptive-sparsity-in-se3-networks">Core Innovation: Adaptive Sparsity in SE(3) Networks</h2>
<p>This is a <strong>methodological paper</strong> introducing a novel architecture and training curriculum to solve efficiency bottlenecks in Geometric Deep Learning. It directly tackles the primary computational bottleneck in modern SE(3)-equivariant graph neural networks (the tensor product operation) and proposes a generalizable solution through adaptive network sparsification.</p>
<h2 id="the-computational-bottleneck-in-dft-hamiltonian-prediction">The Computational Bottleneck in DFT Hamiltonian Prediction</h2>
<p>SE(3)-equivariant networks are accurate but unscalable for DFT Hamiltonian prediction due to two key bottlenecks:</p>
<ul>
<li><strong>Atom Scaling</strong>: Tensor Product (TP) operations grow quadratically with atoms ($N^2$).</li>
<li><strong>Basis Set Scaling</strong>: Computational complexity grows with the sixth power of the angular momentum order ($L^6$). Larger basis sets (e.g., def2-TZVP) require higher orders ($L=6$), making them prohibitively slow.</li>
</ul>
<p>Existing SE(3)-equivariant models cannot handle large molecules (40-100 atoms) with high-quality basis sets, limiting their practical applicability in computational chemistry.</p>
<h2 id="sphnet-architecture-and-the-three-phase-sparsity-scheduler">SPHNet Architecture and the Three-Phase Sparsity Scheduler</h2>
<p><strong>SPHNet</strong> introduces <strong>Adaptive Sparsity</strong> to prune redundant computations at two levels:</p>
<ol>
<li><strong>Sparse Pair Gate</strong>: Learns which atom pairs to include in message passing, adapting the interaction graph based on importance.</li>
<li><strong>Sparse TP Gate</strong>: Filters which spherical harmonic triplets $(l_1, l_2, l_3)$ are computed in tensor product operations, pruning higher-order combinations that contribute less to accuracy.</li>
<li><strong>Three-Phase Sparsity Scheduler</strong>: A training curriculum (Random → Adaptive → Fixed) that enables stable convergence to high-performing sparse subnetworks.</li>
</ol>
<p>Key insight: The Sparse Pair Gate learns to preserve long-range interactions (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are abundant and easier to learn, while rare long-range interactions require more samples for accurate representation, making them more critical to retain.</p>
<h2 id="benchmarks-and-ablation-studies">Benchmarks and Ablation Studies</h2>
<p>The authors evaluated SPHNet on three datasets (MD17, QH9, and PubChemQH) with varying molecule sizes and basis set complexities. Baselines include SchNOrb, PhiSNet, QHNet, and WANet. SchNOrb and PhiSNet results are limited to MD17, as those models are designed for trajectory datasets. WANet was not open-sourced, so only partial metrics from its paper are reported.</p>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<ul>
<li><strong>Hamiltonian MAE ($H$)</strong>: Mean absolute error between predicted and DFT-computed Hamiltonian matrices, in Hartrees ($E_h$)</li>
<li><strong>Occupied Orbital Energy MAE ($\epsilon$)</strong>: Mean absolute error of all occupied molecular orbital energies derived from the predicted Hamiltonian</li>
<li><strong>Orbital Coefficient Similarity ($\psi$)</strong>: Cosine similarity of occupied molecular orbital coefficients between predicted and reference wavefunctions</li>
</ul>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Sparse Gates</strong> (on PubChemQH):</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Both gates</td>
          <td>97.31</td>
          <td>5.62</td>
          <td>7.09x</td>
      </tr>
      <tr>
          <td>Pair Gate only</td>
          <td>87.70</td>
          <td>6.98</td>
          <td>2.73x</td>
      </tr>
      <tr>
          <td>TP Gate only</td>
          <td>94.31</td>
          <td>8.04</td>
          <td>3.98x</td>
      </tr>
      <tr>
          <td>Neither gate</td>
          <td>86.35</td>
          <td>10.91</td>
          <td>1.73x</td>
      </tr>
  </tbody>
</table>
<p>The Sparse Pair Gate contributes a 78% speedup with 30% memory reduction. The Sparse TP Gate (pruning 70% of combinations) yields a 160% speedup. Both gates together achieve the highest speedup, though accuracy slightly decreases compared to no gating.</p>
<p><strong>Three-Phase Scheduler</strong>: Removing the random phase causes convergence to local optima ($112.68 \pm 10.75$ vs $97.31 \pm 0.52$). Removing the adaptive phase increases variance and lowers accuracy ($122.79 \pm 19.02$). Removing the fixed phase has minimal accuracy impact but reduces speedup from 7.09x to 5.45x due to dynamic graph overhead.</p>
<p><strong>Sparsity Rate</strong>: The critical sparsity threshold scales with system complexity: 30% for MD17 (small molecules), 40% for QH9 (medium), and 70% for PubChemQH (large). Beyond the threshold, MAE increases sharply. Computational cost decreases approximately linearly with sparsity rate.</p>
<h3 id="transferability-to-other-models">Transferability to Other Models</h3>
<p>To demonstrate the speedup is architecture-agnostic, the authors applied the Sparse Pair Gate and Sparse TP Gate to the QHNet baseline on PubChemQH:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet baseline</td>
          <td>123.74</td>
          <td>22.50</td>
          <td>1.00x</td>
      </tr>
      <tr>
          <td>+ TP Gate</td>
          <td>128.16</td>
          <td>12.68</td>
          <td>2.04x</td>
      </tr>
      <tr>
          <td>+ Pair Gate</td>
          <td>126.27</td>
          <td>10.07</td>
          <td>1.66x</td>
      </tr>
      <tr>
          <td>+ Both gates</td>
          <td>128.89</td>
          <td>8.46</td>
          <td>3.30x</td>
      </tr>
  </tbody>
</table>
<p>The gates reduced QHNet&rsquo;s memory by 62% and improved speed by 3.3x with modest accuracy trade-off, confirming the gates are portable modules applicable to other SE(3)-equivariant architectures.</p>
<h2 id="performance-results">Performance Results</h2>
<h3 id="qh9-134k-molecules-leq-20-atoms">QH9 (134k molecules, $\leq$ 20 atoms)</h3>
<p>SPHNet achieves 3.3x to 4.0x speedup over QHNet across all four QH9 splits, with improved Hamiltonian MAE and orbital energy MAE. Memory drops to 0.23 GB/sample (33% of QHNet&rsquo;s 0.70 GB). On the stable-iid split, Hamiltonian MAE improves from 76.31 to 45.48 ($10^{-6} E_h$).</p>
<h3 id="pubchemqh-50k-molecules-40-100-atoms">PubChemQH (50k molecules, 40-100 atoms)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>$\epsilon$ [$E_h$] $\downarrow$</th>
          <th>$\psi$ [$10^{-2}$] $\uparrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet</td>
          <td>123.74</td>
          <td>3.33</td>
          <td>2.32</td>
          <td>22.5</td>
          <td>1.0x</td>
      </tr>
      <tr>
          <td>WANet</td>
          <td>99.98</td>
          <td><strong>1.17</strong></td>
          <td><strong>3.13</strong></td>
          <td>15.0</td>
          <td>2.4x</td>
      </tr>
      <tr>
          <td>SPHNet</td>
          <td><strong>97.31</strong></td>
          <td>2.16</td>
          <td>2.97</td>
          <td><strong>5.62</strong></td>
          <td><strong>7.1x</strong></td>
      </tr>
  </tbody>
</table>
<p>SPHNet achieves the best Hamiltonian MAE and efficiency, though WANet outperforms on orbital energy MAE and coefficient similarity. The higher speedup on PubChemQH (vs QH9) reflects greater computational redundancy in larger systems with higher-order basis sets ($L_{max} = 6$ for def2-TZVP vs $L_{max} = 4$ for def2-SVP).</p>
<h3 id="md17-small-molecule-trajectories">MD17 (Small Molecule Trajectories)</h3>
<p>SPHNet achieves accuracy comparable to QHNet and PhiSNet on four MD17 molecules (water, ethanol, malondialdehyde, uracil; 3-12 atoms). MD17 represents a simpler task where baseline models already perform well, leaving limited room for improvement. For water (3 atoms), the number of interaction combinations is inherently small, limiting the benefit of adaptive sparsification.</p>
<h3 id="scaling-limit">Scaling Limit</h3>
<p>SPHNet can train on systems with approximately 3000 atomic orbitals on a single A6000 GPU; the QHNet baseline runs out of memory at approximately 1800 orbitals. Memory consumption scales more favorably as molecule size increases.</p>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li><strong>Adaptive sparsity scales with system complexity</strong>: The method is most effective for large systems where redundancy is high. For small molecules (e.g., water with only 3 atoms), every interaction is critical, so pruning hurts accuracy and yields negligible speedup.</li>
<li><strong>Long-range pair preservation</strong>: The Sparse Pair Gate selects long-range pairs (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are numerous and easier to learn, while rare long-range interactions are harder to represent and thus more critical to retain.</li>
<li><strong>Generalizable components</strong>: The sparsification techniques are portable modules, demonstrated by successful integration into QHNet with 3.3x speedup.</li>
<li><strong>Architecture ablation</strong>: Removing one Vectorial Node Interaction block or Spherical Node Interaction block significantly hurts accuracy, confirming the importance of the progressive order-increase design. Removing one Pair Construction block has less impact, suggesting room for further speedup.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The experiments evaluated SPHNet on three datasets with different molecular sizes and basis set complexities. All datasets use DFT calculations as ground truth, with MD17 using the PBE exchange-correlation functional and QH9/PubChemQH using B3LYP.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Molecule Size</th>
          <th>Basis Set</th>
          <th>$L_{max}$</th>
          <th>Functional</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MD17</td>
          <td>4 systems</td>
          <td>3-12 atoms (water, ethanol, malondialdehyde, uracil)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>PBE</td>
      </tr>
      <tr>
          <td>QH9</td>
          <td>134k</td>
          <td>$\leq$ 20 atoms (Stable/Dynamic splits)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>B3LYP</td>
      </tr>
      <tr>
          <td>PubChemQH</td>
          <td>50k</td>
          <td>40-100 atoms</td>
          <td>def2-TZVP</td>
          <td>6</td>
          <td>B3LYP</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Availability</strong>:</p>
<ul>
<li><strong>MD17 &amp; QH9</strong>: Publicly available</li>
<li><strong>PubChemQH</strong>: Publicly available on Hugging Face (<a href="https://huggingface.co/datasets/EperLuo/PubChemQH">EperLuo/PubChemQH</a>)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>:</p>
<p>The model learns the <strong>residual</strong> $\Delta H$:</p>
<p>$$
\begin{aligned}
\Delta H &amp;= H_{\text{ref}} - H_{\text{init}} \\
\mathcal{L} &amp;= \text{MAE}(H_{\text{ref}}, H_{\text{pred}}) + \text{MSE}(H_{\text{ref}}, H_{\text{pred}})
\end{aligned}
$$</p>
<p>where $H_{\text{init}}$ is a computationally inexpensive initial guess computed via PySCF.</p>
<p><strong>Hyperparameters</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>PubChemQH</th>
          <th>QH9</th>
          <th>MD17</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Batch Size</td>
          <td>8</td>
          <td>32</td>
          <td>10 (uracil: 5)</td>
      </tr>
      <tr>
          <td>Training Steps</td>
          <td>300k</td>
          <td>260k</td>
          <td>200k</td>
      </tr>
      <tr>
          <td>Warmup Steps</td>
          <td>1k</td>
          <td>1k</td>
          <td>1k</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>1e-3</td>
          <td>1e-3</td>
          <td>5e-4</td>
      </tr>
      <tr>
          <td>Sparsity Rate</td>
          <td>0.7</td>
          <td>0.4</td>
          <td>0.1-0.3</td>
      </tr>
      <tr>
          <td>TSS Epoch $t$</td>
          <td>3</td>
          <td>3</td>
          <td>3</td>
      </tr>
  </tbody>
</table>
<p><strong>Sparse Pair Gate</strong>: Adapts the interaction graph. It concatenates zero-order features and inner products of atom pairs, then passes them through a linear layer $F_p$ with sigmoid activation to learn a weight $W_p^{ij}$ for every pair. Pairs are kept only if selected by the scheduler ($U_p^{TSS}$). The overhead comes primarily from the linear layer $F_p$.</p>
<p><strong>Sparse TP Gate</strong>: Filters triplets $(l_1, l_2, l_3)$ inside the TP operation. Higher-order combinations are more likely to be pruned. Complexity: $\mathcal{O}(L^3)$.</p>
<p><strong>Three-Phase Sparsity Scheduler</strong>: Training curriculum designed to optimize the sparse gates effectively:</p>
<ul>
<li><strong>Phase 1 (Random)</strong>: Random selection ($1-k$ probability) to ensure unbiased weight updates. Complexity: $\mathcal{O}(|U|)$.</li>
<li><strong>Phase 2 (Adaptive)</strong>: Selects top $(1-k)$ percent based on learned magnitude. Complexity: $\mathcal{O}(|U|\log|U|)$.</li>
<li><strong>Phase 3 (Fixed)</strong>: Freezes the connectivity mask for maximum inference speed. No overhead.</li>
</ul>
<p><strong>Weight Initialization</strong>: Learnable sparsity weights ($W$) initialized as all-ones vector.</p>
<h3 id="models">Models</h3>
<p>The model predicts the Hamiltonian matrix $H$ from atomic numbers $Z$ and coordinates $r$.</p>
<p><strong>Inputs</strong>: Atomic numbers ($Z$) and 3D coordinates.</p>
<p><strong>Backbone Structure</strong>:</p>
<ol>
<li><strong>Vectorial Node Interaction (x4)</strong>: Uses long-short range message passing. Extracts vectorial representations ($l=1$) without high-order TPs to save cost.</li>
<li><strong>Spherical Node Interaction (x2)</strong>: Projects features to high-order spherical harmonics (up to $L_{max}$). The first block increases the maximum order from 0 to $L_{max}$ without the Sparse Pair Gate; the second block applies the <strong>Sparse Pair Gate</strong> to filter node pairs.</li>
<li><strong>Pair Construction Block (x2)</strong>: Splits into <strong>Diagonal</strong> (self-interaction) and <strong>Non-Diagonal</strong> (cross-interaction) blocks. Both use the <strong>Sparse TP Gate</strong> to prune cross-order combinations $(l_1, l_2, l_3)$. The Non-Diagonal block also uses the <strong>Sparse Pair Gate</strong>. The first Pair Construction block does not use the Sparse Pair Gate, to ensure complete information flow.</li>
<li><strong>Expansion Block</strong>: Reconstructs the full Hamiltonian matrix from the sparse irreducible representations, exploiting symmetry ($H_{ji} = H_{ij}^T$) to halve computations.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4x NVIDIA A100 (80GB)</li>
<li><strong>Benchmarking</strong>: Single NVIDIA RTX A6000 (46GB)</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, E., Wei, X., Huang, L., Li, Y., Yang, H., Xia, Z., Wang, Z., Liu, C., Shao, B., &amp; Zhang, J. (2025). Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>, Vancouver, Canada.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{luo2025efficient,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Erpai and Wei, Xinran and Huang, Lin and Li, Yunyang and Yang, Han and Xia, Zaishuo and Wang, Zun and Liu, Chang and Shao, Bin and Zhang, Jia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45656">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=K3lykWhXON">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=K3lykWhXON">PDF on OpenReview</a></li>
<li><a href="https://github.com/microsoft/SPHNet">GitHub Repository</a> <em>(Note: The official repository was archived by Microsoft in December 2025. It is available for reference but no longer actively maintained.)</em></li>
</ul>
]]></content:encoded></item><item><title>Dark Side of Forces: Non-Conservative ML Force Models</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/dark-side-of-forces/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/dark-side-of-forces/</guid><description>Bigi et al. critique non-conservative force models in ML potentials, showing their simulation failures and proposing hybrid solutions.</description><content:encoded><![CDATA[<h2 id="contribution-systematic-assessment-of-non-conservative-ml-force-models">Contribution: Systematic Assessment of Non-Conservative ML Force Models</h2>
<p>This is a <strong>Systematization</strong> paper. It systematically catalogs the exact failure modes of existing non-conservative force approaches, quantifies them with a new diagnostic metric, and proposes a hybrid Multiple Time-Stepping solution combining the speed benefits of direct force prediction with the physical correctness of conservative models.</p>
<h2 id="motivation-the-speed-accuracy-trade-off-in-ml-force-fields">Motivation: The Speed-Accuracy Trade-off in ML Force Fields</h2>
<p>Many recent machine learning interatomic potential (MLIP) architectures predict forces directly ($F_\theta(r)$). This &ldquo;non-conservative&rdquo; approach avoids the computational overhead of automatic differentiation, yielding faster inference (typically 2-3x speedup) and faster training (up to 3x). However, it sacrifices energy conservation and rotational constraints, potentially destabilizing molecular dynamics simulations. The field lacks rigorous quantification of when this trade-off breaks down and how to mitigate the failures.</p>
<h2 id="novelty-jacobian-asymmetry-and-hybrid-architectures">Novelty: Jacobian Asymmetry and Hybrid Architectures</h2>
<p>Four key contributions:</p>
<ol>
<li>
<p><strong>Jacobian Asymmetry Metric ($\lambda$):</strong> A quantitative diagnostic for non-conservation. Since conservative forces derive from a scalar field, their Jacobian (the Hessian of energy) must be symmetric. The normalized norm of the antisymmetric part quantifies the degree of violation:
$$ \lambda = \frac{|| \mathbf{J}_{\text{anti}} ||_F}{|| \mathbf{J} ||_F} $$
where $\mathbf{J}_{\text{anti}} = (\mathbf{J} - \mathbf{J}^\top)/2$. Measured values range from $\lambda \approx 0.004$ (PET-NC) to $\lambda \approx 0.032$ (SOAP-BPNN-NC), with ORB at 0.015 and EquiformerV2 at 0.017. Notably, the pairwise $\lambda_{ij}$ approaches 1 at large interatomic distances, meaning non-conservative artifacts disproportionately affect long-range and collective interactions.</p>
</li>
<li>
<p><strong>Systematic Failure Mode Catalog:</strong> First comprehensive demonstration that non-conservative models cause runaway heating in NVE ensembles (temperature drifts of ~7,000 billion K/s for PET-NC and ~10x larger for ORB) and equipartition violations in NVT ensembles where different atom types equilibrate to different temperatures, a physical impossibility.</p>
</li>
<li>
<p><strong>Theoretical Analysis of Force vs. Energy Training:</strong> Force-only training overemphasizes high-frequency vibrational modes because force labels carry per-atom gradients that are dominated by stiff, short-range interactions. Energy labels provide a more balanced representation across the frequency spectrum. Additionally, conservative models benefit from backpropagation extending the effective receptive field to approximately 2x the interaction cutoff, while direct-force models are limited to the nominal cutoff radius.</p>
</li>
<li>
<p><strong>Hybrid Training and Inference Protocol:</strong> A practical workflow that combines fast direct-force prediction with conservative corrections:</p>
<ul>
<li><strong>Training:</strong> Pre-train on direct forces, then fine-tune on energy gradients (2-4x faster than training conservative models from scratch)</li>
<li><strong>Inference:</strong> Multiple Time-Stepping (MTS) where fast non-conservative forces are periodically corrected by slower conservative forces</li>
</ul>
</li>
</ol>
<h2 id="methodology-systematic-failure-mode-analysis">Methodology: Systematic Failure Mode Analysis</h2>
<p>The evaluation systematically tests multiple state-of-the-art models across diverse simulation scenarios:</p>
<p><strong>Models tested:</strong></p>
<ul>
<li><strong>PET-C/PET-NC</strong> (Point Edge Transformer, conservative and non-conservative variants)</li>
<li><strong>PET-M</strong> (hybrid variant jointly predicting both conservative and non-conservative forces)</li>
<li><strong>ORB-v2</strong> (non-conservative, trained on Alexandria/MPtrj)</li>
<li><strong>EquiformerV2</strong> (non-conservative equivariant Transformer)</li>
<li><strong>MACE-MP-0</strong> (conservative message-passing)</li>
<li><strong>SevenNet</strong> (conservative message-passing)</li>
<li><strong>SOAP-BPNN-C/SOAP-BPNN-NC</strong> (descriptor-based baseline, both conservative and non-conservative variants)</li>
</ul>
<p><strong>Test scenarios:</strong></p>
<ol>
<li><strong>NVE stability tests</strong> on bulk liquid water, graphene, amorphous carbon, and FCC aluminum</li>
<li><strong>Thermostat artifact analysis</strong> with Langevin and GLE thermostats</li>
<li><strong>Geometry optimization</strong> on water snapshots and QM9 molecules using FIRE and L-BFGS</li>
<li><strong>MTS validation</strong> on OC20 catalysis dataset</li>
<li><strong>Species-resolved temperature measurements</strong> for equipartition testing</li>
</ol>
<p><strong>Key metrics:</strong></p>
<ul>
<li>Jacobian asymmetry ($\lambda$)</li>
<li>Kinetic temperature drift in NVE</li>
<li>Velocity-velocity correlations</li>
<li>Radial distribution functions</li>
<li>Species-resolved temperatures</li>
<li>Inference speed benchmarks</li>
</ul>
<h2 id="results-simulation-instability-and-hybrid-solutions">Results: Simulation Instability and Hybrid Solutions</h2>
<p>Purely non-conservative models are <strong>unsuitable for production simulations</strong> due to uncontrollable unphysical artifacts that no thermostat can correct. Key findings:</p>
<p><strong>Performance failures:</strong></p>
<ul>
<li>Non-conservative models exhibited catastrophic temperature drift in NVE simulations: ~7,000 billion K/s for PET-NC and ~70,000 billion K/s for ORB, with EquiformerV2 comparable to PET-NC</li>
<li>Strong Langevin thermostats ($\tau=10$ fs) damped diffusion by ~5x, negating the speed benefits of non-conservative models</li>
<li>Advanced GLE thermostats also failed to control non-conservative drift (ORB reached 1181 K vs. 300 K target)</li>
<li>Equipartition violations: under stochastic velocity rescaling, O and H atoms equilibrated at different temperatures. For ORB, H atoms reached 336 K and O atoms 230 K against a 300 K target. For PET-NC, deviations were smaller but still significant (H at 296 K, O at 310 K).</li>
<li>Geometry optimization was more fragile with non-conservative forces: inaccurate NC models (SOAP-BPNN-NC) failed catastrophically, while more accurate ones (PET-NC) could converge with FIRE but showed large force fluctuations with L-BFGS. Non-conservative models consistently had lower success rates across water and QM9 benchmarks.</li>
</ul>
<p><strong>Hybrid solution success:</strong></p>
<ul>
<li>MTS with non-conservative forces corrected every 8 steps ($M=8$) achieved conservative stability with only ~20% overhead compared to a purely non-conservative trajectory. Results were essentially indistinguishable from fully conservative simulations. Higher stride values ($M=16$) became unstable due to resonances between fast degrees of freedom and integration errors.</li>
<li>Conservative fine-tuning achieved the accuracy of from-scratch training in about 1/3 the total training time (2-4x resource reduction)</li>
<li>Validated on OC20 catalysis benchmark</li>
</ul>
<p><strong>Scaling caveat:</strong> The authors note that as training datasets grow and models become more expressive, non-conservative artifacts should diminish because accurate models naturally exhibit less non-conservative behavior. However, they argue the best path forward is hybrid approaches rather than waiting for scale to solve the problem.</p>
<p><strong>Recommendation:</strong> The optimal production path is hybrid architectures using direct forces for acceleration (via MTS and pre-training) while anchoring models in conservative energy surfaces. This captures computational benefits without sacrificing physical reliability.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Primary training/evaluation:</strong></p>
<ul>
<li><strong>Bulk Liquid Water</strong> (Cheng et al., 2019): revPBE0-D3 calculations with over 250,000 force/energy targets, chosen for rigorous thermodynamic testing</li>
</ul>
<p><strong>Generalization tests:</strong></p>
<ul>
<li>Graphene, amorphous carbon, FCC aluminum (tested with general-purpose foundation models)</li>
</ul>
<p><strong>Benchmarks:</strong></p>
<ul>
<li><strong>QM9</strong>: Geometry optimization tests</li>
<li><strong>OC20</strong> (Open Catalyst): Oxygen on alloy surfaces for MTS validation</li>
</ul>
<p>All datasets publicly available through cited sources.</p>
<h3 id="models">Models</h3>
<p><strong>Point Edge Transformer (PET)</strong> variants:</p>
<ul>
<li><strong>PET-C (Conservative)</strong>: Forces via energy backpropagation</li>
<li><strong>PET-NC (Non-Conservative)</strong>: Direct force prediction head, slightly higher parameter count</li>
<li><strong>PET-M (Hybrid)</strong>: Jointly predicts both conservative and non-conservative forces, accuracy within ~10% of the best single-task models</li>
</ul>
<p><strong>Baseline comparisons:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Training Data</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ORB-v2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Rotationally unconstrained</td>
      </tr>
      <tr>
          <td>EquiformerV2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Equivariant Transformer</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-C</td>
          <td>Conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-NC</td>
          <td>Non-conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
  </tbody>
</table>
<p><strong>Training details:</strong></p>
<ul>
<li><strong>Loss functions</strong>: PET-C uses joint Energy + Force $L^2$ loss; PET-NC uses Force-only $L^2$ loss</li>
<li><strong>Fine-tuning protocol</strong>: PET-NC converted to conservative via energy head fine-tuning</li>
<li><strong>MTS configuration</strong>: Non-conservative forces with conservative corrections every 8 steps ($M=8$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics &amp; Software:</strong>
Molecular dynamics evaluations were performed using <strong>i-PI</strong>, while geometry optimizations used <strong>ASE (Atomic Simulation Environment)</strong>. Note that primary code reproducibility is provided via an archived Zenodo snapshot; the authors did not link a live, public GitHub repository.</p>
<ol>
<li><strong>Jacobian asymmetry</strong> ($\lambda$): Quantifies non-conservation via antisymmetric component</li>
<li><strong>Temperature drift</strong>: NVE ensemble stability</li>
<li><strong>Velocity-velocity correlation</strong> ($\hat{c}_{vv}(\omega)$): Thermostat artifact detection</li>
<li><strong>Radial distribution functions</strong> ($g(r)$): Structural accuracy</li>
<li><strong>Species-resolved temperature</strong>: Equipartition testing</li>
<li><strong>Inference speed</strong>: Wall-clock time per MD step</li>
</ol>
<p><strong>Key results:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Speed (ms/step)</th>
          <th>NVE Stability</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PET-NC</td>
          <td>8.58</td>
          <td>Failed</td>
          <td>~7,000 billion K/s drift</td>
      </tr>
      <tr>
          <td>PET-C</td>
          <td>19.4</td>
          <td>Stable</td>
          <td>2.3x slower than PET-NC</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>52.8</td>
          <td>Stable</td>
          <td>Conservative baseline</td>
      </tr>
      <tr>
          <td><strong>PET Hybrid (MTS)</strong></td>
          <td><strong>~10.3</strong></td>
          <td><strong>Stable</strong></td>
          <td><strong>~20% overhead vs. pure NC</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Thermostat artifacts:</strong></p>
<ul>
<li>Langevin ($\tau=10$ fs) dampened diffusion by ~5x (weaker coupling at $\tau=100$ fs reduced diffusion by ~1.5x)</li>
<li>GLE thermostats also failed to control non-conservative drift</li>
<li>Equipartition violations under SVR: ORB showed H at 336 K and O at 230 K (target 300 K); PET-NC showed smaller but significant species-resolved deviations</li>
</ul>
<p><strong>Optimization failures:</strong></p>
<ul>
<li>Non-conservative models showed lower geometry optimization success rates across water and QM9 benchmarks, with inaccurate NC models failing catastrophically</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute resources:</strong></p>
<ul>
<li><strong>Training</strong>: From-scratch baseline models were trained using 4x Nvidia H100 GPUs (over a duration of around two days).</li>
<li><strong>Fine-Tuning</strong>: Conservative fine-tuning was performed using a single (1x) Nvidia H100 GPU for a duration of one day.</li>
<li>This hybrid fine-tuning approach achieved a 2-4x reduction in computational resources compared to training conservative models from scratch.</li>
</ul>
<p><strong>Reproduction resources:</strong></p>
<ul>
<li><a href="https://zenodo.org/records/14778891">Zenodo repository</a> (code and data)</li>
<li><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS inference tutorial</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative fine-tuning tutorial</a></li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bigi, F., Langer, M. F., &amp; Ceriotti, M. (2025). The dark side of the forces: assessing non-conservative force models for atomistic machine learning. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{bigi2025dark,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{The dark side of the forces: assessing non-conservative force models for atomistic machine learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bigi, Filippo and Langer, Marcel F and Ceriotti, Michele}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45458">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/pdf?id=OEl3L8osas">PDF on OpenReview</a></li>
<li><a href="https://zenodo.org/records/14778891">Zenodo repository</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS Inference Tutorial</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative Fine-Tuning Tutorial</a></li>
</ul>
]]></content:encoded></item><item><title>Beyond Atoms: 3D Space Modeling for Molecular Pretraining</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/beyond-atoms/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-modeling/beyond-atoms/</guid><description>Lu et al. introduce SpaceFormer, a Transformer that models entire 3D molecular space including atoms for superior representations.</description><content:encoded><![CDATA[<h2 id="paper-typology-and-contribution">Paper Typology and Contribution</h2>
<p>This is a <strong>Method</strong> paper. It challenges the atom-centric paradigm of molecular representation learning by proposing a novel framework that models the continuous 3D space surrounding atoms. The core contribution is <strong>SpaceFormer</strong>, a Transformer-based architecture that discretizes molecular space into grids to capture physical phenomena (electron density, electromagnetic fields) often missed by traditional point-cloud models.</p>
<h2 id="the-physical-intuition-modeling-empty-space">The Physical Intuition: Modeling &ldquo;Empty&rdquo; Space</h2>
<p><strong>The Gap</strong>: Prior 3D molecular representation models, such as Uni-Mol, treat molecules as discrete sets of atoms, essentially point clouds in 3D space. However, from a quantum physics perspective, the &ldquo;empty&rdquo; space between atoms is far from empty. It is permeated by electron density distributions and electromagnetic fields that determine molecular properties.</p>
<p><strong>The Hypothesis</strong>: Explicitly modeling this continuous 3D space alongside discrete atom positions yields superior representations for downstream tasks, particularly for computational properties that depend on electronic structure, such as HOMO/LUMO energies and energy gaps.</p>
<h2 id="a-surprising-observation-virtual-points-improve-representations">A Surprising Observation: Virtual Points Improve Representations</h2>
<p>Before proposing SpaceFormer, the authors present a simple yet revealing experiment. They augment Uni-Mol by adding randomly sampled virtual points (VPs) from the 3D space within the circumscribed cuboid of each molecule. These VPs carry no chemical information whatsoever: they are purely random noise points.</p>
<p>The result is surprising: adding just 10 random VPs already yields a noticeable improvement in validation loss. The improvement remains consistent and gradually increases as the number of VPs grows, eventually reaching a plateau. This observation holds across downstream tasks as well, with Uni-Mol + VPs improving on several quantum property predictions (LUMO, E1-CC2, E2-CC2) compared to vanilla Uni-Mol.</p>
<p>The implication is that even uninformative spatial context helps the model learn better representations, motivating a principled framework for modeling the full 3D molecular space.</p>
<h2 id="spaceformer-voxelization-and-3d-positional-encodings">SpaceFormer: Voxelization and 3D Positional Encodings</h2>
<p>The key innovation is treating the molecular representation problem as <strong>3D space modeling</strong>. SpaceFormer follows these core steps:</p>
<ol>
<li><strong>Voxelizes the entire 3D space</strong> into a grid with cells of $0.49\text{\AA}$ (based on O-H bond length to ensure at most one atom per cell).</li>
<li><strong>Uses adaptive multi-resolution grids</strong> to efficiently handle empty space, keeping it fine-grained near atoms and coarse-grained far away.</li>
<li><strong>Applies Transformers to 3D spatial tokens</strong> with custom positional encodings that achieve linear complexity.</li>
</ol>
<p>Specifically, the model utilizes two forms of 3D Positional Encoding:</p>
<p><strong>3D Directional PE (RoPE Extension)</strong>
They extend Rotary Positional Encoding (RoPE) to 3D continuous space by splitting the Query and Key vectors into three blocks (one for each spatial axis). The directional attention mechanism takes the form:</p>
<p>$$
\begin{aligned}
\mathbf{q}_{i}^{\top} \mathbf{k}_{j} = \sum_{s=1}^{3} \mathbf{q}_{i,s}^{\top} \mathbf{R}(c_{j,s} - c_{i,s}) \mathbf{k}_{j,s}
\end{aligned}
$$</p>
<p><strong>3D Distance PE (RFF Approximation)</strong>
To compute invariant geometric distance without incurring quadratic memory overhead, they use Random Fourier Features (RFF) to approximate a Gaussian kernel of pairwise distances:</p>
<p>$$
\begin{aligned}
\exp \left( - \frac{| \mathbf{c}_i - \mathbf{c}_j |_2^2}{2\sigma^2} \right) &amp;\approx z(\mathbf{c}_i)^\top z(\mathbf{c}_j) \\
z(\mathbf{c}_i) &amp;= \sqrt{\frac{2}{d}} \cos(\sigma^{-1} \mathbf{c}_i^\top \boldsymbol{\omega} + \mathbf{b})
\end{aligned}
$$</p>
<p>This approach enables the model to natively encode complex field-like phenomena without computing exhaustive $O(N^2)$ distance matrices.</p>
<h2 id="experimental-setup-and-downstream-tasks">Experimental Setup and Downstream Tasks</h2>
<p><strong>Pretraining Data</strong>: 19 million unlabeled molecules from the same dataset used by Uni-Mol.</p>
<p><strong>Downstream Benchmarks</strong>: The authors propose a new benchmark of 15 tasks, motivated by known limitations of MoleculeNet: invalid structures, inconsistent chemical representations, data curation errors, and an inability to adequately distinguish model performance. The tasks split into two categories:</p>
<ol>
<li>
<p><strong>Computational Properties (Quantum Mechanics)</strong></p>
<ul>
<li>Subsets of <a href="/notes/computational-chemistry/datasets/gdb-17/">GDB-17</a> (HOMO, LUMO, GAP energy prediction, 20K samples; E1-CC2, E2-CC2, f1-CC2, f2-CC2, 21.7K samples)</li>
<li>Polybenzenoid hydrocarbons (Dipole moment, adiabatic ionization potential, D3 dispersion correction)</li>
<li>Metric: Mean Absolute Error (MAE)</li>
</ul>
</li>
<li>
<p><strong>Experimental Properties (Pharma/Bio)</strong></p>
<ul>
<li>MoleculeNet tasks (BBBP, BACE for drug discovery)</li>
<li>Biogen ADME tasks (HLM, MME, Solubility)</li>
<li>Metrics: AUC for classification, MAE for regression</li>
</ul>
</li>
</ol>
<p><strong>Splitting Strategy</strong>: All datasets use 8:1:1 train/validation/test ratio with <strong>scaffold splitting</strong> to test out-of-distribution generalization.</p>
<p><strong>Training Setup</strong>:</p>
<ul>
<li><strong>Objective</strong>: Masked Auto-Encoder (MAE) with 30% random masking. Model predicts whether a cell contains an atom, and if so, regresses both atom type and precise offset position.</li>
<li><strong>Hardware</strong>: ~50 hours on 8 NVIDIA A100 GPUs</li>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9, \beta_2=0.99$)</li>
<li><strong>Learning Rate</strong>: Peak 1e-4 with linear decay and 0.01 warmup ratio</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>Total Updates</strong>: 1 million</li>
</ul>
<p><strong>Baseline Comparisons</strong>: GROVER (2D graph-based MPR), GEM (2D graph enhanced with 3D information), 3D Infomax (GNN with 3D information), Uni-Mol (3D MPR, primary baseline using the same pretraining dataset), and Mol-AE (extends Uni-Mol with atom-based MAE pretraining).</p>
<h2 id="results-and-analysis">Results and Analysis</h2>
<p><strong>Strong Contextual Performance</strong>: SpaceFormer ranked 1st in 10 of 15 tasks and in the top 2 for 14 of 15 tasks. It surpassed the runner-up models by approximately 20% on quantum property tasks (HOMO, LUMO, GAP, E1-CC2, Dipmom), validating that modeling non-atom space captures electronic structure better than atom-only regimes.</p>
<h3 id="key-results-on-quantum-properties">Key Results on Quantum Properties</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>GROVER</th>
          <th>GEM</th>
          <th>3D Infomax</th>
          <th>Uni-Mol</th>
          <th>Mol-AE</th>
          <th><strong>SpaceFormer</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HOMO (Ha)</td>
          <td>0.0075</td>
          <td>0.0068</td>
          <td>0.0065</td>
          <td>0.0052</td>
          <td>0.0050</td>
          <td><strong>0.0042</strong></td>
      </tr>
      <tr>
          <td>LUMO (Ha)</td>
          <td>0.0086</td>
          <td>0.0080</td>
          <td>0.0070</td>
          <td>0.0060</td>
          <td>0.0057</td>
          <td><strong>0.0040</strong></td>
      </tr>
      <tr>
          <td>GAP (Ha)</td>
          <td>0.0109</td>
          <td>0.0107</td>
          <td>0.0095</td>
          <td>0.0081</td>
          <td>0.0080</td>
          <td><strong>0.0064</strong></td>
      </tr>
      <tr>
          <td>E1-CC2 (eV)</td>
          <td>0.0101</td>
          <td>0.0090</td>
          <td>0.0089</td>
          <td>0.0067</td>
          <td>0.0070</td>
          <td><strong>0.0058</strong></td>
      </tr>
      <tr>
          <td>Dipmom (Debye)</td>
          <td>0.0752</td>
          <td>0.0289</td>
          <td>0.0291</td>
          <td>0.0106</td>
          <td>0.0113</td>
          <td><strong>0.0083</strong></td>
      </tr>
  </tbody>
</table>
<p>SpaceFormer&rsquo;s advantage is most pronounced on computational properties that depend on electronic structure. On experimental biological tasks (e.g., BBBP), where measurements are noisy, the advantage narrows or reverses: Uni-Mol achieves 0.9066 AUC on BBBP compared to SpaceFormer&rsquo;s 0.8605.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>The authors present several ablations that isolate the source of SpaceFormer&rsquo;s improvements:</p>
<p><strong>MAE vs. Denoising</strong>: SpaceFormer with MAE pretraining outperforms SpaceFormer with denoising on all four ablation tasks. The MAE objective requires predicting <em>whether</em> an atom exists in a masked voxel, which forces the model to learn global structural dependencies. In the denoising variant, only atom cells are masked so the model never needs to predict atom existence, reducing the task to coordinate regression.</p>
<p><strong>FLOPs Control</strong>: A SpaceFormer-Large model (4x width, atom-only) trained with comparable FLOPs still falls short of SpaceFormer with 1000 non-atom cells on most downstream tasks. This confirms the improvement comes from modeling 3D space, not from additional compute.</p>
<p><strong>Virtual Points vs. SpaceFormer</strong>: Adding up to 200 random virtual points to Uni-Mol improves some tasks but leaves a significant gap compared to SpaceFormer, demonstrating that principled space discretization outperforms naive point augmentation.</p>
<p><strong>Efficiency Validation</strong>: The Adaptive Grid Merging method reduces the number of cells by roughly 10x with virtually no performance degradation. The 3D positional encodings scale linearly with the number of cells, while Uni-Mol&rsquo;s pretraining cost scales quadratically.</p>
<h3 id="scope-and-future-directions">Scope and Future Directions</h3>
<p>SpaceFormer does not incorporate built-in SE(3) equivariance, relying instead on data augmentation (random rotations and random boundary padding) during training. The authors identify extending SpaceFormer to force field tasks and larger systems such as proteins and complexes as promising future directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="code-and-data-availability">Code and Data Availability</h3>
<ul>
<li><strong>Source Code</strong>: As of the current date, the authors have not released the official source code or pre-trained weights.</li>
<li><strong>Datasets</strong>: Pretraining utilized the same 19M unlabeled molecule dataset as Uni-Mol. Downstream tasks use a newly curated internal benchmark built from subsets of GDB-17, MoleculeNet, and Biogen ADME. The exact customized scaffold splits for these evaluations are pending the official code release.</li>
<li><strong>Compute</strong>: Pretraining the base SpaceFormer encoder (~67.8M parameters, configured to merge level 3) required approximately 50 hours on 8 NVIDIA A100 GPUs.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model treats a molecule as a 3D &ldquo;image&rdquo; via voxelization, processed by a Transformer.</p>
<p><strong>Input Representation</strong>:</p>
<ul>
<li><strong>Discretization</strong>: 3D space divided into grid cells with length <strong>$0.49\text{\AA}$</strong> (based on O-H bond length to ensure at most one atom per cell)</li>
<li><strong>Tokenization</strong>: Tokens are pairs $(t_i, c_i)$ where $t_i$ is atom type (or NULL) and $c_i$ is the coordinate</li>
<li><strong>Embeddings</strong>: Continuous embeddings with dimension 512. Inner-cell positions discretized with $0.01\text{\AA}$ precision</li>
</ul>
<p><strong>Transformer Specifications</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Layers</th>
          <th>Attention Heads</th>
          <th>Embedding Dim</th>
          <th>FFN Dim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Encoder</strong></td>
          <td>16</td>
          <td>8</td>
          <td>512</td>
          <td>2048</td>
      </tr>
      <tr>
          <td><strong>Decoder</strong> (MAE)</td>
          <td>4</td>
          <td>4</td>
          <td>256</td>
          <td>1024</td>
      </tr>
  </tbody>
</table>
<p><strong>Attention Mechanism</strong>: FlashAttention for efficient handling of large sequence lengths.</p>
<p><strong>Positional Encodings</strong>:</p>
<ol>
<li><strong>3D Directional PE</strong>: Extension of Rotary Positional Embedding (RoPE) to 3D continuous space, capturing relative directionality</li>
<li><strong>3D Distance PE</strong>: Random Fourier Features (RFF) to approximate Gaussian kernel of pairwise distances with linear complexity</li>
</ol>
<h4 id="visualizing-rff-and-rope">Visualizing RFF and RoPE</h4>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-rff-rope-visualization.webp"
         alt="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         title="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visual intuition for SpaceFormer&rsquo;s positional encodings: Top row shows RFF distance encoding (Gaussian-like attention decay and high-frequency feature fingerprints). Bottom row shows RoPE directional encoding (vector rotation fields and resulting attention patterns).</figcaption>
    
</figure>

<p><strong>Top Row (Distance / RFF):</strong> Shows how the model learns &ldquo;closeness.&rdquo; Distance is represented by a complex &ldquo;fingerprint&rdquo; of waves that creates a Gaussian-like force field.</p>
<ul>
<li><strong>Top Left (The Force Field):</strong> The attention score (dot product) naturally forms a Gaussian curve. It is high when atoms are close and decays to zero as they move apart. This mimics physical forces without the model needing to learn that math from scratch.</li>
<li><strong>Top Right (The Fingerprint):</strong> Each dimension oscillates at a different frequency. A specific distance (e.g., $d=2$) has a unique combination of high and low values across these dimensions, creating a unique &ldquo;fingerprint&rdquo; for that exact distance.</li>
</ul>
<p><strong>Bottom Row (Direction / RoPE):</strong> Shows how the model learns &ldquo;relative position.&rdquo; It visualizes the vector rotation and how that creates a grid-like attention pattern.</p>
<ul>
<li><strong>Bottom Left (The Rotation):</strong> This visualizes the &ldquo;X-axis chunk&rdquo; of the vector. As you move from left ($x=-3$) to right ($x=3$), the arrows rotate. The model compares angles between atoms to determine relative positions.</li>
<li><strong>Bottom Right (The Grid):</strong> The resulting attention pattern when combining X-rotations and Y-rotations. The red/blue regions show where the model pays attention relative to the center, forming a grid-like interference pattern that distinguishes relative positions (e.g., &ldquo;top-right&rdquo; vs &ldquo;bottom-left&rdquo;).</li>
</ul>
<h4 id="adaptive-grid-merging">Adaptive Grid Merging</h4>
<p>To make the 3D grid approach computationally tractable, two key strategies are employed:</p>
<ol>
<li><strong>Grid Sampling</strong>: Randomly selecting 10-20% of empty cells during training</li>
<li><strong>Adaptive Grid Merging</strong>: Recursively merging $2 \times 2 \times 2$ blocks of empty cells into larger &ldquo;coarse&rdquo; cells, creating a multi-resolution view that is fine-grained near atoms and coarse-grained in empty space (merging set to Level 3)</li>
</ol>
<p><strong>Visualizing Adaptive Grid Merging</strong>:</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-merging.webp"
         alt="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         title="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging demonstrated on H₂O. Red cells (Level 0) contain atoms and remain at full resolution. Progressively darker blue cells represent merged empty regions at higher levels, covering the same volume with fewer tokens.</figcaption>
    
</figure>

<p>The adaptive grid process compresses empty space around molecules while maintaining high resolution near atoms:</p>
<ul>
<li><strong>Red Cells (Level 0):</strong> The smallest squares ($0.49$Å) containing atoms. These are kept at highest resolution because electron density changes rapidly here.</li>
<li><strong>Light Blue Cells (Level 0/1):</strong> Small empty regions close to atoms.</li>
<li><strong>Darker Blue Cells (Level 2/3):</strong> Large blocks of empty space further away.</li>
</ul>
<p>If we used a naive uniform grid, we would have to process thousands of empty &ldquo;Level 0&rdquo; cells containing almost zero information. By merging them into larger blocks (the dark blue squares), the model covers the same volume with significantly fewer input tokens, reducing the number of tokens by roughly <strong>10x</strong> compared to a dense grid.</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-benzene.webp"
         alt="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         title="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging for benzene (C₆H₆). The model maintains maximum resolution (red Level 0 cells) only where atoms exist, while merging vast empty regions into large blocks (dark blue L3/L4 cells). This allows the model to focus computational power on chemically active zones.</figcaption>
    
</figure>

<p>The benzene example above demonstrates how this scales to larger molecules. The characteristic hexagonal ring of 6 carbon atoms (black) and 6 hydrogen atoms (white) occupies a small fraction of the total grid. The dark blue corners (L3, L4) represent massive merged blocks of empty space, allowing the model to focus 90% of its computational power on the red &ldquo;active&rdquo; zones where chemistry actually happens.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lu, S., Ji, X., Zhang, B., Yao, L., Liu, S., Gao, Z., Zhang, L., &amp; Ke, G. (2025). Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>. <a href="https://proceedings.mlr.press/v267/lu25e.html">https://proceedings.mlr.press/v267/lu25e.html</a></p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lu2025beyond,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lu, Shuqi and Ji, Xiaohong and Zhang, Bohang and Yao, Lin and Liu, Siyuan and Gao, Zhifeng and Zhang, Linfeng and Ke, Guolin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=Wd9KPQCKwq">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=Wd9KPQCKwq">PDF on OpenReview</a></li>
<li><a href="https://icml.cc/virtual/2025/poster/45004">ICML 2025 poster page</a></li>
</ul>
]]></content:encoded></item><item><title>Contrastive Learning for Variational Autoencoder Priors</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</link><pubDate>Sun, 17 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</guid><description>Aneja et al.'s NeurIPS 2021 paper introducing Noise Contrastive Priors (NCPs) to address VAE's 'prior hole' problem with energy-based priors.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a novel training approach for Variational Autoencoders (VAEs) to address fundamental limitations in their generative quality through improved prior learning.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by a critical limitation in Variational Autoencoders known as the <strong>&ldquo;prior hole&rdquo; problem</strong>, where the prior distribution p(z) fails to match the aggregate approximate posterior q(z). This mismatch leads to areas in the latent space with high density under the prior that don&rsquo;t map to realistic data samples, resulting in poor generative quality compared to GANs and other generative models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/vae-prior-hole-problem-illustrated.webp"
         alt="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         title="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;prior hole&rsquo; problem: the standard Gaussian prior (red dashed contours) assigns highest probability to the center, but the aggregate posterior (blue dots) forms a ring with no data in that region.</figcaption>
    
</figure>

<p>The figure above illustrates this mismatch. The blue dots represent where a trained encoder actually places data in the latent space (the aggregate posterior $q(z)$), which often forms complex, non-Gaussian shapes. The red dashed contours show the standard Gaussian prior $p(z) = \mathcal{N}(0, I)$, which assumes data is centered at the origin. When generating new samples, we draw from this prior, making it likely to sample from the empty &ldquo;hole&rdquo; where the decoder has never seen training data, producing unrealistic outputs.</p>
<p>A natural question arises: the prior $p(z)$ is used for <em>sampling</em> at inference time, so why does learning a better prior also improve <em>likelihood</em> (NLL)? The answer lies in the VAE objective. VAEs maximize the Evidence Lower Bound (ELBO):</p>
<p>$$ \log p(x) \geq \mathcal{L}_{\text{ELBO}}(x) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(z|x) \parallel p(z))}_{\text{Regularization}} $$</p>
<p>The KL divergence term penalizes the mismatch between each data point&rsquo;s approximate posterior $q(z|x)$ and the prior $p(z)$. When the prior is a simple Gaussian but the aggregate posterior forms a complex shape (as in the figure above), this KL term remains unnecessarily high for every data point.</p>
<p>By replacing the simple prior with a learned $p_{\text{NCP}}(z)$ that matches the aggregate posterior, the KL penalty decreases, tightening the ELBO and improving NLL. The learned prior thus provides a <strong>unified solution</strong>: better likelihood during training (tighter bound) and better sampling at inference (no &ldquo;holes&rdquo;).</p>
<p>The OpenReview discussion contains a significant theoretical debate regarding the paper&rsquo;s core premise. Reviewers argued that the &ldquo;prior hole&rdquo; problem is actually a failure of the posterior to match the prior, or a failure of the encoder. The authors defended their approach by noting that even with a perfect posterior, a simple Normal prior might fail because the decoder lacks capacity to map a simple distribution to complex data without dropping modes. This justifies fixing the prior by making it learned and complex.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The authors propose a novel <strong>energy-based model (EBM) prior</strong> that is trained using <strong>Noise Contrastive Estimation (NCE)</strong>, which they term a <strong>Noise Contrastive Prior (NCP)</strong>. The key innovations are:</p>
<ul>
<li><strong>Two-Stage Training Process</strong>: First, a standard VAE is trained with a simple base prior. Then, the VAE weights are frozen and a binary classifier learns to distinguish between samples from the aggregate posterior q(z) and the base prior p(z).</li>
<li><strong>Reweighting Strategy</strong>: The core idea is to reweight a base prior distribution p(z) with a learned reweighting factor r(z) to make the resulting prior $p_{\text{NCP}}(z)$ better match the aggregate posterior q(z).</li>
<li><strong>NCE for EBM Training</strong>: The method frames EBM training as a binary classification task to avoid computationally expensive MCMC sampling.</li>
<li><strong>Scalability to Hierarchical Models</strong>: For hierarchical VAEs with multiple latent groups, the NCP approach can be applied independently and in parallel to each group&rsquo;s conditional prior.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling</li>
<li><strong>CelebA 64x64</strong>: Applied to both standard VAE architectures and more advanced VAEs with GMM priors (RAE model)</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>The experiments compared FID scores, likelihood metrics, and qualitative sample quality between baseline VAEs and NCP-enhanced versions, with particular focus on state-of-the-art hierarchical VAEs (NVAE).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The proposed NCP method demonstrated improvements in generative quality across evaluated datasets, with modest gains on standard VAEs and particularly large gains on hierarchical models like NVAE:</p>
<ul>
<li><strong>CelebA-64</strong>: NCP improved FID scores from 48.12 to 41.28 for standard VAEs, and from 40.95 to 39.00 for RAE models with GMM priors.</li>
<li><strong>Hierarchical Models (NVAE)</strong>: The impact was particularly pronounced on state-of-the-art hierarchical VAEs:
<ul>
<li><strong>CIFAR-10</strong>: FID improved from 51.71 to 24.08</li>
<li><strong>CelebA-64</strong>: FID improved dramatically from 13.48 to 5.25, making it competitive with GANs</li>
<li><strong>CelebA HQ 256x256</strong>: FID reduced from 40.26 to 24.79</li>
</ul>
</li>
<li><strong>Likelihood Performance</strong>: On MNIST, NCP-VAE achieved 78.10 nats NLL vs. baseline NVAE&rsquo;s 78.67 nats</li>
</ul>
<p>The key conclusions are that <strong>two-stage training with noise contrastive estimation</strong> provides an effective framework for learning expressive energy-based priors that addresses the prior hole problem while scaling efficiently to hierarchical models.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="algorithms">Algorithms</h3>
<h4 id="the-reweighting-mechanism">The Reweighting Mechanism</h4>
<p>The core innovation is defining the NCP prior as $p_{\text{NCP}}(z) \propto p(z)r(z)$. The reweighting factor $r(z)$ is derived from the binary classifier $D(z)$ using the <strong>likelihood ratio trick</strong>:</p>
<p>$$ r(z) \approx \frac{D(z)}{1 - D(z)} $$</p>
<p>Here, $D(z)$ is the sigmoid output of the trained discriminator, representing the probability that sample $z$ came from the aggregate posterior $q(z)$ (&ldquo;real&rdquo;). For an optimal discriminator $D^*(z)$, this ratio exactly equals $\frac{q(z)}{p(z)}$, allowing the model to approximate the density ratio without explicit density estimation.</p>















<figure class="post-figure center ">
    <img src="/img/notes/ncp-vae-reweighting-the-prior-posterior.webp"
         alt="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         title="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reweighting mechanism: the learned factor $r(z)$ (bottom) reweights the simple Gaussian prior $p(z)$ (middle) to approximate the complex aggregate posterior $q(z)$ (top). Where $q(z)$ has high density but $p(z)$ is low, $r(z)$ compensates with high values.</figcaption>
    
</figure>

<h4 id="hierarchical-architecture-strategy">Hierarchical Architecture Strategy</h4>
<p>For hierarchical models (like NVAE), the method trains $K$ binary classifiers in parallel (one for each latent group). Crucially, to ensure efficiency, the classifiers reuse the <strong>context feature</strong> $c(z_{&lt;k})$ extracted by the frozen VAE&rsquo;s prior network. This architectural choice provides significant computational savings.</p>
<h4 id="test-time-sampling-inference">Test-Time Sampling (Inference)</h4>
<p>Since $p_{\text{NCP}}(z)$ is an energy-based model, direct sampling is impossible. The paper employs two methods to generate samples:</p>
<ul>
<li><strong>Sampling-Importance-Resampling (SIR):</strong> Used for most results. It draws $M$ samples (e.g., $M=5000$) from the base prior $p(z)$ and resamples them based on weights $w^{(m)} = r(z^{(m)})$.</li>
<li><strong>Langevin Dynamics (LD):</strong> An iterative refinement method using the gradient of the energy function $E(z) = -\log r(z) - \log p(z)$.</li>
</ul>
<h3 id="models">Models</h3>
<h4 id="decoder-architecture">Decoder Architecture</h4>
<p>For RGB datasets (CIFAR-10, CelebA), the output likelihood must be changed from <strong>Discretized Logistic</strong> (standard NVAE) to a <strong>Normal distribution</strong>. The authors note this change alone led to &ldquo;significant improvements in the base model performance.&rdquo; Using the standard NVAE decoder will result in a weaker baseline than reported.</p>
<h4 id="discriminator-architecture">Discriminator Architecture</h4>
<p>The binary classifier uses a ResNet-style architecture with <strong>Squeeze-and-Excitation (SE)</strong> blocks:</p>
<ul>
<li><strong>Activation:</strong> Swish</li>
<li><strong>Normalization:</strong> Batch Normalization</li>
<li><strong>Optimization:</strong> Adam with Cosine Annealing (learning rate: $10^{-3} \to 10^{-7}$)</li>
</ul>
<p>The SE blocks help the model focus on channel-wise feature recalibration, which is important for distinguishing subtle differences between prior and aggregate posterior in high-dimensional latent spaces.</p>
<h3 id="hardware">Hardware</h3>
<p>The main paper is vague on training time, but the OpenReview rebuttal explicitly lists hardware costs:</p>
<ul>
<li><strong>Hardware:</strong> NVIDIA Tesla V100 (32GB) GPUs</li>
<li><strong>Per-Discriminator Training:</strong> ~13 hours for 100 epochs</li>
<li><strong>Parallelization:</strong> Because latent groups are independent, all discriminators can train in parallel</li>
<li><strong>Total Cost (CelebA-64):</strong> ~8.1 GPU-days</li>
<li><strong>Comparison:</strong> The authors argue this is efficient compared to VDVAE, which requires ~560 GPU-days</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<h4 id="inference-speed-vs-quality-trade-off">Inference Speed vs. Quality Trade-off</h4>
<p>Reviewers flagged that SIR sampling can be prohibitively slow. The authors clarified the exact trade-off:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Proposal Samples ($M$)</th>
          <th style="text-align: left">Time per Image</th>
          <th style="text-align: left">FID (CelebA-64)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">5,000 (paper default)</td>
          <td style="text-align: left">~10.11 seconds</td>
          <td style="text-align: left">5.25</td>
      </tr>
      <tr>
          <td style="text-align: left">500 (practical)</td>
          <td style="text-align: left">~1.25 seconds</td>
          <td style="text-align: left">6.76</td>
      </tr>
  </tbody>
</table>
<p>The quality gain from 500 to 5,000 samples is marginal. <strong>For practical applications, $M=500$ is recommended.</strong></p>
<h4 id="hyperparameters">Hyperparameters</h4>
<ul>
<li><strong>FID Calculation:</strong> 50,000 samples</li>
<li><strong>SIR Proposals:</strong> 5,000 samples (paper default) or 500 (practical)</li>
<li><strong>MNIST:</strong> Dynamically binarized version used for likelihood evaluation</li>
<li><strong>Optimizers:</strong> The study largely adopts hyperparameters from baseline papers (e.g., Lawson et al. for MNIST, Ghosh et al. for RAE)</li>
</ul>
<h4 id="debugging-benchmark-25-gaussians">Debugging Benchmark: 25-Gaussians</h4>
<p>The supplement provides a toy experiment ideal for verifying a new implementation before running on expensive image datasets:</p>
<ul>
<li><strong>Setup:</strong> Synthetic dataset of 25 2D-Gaussians arranged on a grid</li>
<li><strong>Target NLL:</strong> ~-0.954 nats (NCP) vs. ~-2.753 nats (Vanilla VAE)</li>
<li><strong>Success Criterion:</strong> Samples should avoid low-density regions between grid points. A standard VAE will generate samples in these &ldquo;prior holes,&rdquo; while a working NCP implementation should cleanly remove these artifacts.</li>
</ul>
<h4 id="implementation-warnings">Implementation Warnings</h4>
<ul>
<li><strong>SIR Failure Mode:</strong> If the learned prior $p_{\text{NCP}}$ deviates too far from the base prior, SIR sampling collapses (low effective sample size). The authors suggest Hamiltonian Monte Carlo (HMC) as an alternative.</li>
<li><strong>Temperature Scaling:</strong> The quantitative results (FID) use temperature $T=1.0$, but qualitative images in the paper likely use reduced temperature for improved visual sharpness.</li>
</ul>
<h3 id="data">Data</h3>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling (32x32 RGB images)</li>
<li><strong>CelebA 64x64</strong>: Face generation task with moderate resolution</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>All datasets use standard train/test splits from the computer vision literature.</p>
<h4 id="additional-metrics">Additional Metrics</h4>
<p>Beyond FID and NLL, the paper uses:</p>
<ul>
<li><strong>Effective Sample Size (ESS):</strong> Validates reliability of the SIR sampling procedure</li>
<li><strong>Maximum Mean Discrepancy (MMD):</strong> Measures distance between aggregate posterior and NCP prior distributions</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Aneja, J., Schwing, A. G., Kautz, J., &amp; Vahdat, A. (2021). A contrastive learning approach for training variational autoencoder priors. <em>Advances in Neural Information Processing Systems</em>, 34, 29604-29616. <a href="https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html">https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{aneja2021contrastive,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Contrastive Learning Approach for Training Variational Autoencoder Priors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Aneja, Jyoti and Schwing, Alexander G and Kautz, Jan and Vahdat, Arash}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{29604--29616}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview Discussion</a></li>
<li><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code Repository</a> (Google Drive; link may become inaccessible over time)</li>
</ul>
]]></content:encoded></item><item><title>3D Steerable CNNs: Rotationally Equivariant Features</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/3d-steerable-cnns/</link><pubDate>Thu, 16 Jan 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/3d-steerable-cnns/</guid><description>Weiler et al.'s NeurIPS 2018 paper introducing SE(3)-equivariant CNNs for volumetric data using group theory and spherical harmonics.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a novel neural network architecture, the 3D Steerable CNN. It provides a comprehensive theoretical derivation for the architecture grounded in group representation theory and demonstrates its practical application.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by the prevalence of <strong>symmetry</strong> in problems from the natural sciences. Standard 3D CNNs lack inherent equivariance to 3D rotations, a fundamental symmetry governed by the SE(3) group in many scientific datasets like molecular or protein structures. Building this symmetry directly into the model architecture as an <strong>inductive bias</strong> is expected to yield more data-efficient, generalizable, and physically meaningful models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/3d-cnn-versus-3d-steerable-cnn.webp"
         alt="Comparison of standard 3D CNN versus 3D Steerable CNN for handling rotational symmetry. Panel A shows how standard CNNs produce distorted outputs when inputs are rotated, requiring data augmentation. Panel B shows how Steerable CNNs use spherical harmonic kernel bases to produce equivariant geometric field outputs that transform predictably under rotation."
         title="Comparison of standard 3D CNN versus 3D Steerable CNN for handling rotational symmetry. Panel A shows how standard CNNs produce distorted outputs when inputs are rotated, requiring data augmentation. Panel B shows how Steerable CNNs use spherical harmonic kernel bases to produce equivariant geometric field outputs that transform predictably under rotation."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Standard 3D CNNs (Panel A) produce inconsistent feature maps when inputs are rotated, requiring expensive data augmentation. 3D Steerable CNNs (Panel B) use analytically-derived spherical harmonic kernels to produce geometric field outputs that transform equivariantly under rotation.</figcaption>
    
</figure>

<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the rigorous and practical construction of a CNN architecture that is equivariant to 3D rigid body motions (SE(3) group). The key contributions are:</p>
<ul>
<li><strong>Geometric Feature Representation</strong>: Features are modeled as geometric <strong>fields</strong> (collections of scalars, vectors, and higher-order tensors) defined over $\mathbb{R}^{3}$. Each type of feature transforms according to an <strong>irreducible representation (irrep)</strong> of the rotation group SO(3).</li>
<li><strong>General Equivariant Convolution</strong>: The paper proves that the most general form of an SE(3)-equivariant linear map between these fields is a convolution with a <strong>rotation-steerable kernel</strong>.</li>
<li><strong>Analytical Kernel Basis</strong>: The main theoretical breakthrough is the analytical derivation of a complete basis for these steerable kernels. They solve the kernel&rsquo;s equivariance constraint, $\kappa(rx) = D^{j}(r)\kappa(x)D^{l}(r)^{-1}$, showing the solutions are functions whose angular components are <strong>spherical harmonics</strong>. The network&rsquo;s kernels are then parameterized as a learnable linear combination of these pre-computed basis functions, making the implementation a minor modification to standard 3D convolutions.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/spherical-harmonics.webp"
         alt="Spherical harmonics visualization showing the angular basis functions organized by degree l (rows) and order m (columns). Row 0 shows the single s-type orbital (l=0), row 1 shows three p-type orbitals (l=1), row 2 shows five d-type orbitals (l=2), and row 3 shows seven f-type orbitals (l=3)."
         title="Spherical harmonics visualization showing the angular basis functions organized by degree l (rows) and order m (columns). Row 0 shows the single s-type orbital (l=0), row 1 shows three p-type orbitals (l=1), row 2 shows five d-type orbitals (l=2), and row 3 shows seven f-type orbitals (l=3)."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spherical harmonics $Y_l^m$ organized by degree $l$ (rows) and order $m$ (columns). These functions form the angular basis for steerable kernels: $l=0$ (scalar), $l=1$ (vector/p-orbital), $l=2$ (rank-2 tensor/d-orbital), $l=3$ (rank-3 tensor/f-orbital). Each degree $l$ has $2l+1$ components.</figcaption>
    
</figure>

<ul>
<li><strong>Equivariant Nonlinearity</strong>: A novel <strong>gated nonlinearity</strong> is proposed for non-scalar features. It preserves equivariance by multiplying a feature field by a separately computed, learned scalar field (the gate).</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The model&rsquo;s performance was evaluated on a series of tasks with inherent rotational symmetry:</p>
<ol>
<li><strong>Tetris Classification</strong>: A toy problem to empirically validate the model&rsquo;s rotational equivariance by training on aligned blocks and testing on randomly rotated ones.</li>
<li><strong>SHREC17 3D Model Classification</strong>: A benchmark for classifying complex 3D shapes that are arbitrarily rotated.</li>
<li><strong>Amino Acid Propensity Prediction</strong>: A scientific application to predict amino acid types from their 3D atomic environments.</li>
<li><strong>CATH Protein Structure Classification</strong>: A challenging task on a new dataset introduced by the authors, requiring classification of global protein architecture, a problem with full SE(3) invariance.</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The 3D Steerable CNN demonstrated significant advantages due to its built-in equivariance:</p>
<ul>
<li>It was empirically confirmed to be <strong>rotationally equivariant</strong>, achieving 99% test accuracy on the rotated Tetris dataset, compared to a standard 3D CNN&rsquo;s 27% accuracy.</li>
<li>On the amino acid prediction task the model achieves 0.58 accuracy, compared to 0.50 (regular-grid) and 0.56 (concentric-grid) baselines, using roughly half the parameters. On SHREC17 it reaches micro+macro MAP of 1.11 against 1.13 for the leading contemporary system.</li>
<li>On the CATH protein classification task, it <strong>outperformed a deep 3D CNN baseline</strong> while using ~110x fewer parameters. This performance gap widened as the training data was reduced, highlighting the model&rsquo;s superior <strong>data efficiency</strong>.</li>
</ul>
<p>The paper concludes that 3D Steerable CNNs provide a universal and effective framework for incorporating SE(3) symmetry into deep learning models, leading to improved accuracy and efficiency for tasks involving volumetric data, particularly in scientific domains.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Format</strong>: All inputs must be voxelized. Point clouds require voxelization before use.
<ul>
<li><strong>Proteins (CATH)</strong>: $50^3$ grid, 0.2 nm voxel size. Gaussian density placed at atom centers.</li>
<li><strong>3D Objects (SHREC17)</strong>: $64^3$ voxel grids.</li>
<li><strong>Tetris</strong>: $36^3$ input grid.</li>
</ul>
</li>
<li><strong>Splitting Strategy</strong>: CATH used a 10-fold split (7 train, 1 val, 2 test) strictly separated by &ldquo;superfamily&rdquo; level to prevent data leakage (&lt;40% sequence identity).</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Kernel Basis Construction</strong>:</p>
<ul>
<li>Constructed from <strong>Spherical Harmonics</strong> multiplied by <strong>Gaussian Radial Shells</strong>: $\exp\left(-\frac{1}{2}(|x|-m)^{2}/\sigma^{2}\right)$</li>
<li><strong>Anti-aliasing</strong>: A radially dependent angular frequency cutoff ($J_{\max}$) is applied to prevent aliasing near the origin.</li>
</ul>
<p><strong>Normalization</strong>: Uses <strong>Equivariant Batch Norm</strong>. Non-scalar fields are normalized by the average of their norms.</p>
<p><strong>Downsampling</strong>: Standard strided convolution breaks equivariance. The architecture uses <strong>low-pass filtering</strong> (Gaussian blur) before downsampling to maintain equivariance.</p>
<p><strong>Exact Architecture Configurations</strong>:</p>
<p><strong>Tetris Architecture</strong> (4 layers):</p>
<table>
  <thead>
      <tr>
          <th>Layer</th>
          <th>Field Types</th>
          <th>Spatial Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Input</td>
          <td>-</td>
          <td>$36^3$</td>
      </tr>
      <tr>
          <td>Layer 1</td>
          <td>4 scalars, 4 vectors ($l=1$), 4 tensors ($l=2$)</td>
          <td>$40^3$</td>
      </tr>
      <tr>
          <td>Layer 2</td>
          <td>16 scalars, 16 vectors, 16 tensors</td>
          <td>$22^3$ (stride 2)</td>
      </tr>
      <tr>
          <td>Layer 3</td>
          <td>32 scalars, 16 vectors, 16 tensors</td>
          <td>$13^3$ (stride 2)</td>
      </tr>
      <tr>
          <td>Output</td>
          <td>8 scalars (global average pool)</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p><strong>SHREC17 Architecture</strong> (8 layers):</p>
<table>
  <thead>
      <tr>
          <th>Layers</th>
          <th>Field Types</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1-2</td>
          <td>8 scalars, 4 vectors, 2 tensors ($l=2$)</td>
      </tr>
      <tr>
          <td>3-4</td>
          <td>16 scalars, 8 vectors, 4 tensors</td>
      </tr>
      <tr>
          <td>5-7</td>
          <td>32 scalars, 16 vectors, 8 tensors</td>
      </tr>
      <tr>
          <td>8</td>
          <td>512 scalars</td>
      </tr>
      <tr>
          <td>Output</td>
          <td>55 scalars (classes)</td>
      </tr>
  </tbody>
</table>
<p><strong>CATH Architecture</strong> (ResNet34-inspired):</p>
<p>Block progression: <code>(2,2,2,2)</code> → <code>(4,4,4)</code> → <code>(8,8,8,8)</code> → <code>(16,16,16,16)</code></p>
<p>Notation: <code>(a,b,c,d)</code> = $a$ scalars ($l=0$), $b$ vectors ($l=1$), $c$ rank-2 tensors ($l=2$), $d$ rank-3 tensors ($l=3$).</p>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Parameter Counts</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Model</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CATH</td>
          <td>3D Steerable CNN</td>
          <td>143,560</td>
      </tr>
      <tr>
          <td>CATH</td>
          <td>Baseline (ResNet34-style)</td>
          <td>15,878,764</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>3D Steerable CNN</td>
          <td>~32,600,000</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>Regular grid baseline</td>
          <td>~61,100,000</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>Concentric grid baseline</td>
          <td>~75,300,000</td>
      </tr>
  </tbody>
</table>
<p>Note: The concentric grid baseline groups voxels by distance from the molecular center, reflecting that atomic interactions are primarily distance-dependent (Torng, W., &amp; Altman, R. B. (2017). 3D deep convolutional neural networks for amino acid environment similarity analysis. <em>BMC Bioinformatics</em>, 18, 302). Amino acid parameter counts are rounded figures as reported in the paper.</p>
<p><strong>Hyperparameters &amp; Training</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Optimizer</strong></td>
          <td>Adam</td>
      </tr>
      <tr>
          <td><strong>LR Scheduler</strong></td>
          <td>Exponential decay (0.94/epoch) after 40 epoch burn-in</td>
      </tr>
      <tr>
          <td><strong>Dropout</strong> (CATH)</td>
          <td>0.1 (Capsule-wide convolutional dropout)</td>
      </tr>
      <tr>
          <td><strong>Weight Decay</strong> (CATH)</td>
          <td>L1 &amp; L2 regularization: $10^{-8.5}$</td>
      </tr>
  </tbody>
</table>
<p><strong>Mathematical Formulations for Equivariance</strong>:</p>
<p>Standard operations like Batch Normalization and ReLU break rotational equivariance. The paper derives equivariant alternatives:</p>
<p><strong>Equivariant Batch Normalization</strong>:</p>
<p>Standard BN subtracts a mean, which introduces a preferred direction and breaks symmetry. <strong>Norm-based normalization</strong> scales feature fields by the average of their squared norms to preserve symmetry:</p>
<p>$$f_{i}(x) \mapsto f_{i}(x) \left( \frac{1}{|B|} \sum_{j \in B} \frac{1}{V} \int dx |f_{j}(x)|^{2} + \epsilon \right)^{-1/2}$$</p>
<p>This scales vector lengths to unit variance on average while avoiding mean subtraction, preserving directional information and symmetry.</p>
<p><strong>Equivariant Nonlinearities</strong>:</p>
<p>Applying ReLU to vector components independently breaks equivariance (this depends on the coordinate frame). Two approaches:</p>
<ol>
<li>
<p><strong>Norm Nonlinearity</strong> (geometric shrinking): Acts on magnitude, preserves direction. Shrinks vectors shorter than learned bias $\beta$ to zero:
$$f(x) \mapsto \text{ReLU}(|f(x)| - \beta) \frac{f(x)}{|f(x)|}$$
<em>Note: Found to converge slowly; omitted from final models.</em></p>
</li>
<li>
<p><strong>Gated Nonlinearity</strong> (used in practice): A separate scalar field $s(x)$ passes through sigmoid to create a gate $\sigma(s(x))$, which multiplies the geometric field:
$$f_{\text{out}}(x) = f_{\text{in}}(x) \cdot \sigma(s(x))$$
<em>Architecture implication: Requires extra scalar channels ($l=0$) specifically for gating higher-order channels ($l&gt;0$).</em></p>
</li>
</ol>
<p><strong>Voxelization Details</strong>:</p>
<p>For CATH protein inputs, Gaussian density is placed at each atom position with standard deviation equal to <strong>half the voxel width</strong> ($0.5 \times 0.2\text{ nm} = 0.1\text{ nm}$).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>Steerable CNN</th>
          <th>Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tetris (rotated test)</td>
          <td>Accuracy</td>
          <td>99%</td>
          <td>27% (standard 3D CNN)</td>
      </tr>
      <tr>
          <td>Amino Acid Propensity</td>
          <td>Accuracy</td>
          <td><strong>0.58</strong> (32.6M params)</td>
          <td>0.50 (regular grid, 61.1M params); 0.56 (concentric grid, 75.3M params)</td>
      </tr>
      <tr>
          <td>SHREC17</td>
          <td>micro + macro MAP (higher is better)</td>
          <td>1.11</td>
          <td>1.13 (SOTA)</td>
      </tr>
      <tr>
          <td>CATH</td>
          <td>Accuracy</td>
          <td>Higher across all training set sizes (see Figure 4; not reported as a single value) (143,560 params)</td>
          <td>Deep 3D CNN (15,878,764 params; ~110x more)</td>
      </tr>
  </tbody>
</table>
<p>Note: On SHREC17, the metric is micro MAP + macro MAP combined (higher is better); Steerable CNN: 0.717 micro (from Table 4) + ~0.394 macro (back-calculated: 1.11 - 0.717) = 1.11. On CATH, the steerable CNN outperformed the baseline with ~110x fewer parameters, a gap that widened as training data was reduced.</p>
<h2 id="historical-context-from-peer-reviews">Historical Context (From Peer Reviews)</h2>
<p>The NeurIPS peer reviews reveal important context about the paper&rsquo;s structure and claims:</p>
<ul>
<li>
<p><strong>Evolution of Experiments</strong>: The <strong>SHREC17</strong> experiment and the <strong>arbitrary rotation</strong> test in Tetris were added during the rebuttal phase to address reviewer concerns about the lack of standard computer vision benchmarks. This explains why SHREC17 feels somewhat disconnected from the paper&rsquo;s &ldquo;AI for Science&rdquo; narrative.</p>
</li>
<li>
<p><strong>Continuous vs. Discrete Rotations</strong>: The Tetris experiment validates equivariance to <strong>continuous</strong> ($SO(3)$) rotations alongside discrete 90-degree turns. This distinction is crucial and separates Steerable CNNs from earlier Group CNNs that handled discrete rotation groups exclusively.</p>
</li>
<li>
<p><strong>Terminology Warning</strong>: The use of terms like &ldquo;fiber&rdquo; and &ldquo;induced representation&rdquo; was critiqued by reviewers as being denser than necessary and inconsistent with related work (e.g., Tensor Field Networks). If you find Section 3 difficult, this is a known barrier of this paper. Focus on the resulting kernel constraints.</p>
</li>
<li>
<p><strong>Parameter Efficiency Quantified</strong>: Reviewers highlighted that the main practical win is <strong>parameter efficiency</strong>. Standard 3D CNNs hit diminishing returns around $10^7$ parameters, while Steerable CNNs achieve better results with ~110x fewer parameters ($10^5$).</p>
</li>
</ul>
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/ENLJACPHSEA?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Weiler, M., Geiger, M., Welling, M., Boomsma, W., &amp; Cohen, T. S. (2018). 3D steerable CNNs: Learning rotationally equivariant features in volumetric data. <em>Advances in Neural Information Processing Systems</em>, 31. <a href="https://proceedings.neurips.cc/paper/2018/hash/488e4104520c6aab692863cc1dba45af-Abstract.html">https://proceedings.neurips.cc/paper/2018/hash/488e4104520c6aab692863cc1dba45af-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{weiler20183d,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Weiler, Maurice and Geiger, Mario and Welling, Max and Boomsma, Wouter and Cohen, Taco S}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{31}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/mariogeiger/se3cnn">GitHub Repository</a></li>
<li><a href="https://www.youtube.com/watch?v=ENLJACPHSEA">YouTube Video</a></li>
<li><a href="https://github.com/wouterboomsma/cath_datasets">CATH Dataset</a></li>
</ul>
]]></content:encoded></item><item><title>Optimizing Sequence Models for Dynamical Systems</title><link>https://hunterheidenreich.com/research/deconstructing-recurrence-attention-gating/</link><pubDate>Tue, 01 Oct 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/research/deconstructing-recurrence-attention-gating/</guid><description>Ablation study deconstructing sequence models. Attention-augmented Recurrent Highway Networks outperform Transformers on chaotic systems.</description><content:encoded><![CDATA[<h2 id="abstract">Abstract</h2>
<p>Advanced neural network architectures developed for tasks like natural language processing are often transferred to spatiotemporal forecasting without a deep understanding of which components drive their performance. This can lead to suboptimal results and reinforces the view of these models as &ldquo;black boxes&rdquo;. In this work, we deconstruct the core mechanisms of Transformers and Recurrent Neural Networks (RNNs) (namely attention, gating, and recurrence). We then build and test novel hybrid architectures to identify which components are most effective. A key finding is that while adding recurrence is detrimental to Transformers, augmenting RNNs with attention and neural gating consistently improves their forecasting accuracy. Our study reveals that a seldom-used architecture, the Recurrent Highway Network (RHN) enhanced with these mechanisms, emerges as the top-performing model for forecasting high-dimensional chaotic systems.</p>
<h2 id="key-contributions">Key Contributions</h2>
<ul>
<li><strong>Systematic Ablation</strong>: Deconstructed Transformers and RNNs into core mechanisms (attention, gating, recurrence) to isolate performance drivers</li>
<li><strong>Novel Hybrid Architectures</strong>: Synthesized and tested new combinations of neural primitives for spatiotemporal forecasting</li>
<li><strong>RHN Superiority</strong>: Demonstrated that attention-augmented Recurrent Highway Networks outperform standard Transformers on high-dimensional chaotic systems</li>
<li><strong>Robustness Analysis</strong>: Validated models across both clean physics simulations and noisy real-world industrial datasets</li>
</ul>
<h2 id="the-engineering-problem">The Engineering Problem</h2>
<p>In modern ML, a common anti-pattern is the blind transfer of architectures from one domain (like NLP) to another (like physical forecasting) without understanding the underlying mechanics. This &ldquo;black box&rdquo; approach leads to suboptimal compute usage and performance ceilings.</p>
<p>My goal was to break these architectures down. I treated the core mechanisms of <strong>Transformers</strong> and <strong>RNNs</strong> (<strong>Gating, Attention, and Recurrence</strong>) as orthogonal basis vectors. By decoupling these components, we could synthesize and test hybrid architectures to find the optimal configuration for spatiotemporal forecasting.</p>
<h2 id="methodological-approach">Methodological Approach</h2>
<p>We engineered a modular framework to mix and match neural primitives. We systematically evaluated:</p>
<ol>
<li><strong>Gating Mechanisms:</strong> Testing Additive, Learned Rate, and Input-Dependent variants</li>
<li><strong>Attention:</strong> Implementing multi-headed attention with relative positional biases</li>
<li><strong>Recurrence:</strong> Testing standard cells (LSTM, GRU) against deeper transition cells like Recurrent Highway Networks (RHN)</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/neural-gates.webp"
         alt="Neural gating mechanisms: Additive, Learned Rate, Dependent-Coupled, and Dependent variants"
         title="Neural gating mechanisms: Additive, Learned Rate, Dependent-Coupled, and Dependent variants"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The hierarchy of neural gating mechanisms we tested, from simple additive to fully input-dependent.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/rnn-cell-types.webp"
         alt="RNN cell architectures: Elman, LSTM, GRU, and RHN cells"
         title="RNN cell architectures: Elman, LSTM, GRU, and RHN cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Recurrent cell types compared in our study. The RHN (d) extends processing depth within each timestep.</figcaption>
    
</figure>

<p>This rigorous ablation study allowed us to isolate exactly <em>which</em> mathematical operation was driving performance gain.</p>
<h2 id="key-findings">Key Findings</h2>
<h3 id="the-rhn-is-a-sleeping-giant">The RHN is a Sleeping Giant</h3>
<p>The industry has pivoted hard to Transformers. To understand why this might be suboptimal for physics, one must look at the systems we are modeling.</p>
<p>For high-dimensional chaotic systems like the Multiscale Lorenz-96 shown below, we found that a <strong>Recurrent Highway Network (RHN)</strong> augmented with <strong>Attention and Neural Gating</strong> was the top-performing architecture. This novel hybrid exceeded the forecasting accuracy of standard Transformers, suggesting that deeper recurrence (processing depth per timestep) is crucial for complex dynamics.</p>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/multiscale-lorenz.webp"
         alt="Forecasting comparison on Multiscale Lorenz-96 system"
         title="Forecasting comparison on Multiscale Lorenz-96 system"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Forecasting the Multiscale Lorenz-96 system. The top row shows the &rsquo;texture&rsquo; of the chaotic evolution. Notice how the RHN (far right) maintains the coherent wave-like structures for nearly 2 full Lyapunov times, whereas the Transformer variants blur into noise much earlier.</figcaption>
    
</figure>

<h3 id="transformers-recurrence-hurts-gating-helps">Transformers: Recurrence Hurts, Gating Helps</h3>
<p>We attempted to force recurrence into Transformers to give them &ldquo;memory,&rdquo; but it consistently hurt performance. However, <strong>Neural Gating</strong> significantly improved Transformer robustness. For real-world, noisy data (traffic, weather), the <strong>Pre-Layer Normalization (PreLN) Transformer</strong> with added gating proved to be the most robust model.</p>
<h3 id="augmenting-the-old-guard">Augmenting the Old Guard</h3>
<p>We tested on the Kuramoto-Sivashinsky equation, a model of turbulence and flame fronts. We found that legacy architectures (LSTMs, GRUs) are under-optimized. By adding modern <strong>Attention mechanisms</strong> to these older cells, we improved their performance by over 40% in some chaotic regimes.</p>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/kuramoto-sivashinksy.webp"
         alt="Forecasting comparison on Kuramoto-Sivashinsky system"
         title="Forecasting comparison on Kuramoto-Sivashinsky system"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Forecasting the Kuramoto-Sivashinsky system. The error heatmaps (bottom row) show how prediction quality degrades over time (lighter means larger error). The RHN maintains structural fidelity longer than competing architectures.</figcaption>
    
</figure>

<h3 id="real-world-robustness-beyond-the-lab">Real-World Robustness: Beyond the Lab</h3>
<p>While chaotic systems test the limits of theory, we also validated our models on seven standard industrial datasets, including <strong>Electricity Transformer Temperature (ETT)</strong>, <strong>Traffic Flow</strong>, and <strong>Weather</strong> data.</p>
<p>Unlike the clean physics simulations, these datasets contain real-world noise and irregularities. In this environment, the <strong>Pre-Layer Normalization (PreLN) Transformer</strong> proved to be the most robust architecture. While it didn&rsquo;t always beat the RHN on pure chaos, its stability makes it a strong default choice for general time-series forecasting tasks where training speed and reliability are paramount.</p>
<h2 id="why-this-matters">Why This Matters</h2>
<p>This work demonstrates a move away from &ldquo;state-of-the-art chasing&rdquo; toward first-principles AI engineering.</p>
<ul>
<li><strong>For Production:</strong> We identified that while Transformers train 25-50% faster, optimized RNNs offer superior inference accuracy for physical systems. This allows for informed trade-offs between training budget and deployment precision.</li>
<li><strong>For Research:</strong> We established that architectural components should be treated as tunable hyperparameters, not fixed constraints. By carefully selecting these mechanisms, practitioners can design models better suited for the specific challenges of dynamical systems forecasting.</li>
</ul>
<p>The ablation framework here, treating architectural components as independently tunable factors and measuring their marginal contribution, shaped how later evaluation work is structured. The same principle of isolating variables rather than comparing end-to-end black boxes appears in the document processing research, from benchmark construction in page stream segmentation to grounded evaluation in GutenOCR.</p>
<h2 id="related-work">Related Work</h2>
<p>The methodology here shares a design philosophy with <a href="/research/eigennoise-contrastive-prior/">EigenNoise</a>,
which similarly decomposes a neural mechanism (word vector initialization) into theoretically
grounded components to isolate what drives performance. Both papers treat model components as
testable hypotheses rather than fixed architectural choices.</p>
<p>For broader context on where this fits in the portfolio&rsquo;s Scientific Machine Learning arc,
see the <a href="/research/">Research</a> overview.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{heidenreich2024deconstructingrecurrenceattentiongating,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Deconstructing Recurrence, Attention, and Gating: Investigating the transferability of Transformers and Gated Recurrent Neural Networks in forecasting of dynamical systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hunter S. Heidenreich and Pantelis R. Vlachas and Petros Koumoutsakos}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2410.02654}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.LG}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2410.02654}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Invalid SMILES Benefit Chemical Language Models: A Study</title><link>https://hunterheidenreich.com/notes/computational-chemistry/molecular-representations/invalid-smiles-help/</link><pubDate>Mon, 15 Apr 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/computational-chemistry/molecular-representations/invalid-smiles-help/</guid><description>Skinnider (2024) shows that generating invalid SMILES actually improves chemical language model performance through quality filtering.</description><content:encoded><![CDATA[<h2 id="core-contribution-repurposing-invalid-smiles">Core Contribution: Repurposing Invalid SMILES</h2>
<p>This is a <strong>Method</strong> and <strong>Theory</strong> paper that challenges a fundamental assumption in the field of chemical language models. Skinnider provides both empirical evidence and mechanistic explanations for why the ability to generate &ldquo;invalid&rdquo; SMILES strings is actually beneficial for model performance.</p>
<h2 id="the-problem-with-absolute-validity-in-chemical-lms">The Problem with Absolute Validity in Chemical LMs</h2>
<p>Prior research attempts to eliminate invalid generations using constrained representations like SELFIES. This paper shifts the paradigm by demonstrating that invalid outputs provide a necessary theoretical signal for distribution learning, enabling implicit uncertainty estimation.</p>
<h2 id="invalid-generation-as-an-implicit-quality-filter">Invalid Generation as an Implicit Quality Filter</h2>
<p>The central insight is counterintuitive: <strong>invalid SMILES generation acts as a built-in quality control mechanism</strong>. The key contributions are:</p>
<ol>
<li>
<p><strong>Empirical Evidence</strong>: Direct comparisons showing that SMILES-based models consistently outperform SELFIES-based models across multiple metrics, with performance gains strongly correlated with the proportion of invalid outputs generated.</p>
</li>
<li>
<p><strong>Mechanistic Explanation</strong>: Invalid SMILES are demonstrated to be low-likelihood samples from the model&rsquo;s probability distribution. When these are filtered out, it&rsquo;s equivalent to removing the model&rsquo;s least confident predictions, a form of automatic quality control.</p>
</li>
<li>
<p><strong>Causal Evidence</strong>: By modifying SELFIES to allow invalid generation (through relaxed constraints), the author shows that performance improves when models can generate and discard invalid outputs, directly proving the causal relationship.</p>
</li>
<li>
<p><strong>Bias Analysis</strong>: SELFIES models are shown to introduce systematic structural biases (fewer aromatic rings, more aliphatic rings) due to their validity constraints, limiting their ability to explore chemical space naturally.</p>
</li>
</ol>
<h2 id="experimental-design-and-causal-interventions">Experimental Design and Causal Interventions</h2>
<p>The paper uses a multi-pronged approach to establish both correlation and causation:</p>
<p><strong>Performance Comparisons</strong>: SMILES and SELFIES models were trained on identical datasets and evaluated using distribution-learning metrics like Fréchet ChemNet distance. The comparison was robust across different architectures, training set sizes, and chemical databases.</p>
<p><strong>Loss Analysis</strong>: The relationship between SMILES validity and model confidence was examined by analyzing the sequence loss. For a given SMILES string $S$ composed of tokens $t_1, t_2, &hellip;, t_N$, the negative log-likelihood acts as a proxy for the model&rsquo;s uncertainty:</p>
<p>$$ \text{NLL}(S) = -\sum_{i=1}^N \log P(t_i | t_1, &hellip;, t_{i-1}) $$</p>
<p>Invalid SMILES strings consistently register higher $\text{NLL}$ scores, meaning they represent the model&rsquo;s least confident predictions. Filtering them effectively acts as automatic quality control, providing the mechanistic explanation for why invalid filtering improves performance.</p>
<p><strong>Causal Intervention</strong>: The most clever experiment involved creating &ldquo;unconstrained SELFIES&rdquo; models by relaxing or removing the valency constraints that normally ensure validity. This allowed direct testing of whether the ability to generate invalid outputs (which are then discarded) causally improves performance.</p>
<p><strong>Structural Bias Analysis</strong>: Generated molecules were analyzed for chemical features like ring types and bond patterns to quantify how validity constraints systematically distort the model&rsquo;s exploration of chemical space.</p>
<p><strong>Generalization Testing</strong>: Models were trained on subsets of chemical databases and tested on their ability to reproduce the broader chemical space, measuring how validity constraints affect generalization.</p>
<p><strong>Practical Application</strong>: The approach was tested on structure elucidation, using models to identify unknown molecules from minimal experimental data like mass spectrometry.</p>
<h2 id="key-findings-on-validity-constraints-and-bias">Key Findings on Validity Constraints and Bias</h2>
<p><strong>Superior Performance Across the Board</strong>: SMILES-based models consistently outperformed SELFIES models on distribution-learning tasks. Using metrics like Fréchet ChemNet distance, SMILES models generated molecules that more closely matched the statistical properties of their training data. Remarkably, this performance advantage was directly correlated with the proportion of invalid SMILES generated. Models that produced more invalid outputs performed better after filtering.</p>
<p><strong>Invalid SMILES Are Low-Confidence Predictions</strong>: The analysis revealed that invalid SMILES consistently have higher loss values than valid ones, meaning they represent the model&rsquo;s least confident predictions. This suggests that validity checking acts as an automatic confidence filter, removing low-quality samples without requiring explicit uncertainty estimation.</p>
<p><strong>Causal Evidence Through Unconstrained SELFIES</strong>: The most compelling evidence came from modifying SELFIES to allow invalid generation. When &ldquo;unconstrained SELFIES&rdquo; models could generate and discard invalid molecules, their performance improved dramatically, approaching that of SMILES models. This provides direct causal evidence that the ability to generate invalid outputs is what drives the performance gains.</p>
<p><strong>Validity Constraints Introduce Systematic Bias</strong>: SELFIES models showed clear structural biases compared to both training data and SMILES outputs. They generated fewer aromatic rings and more aliphatic structures, systematic distortions caused by the valency constraints used to ensure validity. These biases limit the model&rsquo;s ability to faithfully represent chemical space.</p>
<p><strong>Reduced Generalization</strong>: When trained on subsets of chemical databases, SMILES models could reproduce a larger portion of the complete chemical space compared to SELFIES models. Although SELFIES generated more valid molecules in absolute terms, their structural biases constrained exploration and limited generalization beyond the training set.</p>
<p><strong>Real-World Application Benefits</strong>: In structure elucidation tasks, identifying unknown molecules from experimental data like mass spectrometry, SMILES-based models significantly outperformed SELFIES models. This demonstrates that the benefits extend beyond academic benchmarks to practical applications.</p>
<p><strong>Computational Efficiency</strong>: Filtering invalid SMILES is computationally trivial. Parsing one million SMILES strings with RDKit takes only minutes on a single CPU, making the post-processing overhead negligible compared to model training and inference costs.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Primary Architecture (LSTM):</strong> The main results rely on a Recurrent Neural Network (RNN) using Long Short-Term Memory (LSTM) units.</p>
<ul>
<li><strong>Structure:</strong> Three-layer LSTM with a hidden layer size of 1,024 dimensions</li>
<li><strong>Embedding:</strong> An embedding layer of 128 dimensions</li>
<li><strong>Decoder:</strong> A linear decoder layer outputs token probabilities</li>
</ul>
<p><strong>Secondary Architecture (Transformer/GPT):</strong> To prove robustness, the author also used a Generative Pretrained Transformer (GPT) architecture adapted from MolGPT.</p>
<ul>
<li><strong>Structure:</strong> Eight transformer blocks</li>
<li><strong>Internals:</strong> Each block contains eight masked self-attention heads and a feed-forward network (1,024 dimensions) using GELU activation</li>
<li><strong>Embedding:</strong> 256 dimensions, concatenated with learned positional encodings</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Optimizer:</strong> Adam optimizer for both architectures with $\beta_1=0.9$ and $\beta_2=0.999$.</p>
<p><strong>Learning Rate:</strong></p>
<ul>
<li>LSTM: 0.001</li>
<li>Transformer: 0.0005</li>
</ul>
<p><strong>Batch Size:</strong> 64</p>
<p><strong>Loss Function:</strong> Cross-entropy loss of next-token prediction.</p>
<p><strong>Stopping Criteria:</strong> Early stopping using a validation set (10% of training data) with patience of 50,000 minibatches.</p>
<h3 id="data">Data</h3>
<p><strong>Primary Source:</strong> ChEMBL database (version 28).</p>
<p><strong>Preprocessing Pipeline:</strong></p>
<ul>
<li><strong>Cleaning:</strong> Removal of duplicate SMILES, salts, and solvents (retaining heavy fragments with $\geq 3$ heavy atoms)</li>
<li><strong>Filtering:</strong> Molecules with atoms other than {Br, C, Cl, F, H, I, N, O, P, S} were removed</li>
<li><strong>Normalization:</strong> Charged molecules were neutralized and converted to canonical SMILES</li>
</ul>
<p><strong>Training Subsets:</strong> Models were trained on random samples of 30,000, 100,000, and 300,000 molecules to test scalability.</p>
<p><strong>Generalization Data:</strong> To test generalization, models were also trained on the <a href="/notes/computational-chemistry/datasets/gdb-13/">GDB-13</a> database (enumerating drug-like molecules up to 13 heavy atoms).</p>
<p><strong>Structure Elucidation Data:</strong> For practical application tasks, models were trained on natural products (LOTUS, COCONUT), food compounds (FooDB), and environmental contaminants (NORMAN).</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metric:</strong> Fréchet ChemNet Distance (FCD), measuring chemical similarity between generated molecules and the training set (lower is better).</p>
<p><strong>Secondary Metrics:</strong></p>
<ul>
<li><strong>Validity:</strong> Percentage of outputs parseable by RDKit</li>
<li><strong>Scaffold Similarity:</strong> Jensen-Shannon distances between Murcko scaffold compositions</li>
<li><strong>Physical Properties:</strong> Comparisons of molecular weight, LogP, topological polar surface area (TPSA), and ring counts (aromatic vs. aliphatic)</li>
<li><strong>Structure Elucidation:</strong> &ldquo;Top-k accuracy,&rdquo; the proportion of held-out molecules where the correct structure appeared in the model&rsquo;s top $k$ ranked outputs</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute Nodes:</strong> Dell EMC C4140 GPU compute nodes</li>
<li><strong>GPUs:</strong> NVIDIA Tesla V100</li>
<li><strong>Compute Time:</strong> Parsing 1 million SMILES took ~7.5 minutes on a single CPU; SELFIES models required an average of 0.6 hours longer to train than SMILES models</li>
</ul>
<h3 id="replicability">Replicability</h3>
<p><strong>Code Availability:</strong> Source code and intermediate data are available via <a href="https://doi.org/10.5281/zenodo.10680855">Zenodo</a>. Note that pre-trained model weights are not explicitly provided in the archive, requiring researchers to train models from scratch using the included scripts to fully replicate the study.</p>
<p><strong>Software Libraries:</strong></p>
<ul>
<li><strong>PyTorch:</strong> LSTM and Transformer implementations</li>
<li><strong>RDKit:</strong> SMILES parsing, validity checking, and property calculation</li>
<li><strong>SELFIES:</strong> Version 2.1.1 for conversion</li>
</ul>
<h2 id="implications-and-takeaways">Implications and Takeaways</h2>
<p>This work fundamentally challenges how we think about &ldquo;errors&rdquo; in generative models. The key insight finds that model outputs that appear incorrect often represent useful uncertainty signals that improve overall performance when properly handled.</p>
<p>The findings suggest that the field&rsquo;s drive toward guaranteed validity leads to systematic biases. It proves advantageous to let models fail informatively and use those failures as quality signals. This applies particularly as the field moves toward larger, more capable models where such self-correction mechanisms become increasingly valuable.</p>
<p>For practitioners, the message is clear: pause before &ldquo;fixing&rdquo; models that generate invalid outputs. That apparent flaw acts as a feature in disguise, providing automatic quality control that improves final results. Success often stems from learning to recognize and filter mistakes effectively.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Skinnider, M. A. (2024). Invalid SMILES are beneficial rather than detrimental to chemical language models. Nature Machine Intelligence, 6(4), 437-448. <a href="https://doi.org/10.1038/s42256-024-00821-x">https://doi.org/10.1038/s42256-024-00821-x</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence (2024)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{skinnider2024invalid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Invalid SMILES are beneficial rather than detrimental to chemical language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Skinnider, Michael A}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{437--448}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group UK London}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are powerful tools for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \x28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For discrete inputs</strong> (like MNIST, where pixels are 0 or 1), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli random variable (either on or off). The decoder outputs the <em>logits</em> (log-probabilities) for each pixel, and the BCE loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \x\text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>IQCRNN: Certified Stability for Neural Networks</title><link>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</link><pubDate>Wed, 11 May 2022 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</guid><description>PyTorch IQCRNN enforcing stability guarantees on RNNs via Integral Quadratic Constraints and semidefinite programming.</description><content:encoded><![CDATA[<p>This project is a PyTorch re-implementation of <strong>IQCRNN</strong>, a method that enforces strict stability guarantees on Recurrent Neural Networks used in control systems.</p>
<h2 id="overview">Overview</h2>
<p>Standard Reinforcement Learning agents can behave unpredictably in unseen states. This approach forces the agent&rsquo;s weights to satisfy <strong>Integral Quadratic Constraints (IQC)</strong> via a projection step. Effectively, it solves a convex optimization problem (Semidefinite Program) inside the gradient descent loop to ensure the controller never violates Lyapunov stability criteria.</p>
<p>The method bridges classic <strong>Robust Control Theory</strong> (1990s) with <strong>Deep Reinforcement Learning</strong> (2020s), providing mathematical certificates of safety for neural network controllers.</p>
<h2 id="features">Features</h2>
<ul>
<li><strong>Hybrid Optimization:</strong> Interleaved standard Gradient Descent (PyTorch) with Convex Optimization (<code>cvxpy</code> + <code>MOSEK</code>) to project weights onto the &ldquo;safe&rdquo; manifold after each training step.</li>
<li><strong>Complex Constraints:</strong> Implemented the &ldquo;Tilde&rdquo; parametrization from the original paper to convexify the non-convex stability conditions of the RNN dynamics, transforming an intractable problem into a solvable Linear Matrix Inequality (LMI).</li>
<li><strong>Safety-Critical Domains:</strong> Applied the controller to unstable environments like Power Grids and Inverted Pendulums where &ldquo;crashing&rdquo; during training is unacceptable.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The repository includes training scripts for the inverted pendulum and power grid environments, demonstrating the stability guarantees in practice.</p>
<h2 id="results">Results</h2>
<p>This project was a deep dive into the tension between <strong>Safety</strong> and <strong>Speed</strong>.</p>
<ul>
<li><strong>The Bottleneck:</strong> Solving an SDP at every few steps of training is computationally exorbitant ($O(n^6)$ complexity in the training loop). While it provided mathematical certificates of safety, it highlighted why these methods haven&rsquo;t yet overtaken standard PPO/SAC in production: the &ldquo;safety tax&rdquo; on training time is massive.</li>
<li><strong>The Lesson:</strong> It taught me that &ldquo;theoretical guarantees&rdquo; often come with &ldquo;engineering fine print.&rdquo; If I were to redo this today, I would look into <strong>differentiable convex optimization layers</strong> (like <code>cvxpylayers</code>) to make the projection end-to-end differentiable.</li>
<li><strong>The &ldquo;Rough Edges&rdquo;:</strong> The codebase has artifacts of its research origins (e.g., the <code>reqs.txt</code> dependency dump). Reading a dense control theory paper (Gu et al., 2021) and implementing the math correctly was the primary focus.</li>
</ul>
<h2 id="citation">Citation</h2>
<p>Credit to the original authors:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{gu2021recurrentneuralnetworkcontrollers,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Recurrent Neural Network Controllers Synthesis with Stability Guarantees for Partially Observed Systems}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fangda Gu and He Yin and Laurent El Ghaoui and Murat Arcak and Peter Seiler and Ming Jin}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2109.03861}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{eess.SY}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2109.03861}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="related-work">Related Work</h2>
<ul>
<li><a href="/research/deconstructing-recurrence-attention-gating/">Deconstructing Recurrence and Attention Gating</a>: research on recurrent network architectures, providing context for why stability guarantees on RNNs matter</li>
</ul>
]]></content:encoded></item><item><title>GPT-2 Susceptibility to Universal Adversarial Triggers</title><link>https://hunterheidenreich.com/research/gpt2-adversarial-triggers/</link><pubDate>Sat, 01 May 2021 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/research/gpt2-adversarial-triggers/</guid><description>Investigation into whether universal adversarial triggers can control both topic and stance of GPT-2's generated text and security implications.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Historical context:</strong> This paper was published in 2021, predating the modern red-teaming practices and adversarial robustness benchmarks that emerged with instruction-tuned and RLHF-trained models. GPT-2 is now a historical baseline, but the core methodology and findings remain a relevant foundation for current adversarial robustness work.</p></blockquote>
<h2 id="abstract">Abstract</h2>
<p>This work investigates universal adversarial triggers (UATs), a method for disrupting language models using input-agnostic token sequences. We investigated whether it is possible to use these triggers to control the <strong>topic</strong> and the <strong>stance</strong> of text generated by GPT-2. Across four controversial topics, we demonstrated success in identifying triggers that guide the model to produce text on a targeted subject and influence the position it takes. Our goal is to raise awareness that even deployed models are susceptible to this influence and to advocate for immediate safeguards.</p>
<h2 id="key-findings--contributions">Key Findings &amp; Contributions</h2>
<ul>
<li><strong>Topic and Stance Control</strong>: We were the first to systematically explore using UATs to control both the topic and the stance of a language model&rsquo;s output. We found that controlling the topic is highly feasible, and controlling the stance is also possible.</li>
<li><strong>The &ldquo;Filter Bubble&rdquo; Hypothesis</strong>: We observed that triggers for fringe topics (e.g., Flat Earth) were harder to find but offered a higher degree of stance control than broader topics. We posit this may reflect &ldquo;filter bubbles&rdquo; in the training data, where fringe viewpoints use distinct linguistic patterns.</li>
<li><strong>Ethical &amp; Security Analysis</strong>: We highlighted the security risks of deployed models being manipulated by external adversaries without internal model access. To be responsible, we withheld the most sensitive triggers we discovered.</li>
<li><strong>Constructive Applications</strong>: Beyond a security flaw, we proposed that UATs could be used constructively as a <strong>diagnostic tool</strong> to audit models for bias or as a method for <strong>bot detection</strong> on social media.</li>
</ul>
<h2 id="significance--why-this-matters">Significance &amp; Why This Matters</h2>
<p>This work extended early research on UATs by moving beyond single-issue attacks (like generating toxic content) to a nuanced analysis of topic and stance control. It demonstrated that a <strong>gradient-based search process (adapting HotFlip)</strong> is effective at manipulating model outputs, emphasizing a critical vulnerability for any organization deploying large language models.</p>
<p>For ML practitioners and security researchers, this highlights the importance of robust safeguards against input-agnostic attacks. It also opens the door to using these same adversarial techniques constructively: as diagnostic tools to audit models for hidden biases or to detect automated bot activity on social media platforms.</p>
<h2 id="related-work">Related Work</h2>
<p>The constructive bot-detection application proposed here connects directly to empirical work on coordinated inauthentic behavior. <a href="/research/coordinated-social-targeting/">Coordinated Social Targeting on Twitter</a> documents real-world follower-manipulation patterns on high-profile accounts, illustrating the kind of automated adversarial activity that UAT-based detection methods could help identify.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{10.1145/3461702.3462578,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Heidenreich, Hunter Scott and Williams, Jake Ryland}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The Earth Is Flat and the Sun Is Not a Star: The Susceptibility of GPT-2 to Universal Adversarial Triggers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{9781450384735}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computing Machinery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{New York, NY, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1145/3461702.3462578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1145/3461702.3462578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 2021 AAAI/ACM Conference on AI, Ethics, and Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{566--573}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">numpages</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{adversarial attacks, bias, language modeling, natural language processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">location</span> = <span style="color:#e6db74">{Virtual Event, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span> = <span style="color:#e6db74">{AIES &#39;21}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>A Guide to Neuroevolution: NEAT and HyperNEAT</title><link>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</link><pubDate>Wed, 02 Jan 2019 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</guid><description>Explore the evolution of neural network topologies with NEAT and how HyperNEAT scales this approach using geometric patterns and indirect encoding.</description><content:encoded><![CDATA[<h2 id="automating-neural-architecture-design">Automating Neural Architecture Design</h2>
<p>Designing neural network architectures is typically a manual, iterative process. Researchers experiment with different layer configurations, activation functions, and connection patterns, often guided by intuition and empirical results. Evolution offers an automated alternative to this design process.</p>
<p><a href="https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (NeuroEvolution of Augmenting Topologies)</a> introduced a compelling answer in 2002. This algorithm optimizes network weights and evolves the network structure itself, starting from minimal topologies and growing complexity only when beneficial.</p>
<p>NEAT&rsquo;s core innovations solved fundamental problems that had plagued earlier attempts at topology evolution. Its solutions for genetic encoding, structural crossover, and innovation protection remain influential today, especially as neural architecture search and automated ML gain prominence.</p>
<h2 id="the-core-challenges-of-neat">The Core Challenges of NEAT</h2>
<p>Evolving neural network topologies presents several fundamental challenges that NEAT elegantly addressed. Understanding these problems helps explain why NEAT&rsquo;s solutions were so influential.</p>
<h3 id="genetic-encoding-how-to-represent-networks">Genetic Encoding: How to Represent Networks</h3>
<p>Evolutionary algorithms require a genetic representation, a way to encode individuals that enables meaningful selection, mutation, and crossover. For neural networks, this choice is critical.</p>
<p><strong>Direct encoding</strong> explicitly represents each network component. Genes directly correspond to nodes and connections. This approach is intuitive and readable, and it works well for smaller networks.</p>
<p><strong>Indirect encoding</strong> specifies construction rules or processes. These encodings are more compact and can generate highly complex structures from simple rules.</p>
<p>NEAT chose direct encoding with a simple two-part structure: separate gene lists for nodes and connections. This balances simplicity with the flexibility needed for evolutionary operations.</p>















<figure class="post-figure center ">
    <img src="/img/neat_genomes.webp"
         alt="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         title="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT&rsquo;s direct encoding: node genes (top) and connection genes (bottom) with historical markings</figcaption>
    
</figure>

<p>Connection genes specify the source and target nodes, weight, enabled status, and an innovation number for historical tracking. Input and output nodes are fixed; only hidden nodes evolve.</p>
<h3 id="structural-mutations-growing-complexity">Structural Mutations: Growing Complexity</h3>
<p>NEAT employs two categories of mutations to evolve both weights and structure:</p>
<p><strong>Weight mutations</strong> adjust existing connection strengths using standard perturbation methods, the familiar approach from traditional neuroevolution.</p>
<p><strong>Structural mutations</strong> add new network components:</p>
<ul>
<li><strong>Add connection</strong>: Creates a new link between existing nodes with a random initial weight</li>
<li><strong>Add node</strong>: Splits an existing connection by inserting a new node. The original connection is disabled, while two new connections replace it. One inherits the original weight, the other starts at 1.0</li>
</ul>
<p>This node-splitting approach minimizes disruption. The new node initially acts as an identity function, giving it time to prove useful before natural selection pressure intensifies.</p>
<h3 id="solving-the-competing-conventions-problem">Solving the Competing Conventions Problem</h3>
<p>Performing crossover between networks with different structures presents a fundamental challenge. Consider two networks that solve the same problem using different internal organizations. Naive crossover between them typically produces broken offspring.</p>















<figure class="post-figure center ">
    <img src="/img/competing_conventions.webp"
         alt="Two neural networks performing the same function but with different internal structures"
         title="Two neural networks performing the same function but with different internal structures"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The competing conventions problem: identical functions, different implementations</figcaption>
    
</figure>

<p>NEAT&rsquo;s solution draws inspiration from biology through <strong>historical markings</strong>. Each structural innovation (adding a node or connection) receives a unique innovation number, a timestamp of when that change first appeared in the population.</p>
<p>During crossover, genes with matching innovation numbers are aligned and combined. This biological concept of homology enables meaningful recombination between networks of different sizes and structures.</p>















<figure class="post-figure center ">
    <img src="/img/neat_crossover.webp"
         alt="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         title="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT crossover using historical markings for gene alignment</figcaption>
    
</figure>

<h3 id="protecting-innovation-through-speciation">Protecting Innovation Through Speciation</h3>
<p>New structural innovations face a harsh reality: they usually perform worse initially. Adding nodes or connections typically decreases performance before optimization can improve the new structure. Without protection, these innovations disappear before realizing their potential.</p>
<p>NEAT addresses this through <strong>speciation</strong>: dividing the population into species based on structural and weight similarity. The historical markings that enable crossover also measure compatibility between individuals.</p>
<p>Crucially, individuals only compete within their species. This gives new structural innovations time to optimize without immediately competing against established, well-tuned networks.</p>
<p><strong>Explicit fitness sharing</strong> enhances this protection: species divide their collective fitness among members, preventing any single species from dominating the population while maintaining diversity for continued exploration.</p>
<h3 id="complexification-starting-minimal">Complexification: Starting Minimal</h3>
<p>NEAT begins with the simplest possible networks (just input and output nodes connected by random weights). No hidden layers exist initially. Complexity emerges only when mutations that add structure prove beneficial.</p>
<p>This complexification approach builds efficient solutions that solve problems with minimal structure. Combined with speciation, it tends to produce highly optimized architectures.</p>
<h2 id="scaling-up-hyperneat">Scaling Up: HyperNEAT</h2>
<p>NEAT evolved networks through direct encoding, where each gene explicitly specifies nodes and connections. Scaling this approach to larger architectures requires a fundamentally different method. Evolving networks with billions of connections like the brain requires indirect encoding.</p>
<p><a href="https://doi.org/10.1162/artl.2009.15.2.15202">HyperNEAT</a> introduces <strong>indirect encoding</strong> through geometric principles. HyperNEAT evolves geometric patterns that generate connections based on spatial relationships. This enables the evolution of large networks with biological regularities: symmetry, repetition, and locality.</p>
<p>The key insight is leveraging Compositional Pattern Producing Networks (CPPNs) to map coordinates to connection weights, exploiting the geometric organization found in natural neural networks.</p>
<h3 id="biological-motivation">Biological Motivation</h3>
<p>The human brain exhibits remarkable organizational principles:</p>
<ul>
<li><strong>Scale</strong>: ~86 billion neurons with ~100 trillion connections</li>
<li><strong>Repetition</strong>: Structural patterns reused across regions</li>
<li><strong>Symmetry</strong>: Mirrored structures like bilateral visual processing</li>
<li><strong>Locality</strong>: Spatial proximity influences connectivity and function</li>
</ul>
<p>HyperNEAT aims to evolve networks that capture these biological regularities, leading to more efficient and interpretable architectures.</p>
<h3 id="compositional-pattern-producing-networks">Compositional Pattern Producing Networks</h3>
<p>CPPNs are the foundation of HyperNEAT&rsquo;s indirect encoding. Think of them as pattern generators that create complex spatial structures from simple coordinate inputs.</p>
<p>DNA exemplifies indirect encoding (roughly 30,000 genes specify a brain with trillions of connections). This massive compression ratio suggests that simple rules can generate complex structures through developmental processes.</p>
<p>CPPNs abstract this concept, using compositions of mathematical functions to create patterns in coordinate space. The same genetic program (function composition) can be reused across different locations and scales, just like how developmental genes control pattern formation throughout an organism.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppns.webp"
         alt="Various symmetric and repetitive patterns created by CPPNs"
         title="Various symmetric and repetitive patterns created by CPPNs"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Complex patterns generated by CPPNs through function composition</figcaption>
    
</figure>

<h3 id="pattern-generation-through-function-composition">Pattern Generation Through Function Composition</h3>
<p>CPPNs generate patterns by composing simple mathematical functions. Key function types include:</p>
<ul>
<li><strong>Gaussian functions</strong>: Create symmetric patterns and gradients</li>
<li><strong>Trigonometric functions</strong>: Generate periodic/repetitive structures</li>
<li><strong>Linear functions</strong>: Produce gradients and asymmetric patterns</li>
<li><strong>Sigmoid functions</strong>: Create sharp transitions and boundaries</li>
</ul>
<p>By combining these functions, CPPNs can encode complex regularities that would require many explicit rules in direct encoding.</p>
<h3 id="evolution-of-cppns">Evolution of CPPNs</h3>
<p>HyperNEAT uses NEAT to evolve the CPPN structure. This brings several advantages:</p>
<ul>
<li><strong>Complexification</strong>: CPPNs start simple and grow more complex only when beneficial</li>
<li><strong>Historical markings</strong>: Enable proper crossover between different CPPN topologies</li>
<li><strong>Speciation</strong>: Protects innovative CPPN patterns during evolution</li>
</ul>
<p>Additional activation functions beyond standard neural networks are crucial:</p>
<ul>
<li>Gaussian functions for symmetry</li>
<li>Sine/cosine for repetition</li>
<li>Specialized functions for specific geometric patterns</li>
</ul>
<h2 id="the-hyperneat-process">The HyperNEAT Process</h2>
<h3 id="substrates-geometric-organization">Substrates: Geometric Organization</h3>
<p>A <strong>substrate</strong> defines the spatial arrangement of neurons. Substrates embed neurons in geometric space (2D grids, 3D volumes, etc.), providing an alternative to layer-based connectivity rules.</p>
<p>The CPPN maps from coordinates to connection weights:</p>
<p>$$\text{CPPN}(x_1, y_1, x_2, y_2) = w$$</p>
<p>Where $(x_1, y_1)$ and $(x_2, y_2)$ are the coordinates of two neurons, and $w$ determines their connection weight.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppn_basics.webp"
         alt="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         title="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basic CPPN architecture mapping coordinates to connection weights</figcaption>
    
</figure>

<p>This geometric approach enables several key properties:</p>
<ul>
<li><strong>Locality</strong>: Nearby neurons tend to have similar connectivity patterns</li>
<li><strong>Symmetry</strong>: Patterns can be mirrored across spatial axes</li>
<li><strong>Repetition</strong>: Periodic functions create repeating motifs</li>
<li><strong>Scalability</strong>: The same pattern can be applied at different resolutions</li>
</ul>
<h3 id="emergent-regularities">Emergent Regularities</h3>
<p>The geometric encoding naturally produces the desired biological patterns:</p>
<p><strong>Symmetry</strong> emerges from symmetric functions. A Gaussian centered at the origin creates identical patterns when $(x_1, y_1)$ and $(x_2, y_2)$ are equidistant from the center.</p>
<p><strong>Repetition</strong> arises from periodic functions like sine and cosine. These create repeating connectivity motifs across the substrate.</p>
<p><strong>Locality</strong> results from functions that vary smoothly across space. Nearby coordinates produce similar outputs, leading to local connectivity patterns.</p>
<p><strong>Imperfect regularity</strong> occurs when these patterns are modulated by additional coordinate dependencies, creating biological-like variation within the basic structure.</p>
<h3 id="substrate-configurations">Substrate Configurations</h3>
<p>The choice of substrate geometry critically influences network behavior. Several standard configurations exist:</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_substrate_configurations.webp"
         alt="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         title="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Common substrate geometries for different problem types</figcaption>
    
</figure>

<p><strong>2D Grid</strong>: Simple planar arrangement, CPPN takes four coordinates $(x_1, y_1, x_2, y_2)$</p>
<p><strong>3D Volume</strong>: Extends to three dimensions, CPPN becomes six-dimensional $(x_1, y_1, z_1, x_2, y_2, z_2)$</p>
<p><strong>Sandwich</strong>: Input layer connects only to output layer, useful for sensory-motor tasks</p>
<p><strong>Circular</strong>: Radial geometry enables rotation-invariant patterns and cyclic behaviors</p>
<p>The substrate must be chosen before evolution begins, making domain knowledge important for success.</p>
<h3 id="exploiting-input-output-geometry">Exploiting Input-Output Geometry</h3>
<p>HyperNEAT exploits the spatial organization of inputs and outputs. For visual tasks, pixel coordinates provide meaningful geometric information. For control problems, sensor and actuator layouts can guide connectivity patterns.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_inputs_outputs.webp"
         alt="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         title="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spatial organization of inputs and outputs enables geometric exploitation</figcaption>
    
</figure>

<p>This spatial awareness allows HyperNEAT to:</p>
<ul>
<li>Develop receptive fields similar to biological vision systems</li>
<li>Create locally connected patterns for spatial processing</li>
<li>Generate symmetric motor control patterns</li>
<li>Scale across different input resolutions</li>
</ul>
<h3 id="resolution-independence">Resolution Independence</h3>
<p>A unique advantage of HyperNEAT is <strong>substrate resolution independence</strong>. Networks evolved on low-resolution substrates can be deployed on higher-resolution versions without retraining. The CPPN&rsquo;s coordinate-based mapping scales naturally across different granularities.</p>
<p>This property suggests that evolved patterns capture fundamental spatial relationships, providing a key insight for scalable neural architecture design.</p>
<h2 id="impact-and-future-directions">Impact and Future Directions</h2>
<p>NEAT and HyperNEAT demonstrated that evolution could design neural network topologies and scale them through indirect encoding. The algorithms&rsquo; key insights, exploiting geometry, generating patterns through function composition, and scaling across resolutions, continue to influence modern research.</p>
<p>Extensions like ES-HyperNEAT add even more sophisticated capabilities by evolving the substrate itself. As neural architecture search becomes increasingly important, these principles find new applications in hybrid approaches that combine evolutionary pattern generation with gradient-based optimization.</p>
<p>The emphasis on spatial organization and regularity also connects to contemporary work on geometric deep learning and equivariant networks, suggesting that evolution and hand-design converge on similar organizing principles for building structured, efficient neural architectures.</p>
]]></content:encoded></item><item><title>QuAC: Question Answering in Context Dataset</title><link>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</link><pubDate>Wed, 31 Oct 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</guid><description>Analysis of QuAC's conversational QA through student-teacher interactions, featuring 100K+ context-dependent questions and coreference challenges.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://aclanthology.org/D18-1241/">QuAC dataset</a> (Question Answering in Context) presents a conversational question answering approach that models student-teacher interactions. Published at EMNLP 2018, this work by Choi et al. addresses how systems can understand dialogue context, resolve references across conversation turns, and handle natural conversation ambiguity. Previous datasets treated questions independently.</p>
<p>The dataset addresses limitations in question answering research by incorporating real-world information-seeking dialogue complexities, where questions build upon previous exchanges and context drives understanding.</p>
<p>For comparison with related work, see my analysis of <a href="/posts/coqa-conversation-question-answering/">CoQA</a>.</p>
<h2 id="the-student-teacher-framework">The Student-Teacher Framework</h2>
<p>QuAC models information-seeking dialogue through a student-teacher setup:</p>
<ul>
<li><strong>Teacher</strong>: Has complete access to information (Wikipedia passage)</li>
<li><strong>Student</strong>: Seeks knowledge through questioning with limited initial context</li>
<li><strong>Interaction</strong>: Handles context-dependent questions, abstract inquiries, and unanswerable requests</li>
</ul>
<p>This framework mirrors real-world scenarios where one party has expertise while another seeks to learn through dialogue. AI systems must act as effective teachers, using available information to provide helpful responses despite ambiguous or incomplete questions.</p>
<p>The dataset contains over 100,000 questions across 14,000+ dialogues, providing substantial scale for training and evaluation.</p>















<figure class="post-figure center ">
    <img src="/img/quac_stats.webp"
         alt="QuAC dataset statistics and scale"
         title="QuAC dataset statistics and scale"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">QuAC dataset statistics and scale</figcaption>
    
</figure>

<h2 id="dataset-construction">Dataset Construction</h2>
<p>QuAC was built using Amazon Mechanical Turk with a two-person dialogue setup:</p>
<p><strong>Teacher role</strong>: Has access to the complete Wikipedia passage and provides answers extracted directly from the text</p>
<p><strong>Student role</strong>: Sees only the article title, introduction paragraph, and section heading, then asks questions to learn about the content</p>
<p>This asymmetric information design ensures student questions naturally differ from the passage content, creating realistic information-seeking scenarios. The extractive answer requirement maintains objective evaluation while simplifying scoring.</p>
<p><strong>Dialogue termination</strong>:</p>
<ul>
<li>12 questions answered</li>
<li>Manual termination by either participant</li>
<li>Two consecutive unanswerable questions</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_convo.webp"
         alt="Example QuAC conversation showing student-teacher interaction"
         title="Example QuAC conversation showing student-teacher interaction"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Example QuAC conversation showing student-teacher interaction</figcaption>
    
</figure>

<h3 id="content-selection">Content Selection</h3>
<p>QuAC focuses on Wikipedia biographical articles for several practical reasons:</p>
<ul>
<li><strong>Reduced complexity</strong>: People-focused content requires less specialized domain knowledge</li>
<li><strong>Natural question flow</strong>: Biographical information lends itself to sequential questioning</li>
<li><strong>Quality control</strong>: Articles filtered to include only subjects with 100+ incoming links, ensuring content depth</li>
</ul>
<p>This focused scope enables consistent evaluation while maintaining broad coverage through diverse biographical subjects across fields and time periods.</p>
<h2 id="key-dataset-characteristics">Key Dataset Characteristics</h2>
<p>QuAC introduces several features that distinguish it from existing question answering benchmarks:</p>















<figure class="post-figure center ">
    <img src="/img/quac_comparison.webp"
         alt="Comparative analysis of QuAC against other QA datasets"
         title="Comparative analysis of QuAC against other QA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Comparative analysis of QuAC against other QA datasets</figcaption>
    
</figure>

<p><strong>Notable features</strong>:</p>
<ul>
<li><strong>High contextual dependency</strong>: 86% of questions require coreference resolution</li>
<li><strong>Non-factoid focus</strong>: 54% of questions go beyond simple fact retrieval</li>
<li><strong>Extended answers</strong>: Responses are longer and more detailed</li>
<li><strong>Unanswerable questions</strong>: Realistic scenarios where information isn&rsquo;t available</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_dist.webp"
         alt="Distribution of question types in QuAC"
         title="Distribution of question types in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of question types in QuAC</figcaption>
    
</figure>

<h3 id="the-coreference-resolution-challenge">The Coreference Resolution Challenge</h3>
<p>QuAC&rsquo;s complexity stems from its heavy reliance on coreference resolution across multiple contexts:</p>
<p><strong>Reference types</strong>:</p>
<ul>
<li><strong>Passage references</strong>: Pronouns and references to entities in the source text</li>
<li><strong>Dialogue references</strong>: References to previously discussed topics</li>
<li><strong>Abstract references</strong>: Challenging cases like &ldquo;what else?&rdquo; that require inferring the inquiry scope</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_coref.webp"
         alt="Types and distribution of coreferences in QuAC"
         title="Types and distribution of coreferences in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Types and distribution of coreferences in QuAC</figcaption>
    
</figure>

<p>The prevalence of coreference resolution makes QuAC particularly challenging, as this remains an active research problem in NLP. Models must understand passage content, track dialogue history, and resolve complex referential expressions simultaneously.</p>
<h2 id="performance-results">Performance Results</h2>
<p>Models face substantial challenges on QuAC, with significant gaps between human and machine performance:</p>















<figure class="post-figure center ">
    <img src="/img/quac_performance.webp"
         alt="Baseline model performance comparison on QuAC"
         title="Baseline model performance comparison on QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Baseline model performance comparison on QuAC</figcaption>
    
</figure>

<p><strong>Performance summary</strong>:</p>
<ul>
<li><strong>Human performance</strong>: 81.1% F1 score</li>
<li><strong>Best baseline</strong>: BiDAF++ with context achieves 60.2% F1</li>
<li><strong>Performance gap</strong>: 20+ point difference shows room for improvement</li>
</ul>
<h3 id="human-equivalence-metrics">Human Equivalence Metrics</h3>
<p>QuAC introduces evaluation metrics beyond traditional F1 scores:</p>
<p><strong>HEQ-Q (Human Equivalence Question-level)</strong>: Percentage of questions where the model achieves human-level or better performance</p>
<p><strong>HEQ-D (Human Equivalence Dialogue-level)</strong>: Percentage of complete dialogues where the model matches human performance across all questions</p>
<p><strong>Current results</strong>:</p>
<ul>
<li>Human baseline: 100% HEQ-Q, 100% HEQ-D (by definition)</li>
<li>Best model: 55.1% HEQ-Q, 5.2% HEQ-D</li>
</ul>
<p>These metrics show both average performance and consistency across questions and conversations, important for practical dialogue systems.</p>
<h2 id="research-impact">Research Impact</h2>
<p>QuAC represents an important step in question answering research by introducing realistic conversational dynamics that existing datasets lack. The student-teacher framework captures natural information-seeking behavior while maintaining extractive evaluation for objective assessment.</p>
<p><strong>Key contributions</strong>:</p>
<ul>
<li><strong>Conversational realism</strong>: Context-dependent questions that mirror dialogue patterns</li>
<li><strong>Coreference complexity</strong>: Integration of challenging NLP problems into QA evaluation</li>
<li><strong>Evaluation metrics</strong>: HEQ scores that measure consistency alongside average performance</li>
<li><strong>Large-scale framework</strong>: Substantial dataset enabling robust model training and evaluation</li>
</ul>
<p>The dataset&rsquo;s <a href="https://quac.ai/">leaderboard</a> provides researchers with a challenging benchmark for developing conversational AI systems. As models improve on QuAC, we can expect progress in dialogue agents, virtual assistants, and educational AI systems that engage in more natural, context-aware conversations.</p>
<p>QuAC&rsquo;s focus on dialogue context and reference resolution pushes the field toward AI systems that can engage in genuine conversation and understand complex dialogue flows.</p>
<h2 id="a-builders-perspective-quac-and-modern-instruction-tuning">A Builder&rsquo;s Perspective: QuAC and Modern Instruction Tuning</h2>
<p>Looking at QuAC through the lens of modern production ML, the student-teacher framework is incredibly relevant. Today, we train foundation models using Reinforcement Learning from Human Feedback (RLHF) and instruction tuning, which rely heavily on multi-turn, context-aware interactions.</p>
<p>When building systems like GutenOCR or enterprise document processing pipelines, users rarely ask perfectly formulated, context-free questions. They ask follow-ups, use pronouns, and expect the system to act as a knowledgeable &ldquo;teacher&rdquo; guiding them through the document. QuAC was one of the first datasets to formalize this asymmetric information dynamic. It highlighted the necessity of handling unanswerable questions gracefully, a critical feature for preventing hallucinations in today&rsquo;s production LLMs.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{choi-etal-2018-quac,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">&#34;{Q}u{AC}: Question Answering in Context&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">&#34;Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">&#34;Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">month</span> = oct # <span style="color:#e6db74">&#34;-&#34;</span> # nov,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">&#34;2018&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">&#34;Brussels, Belgium&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">&#34;Association for Computational Linguistics&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">&#34;https://aclanthology.org/D18-1241/&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">&#34;10.18653/v1/D18-1241&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">&#34;2174--2184&#34;</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CoQA Dataset: Advancing Conversational Question Answering</title><link>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</link><pubDate>Thu, 23 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</guid><description>Analysis of CoQA, a conversational QA dataset with multi-turn dialogue, coreference resolution, and natural answers for QA research.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://doi.org/10.1162/tacl_a_00266">CoQA dataset</a> (Reddy et al., 2019) introduces conversational dynamics to question answering research. CoQA requires models to maintain context across multi-turn conversations while reading and reasoning about text passages. Previous datasets focused on isolated question-answer pairs.</p>
<p>This dataset addresses a gap in conversational AI research by providing a benchmark for systems that must understand dialogue flow and implicit references. These are key components of natural human conversation.</p>
<p>For related work on conversational question answering, see my analysis of <a href="/posts/quac-question-answering-in-context/">QuAC</a>.</p>
<h2 id="what-makes-conversational-qa-different">What Makes Conversational QA Different</h2>
<p>Conversational question answering introduces challenges beyond traditional reading comprehension:</p>
<ol>
<li><strong>Context dependency</strong>: Questions rely on previous dialogue turns for meaning</li>
<li><strong>Coreference resolution</strong>: Understanding pronouns and implicit references</li>
<li><strong>Abstractive answering</strong>: Rephrasing information to generate natural responses</li>
<li><strong>Multi-turn reasoning</strong>: Maintaining coherent dialogue across multiple exchanges</li>
</ol>
<p>These requirements differentiate CoQA from existing question answering datasets that treat each question independently.</p>
<h2 id="why-coqa-matters">Why CoQA Matters</h2>
<p>Question answering systems typically excel at finding specific information in text. However, they often struggle with natural conversation. Human communication involves building on previous exchanges, using pronouns and implicit references, and expressing ideas in varied ways.</p>
<p>CoQA addresses this by creating a large-scale dataset for conversational question answering with three primary characteristics:</p>
<ol>
<li>
<p><strong>Conversation-dependent questions</strong>: After the first question, every subsequent question depends on dialogue history across 127,000 questions spanning 8,000 conversations</p>
</li>
<li>
<p><strong>Natural, abstractive answers</strong>: CoQA requires rephrased responses that sound natural in conversation. The answerer first highlighted the relevant text span, then rephrased the information.</p>
</li>
<li>
<p><strong>Domain diversity</strong>: Training covers 5 domains with testing on 7 domains, including 2 unseen during training</p>
</li>
</ol>
<p>The performance gap is notable: humans achieve 88.8% F1 score while the best models at the time reached 65.1% F1, indicating substantial room for improvement.</p>
<h2 id="dataset-construction">Dataset Construction</h2>
<p>CoQA was constructed using Amazon Mechanical Turk, pairing workers in a question-answer dialogue setup. One worker asked questions about a given passage while another provided answers. The answerer first highlighted the relevant text span, then rephrased the information using different words to create natural, abstractive responses.</p>
<p>This methodology produces answers that sound conversational. This makes the dataset highly realistic for dialogue applications.</p>
<h3 id="domain-coverage">Domain Coverage</h3>
<p>CoQA spans diverse text types to ensure evaluation across different writing styles and topics:</p>
<p><strong>Training domains (5):</strong></p>
<ul>
<li>Children&rsquo;s stories from <a href="https://web.archive.org/web/20180829214346/https://uclmr.github.io/ai4exams/data.html#mctest">MCTest</a></li>
<li>Literature from <a href="https://www.gutenberg.org/">Project Gutenberg</a></li>
<li>Educational content from <a href="https://www.cs.cmu.edu/~glai1/data/race/">RACE</a> (middle/high school English)</li>
<li>CNN news articles</li>
<li>Wikipedia articles</li>
</ul>
<p><strong>Test-only domains (2):</strong></p>
<ul>
<li>Science articles from <a href="http://data.allenai.org/ai2-science-questions/">AI2 Science Questions</a></li>
<li>Creative writing from <a href="https://www.reddit.com/r/WritingPrompts/">Reddit WritingPrompts</a></li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_domains.webp"
         alt="Domain distribution in the CoQA dataset"
         title="Domain distribution in the CoQA dataset"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Domain distribution in the CoQA dataset</figcaption>
    
</figure>

<p>The inclusion of test-only domains provides a rigorous evaluation of model generalization to unseen text types.</p>
<h2 id="comparison-with-existing-datasets">Comparison with Existing Datasets</h2>
<p>Prior to CoQA, the dominant question answering benchmark was <a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD (Stanford Question Answering Dataset)</a>. SQuAD established foundations for reading comprehension and presented specific constraints:</p>
<ul>
<li><strong>SQuAD 1.0</strong>: 100,000+ questions requiring exact text extraction from Wikipedia passages</li>
<li><strong>SQuAD 2.0</strong>: Added 50,000+ unanswerable questions to test when no answer exists</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_size.webp"
         alt="Scale comparison between SQuAD and CoQA datasets"
         title="Scale comparison between SQuAD and CoQA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Scale comparison between SQuAD and CoQA datasets</figcaption>
    
</figure>

<p>SQuAD treats each question independently and requires only extractive answers. CoQA addresses these constraints through conversational context and abstractive responses.</p>
<h3 id="question-and-answer-analysis">Question and Answer Analysis</h3>
<p>The differences between SQuAD and CoQA extend beyond conversational context:</p>
<p><strong>Question diversity</strong>: SQuAD heavily favors &ldquo;what&rdquo; questions (~50%). CoQA shows a more balanced distribution across question types, reflecting natural conversation patterns.</p>















<figure class="post-figure center ">
    <img src="/img/squad_v_coqa.webp"
         alt="Question type distribution comparison between SQuAD and CoQA"
         title="Question type distribution comparison between SQuAD and CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Question type distribution comparison between SQuAD and CoQA</figcaption>
    
</figure>

<p><strong>Context dependence</strong>: CoQA includes challenging single-word questions like &ldquo;who?&rdquo;, &ldquo;where?&rdquo;, or &ldquo;why?&rdquo; that depend entirely on dialogue history.</p>
<p><strong>Answer characteristics</strong>: CoQA answers vary significantly in length and style. SQuAD primarily features extractive spans.</p>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_answers.webp"
         alt="Answer length distribution in SQuAD vs CoQA"
         title="Answer length distribution in SQuAD vs CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Answer length distribution in SQuAD vs CoQA</figcaption>
    
</figure>

<h2 id="the-coreference-challenge">The Coreference Challenge</h2>
<p>CoQA&rsquo;s difficulty stems largely from its reliance on coreference resolution (determining when different expressions refer to the same entity). This remains a challenging research problem in NLP.</p>
<p><strong>Coreference types in CoQA</strong>:</p>
<ul>
<li><strong>Explicit coreferences</strong> (~50% of questions): Clear indicators like pronouns (&ldquo;him,&rdquo; &ldquo;it,&rdquo; &ldquo;her,&rdquo; &ldquo;that&rdquo;)</li>
<li><strong>Implicit coreferences</strong> (~20% of questions): Context-dependent references requiring inference (e.g., asking &ldquo;where?&rdquo; without specifying what)</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_coreferences.webp"
         alt="Distribution of coreference types in CoQA questions"
         title="Distribution of coreference types in CoQA questions"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of coreference types in CoQA questions</figcaption>
    
</figure>

<p>These linguistic phenomena make CoQA more difficult than traditional reading comprehension, as models must resolve references across dialogue turns while maintaining conversational coherence.</p>
<h2 id="performance-benchmarks">Performance Benchmarks</h2>
<p>Models faced significant challenges on CoQA, with substantial room for improvement:</p>















<figure class="post-figure center ">
    <img src="/img/coqa_scores.webp"
         alt="Performance comparison on CoQA across different model types"
         title="Performance comparison on CoQA across different model types"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Performance comparison on CoQA across different model types</figcaption>
    
</figure>

<p>The performance gap between human and machine capabilities highlighted conversational question answering as a challenging frontier in NLP research.</p>
<h2 id="research-impact-and-future-directions">Research Impact and Future Directions</h2>
<p>CoQA represents a step toward more natural conversational AI systems. By requiring models to handle dialogue context, coreference resolution, and abstractive reasoning simultaneously, it challenges current NLP system capabilities.</p>
<p>The dataset&rsquo;s <a href="https://stanfordnlp.github.io/coqa/">leaderboard</a> provides a benchmark for measuring progress on this task. As models improve on CoQA, we can expect advances in conversational AI applications, from chatbots to virtual assistants that engage in more natural, context-aware dialogue.</p>
<p>CoQA&rsquo;s contribution to the field aims to parallel ImageNet&rsquo;s impact on computer vision, providing a challenging, well-constructed benchmark that drives research toward more capable AI systems.</p>
<h2 id="a-builders-perspective-coqa-in-the-era-of-llms">A Builder&rsquo;s Perspective: CoQA in the Era of LLMs</h2>
<p>Looking back at CoQA from the perspective of modern production systems, this dataset was highly prescient. The challenges it introduced, such as multi-turn reasoning, coreference resolution, and abstractive answering, are the exact capabilities we now expect from instruction-tuned Large Language Models (LLMs).</p>
<p>When building document processing pipelines at scale, we rarely extract isolated facts. Users want to chat with their documents, asking follow-up questions like, &ldquo;What does that mean for the Q3 budget?&rdquo; Resolving &ldquo;that&rdquo; to a previous turn&rsquo;s context is exactly what CoQA formalized. Datasets like CoQA laid the groundwork for the conversational interfaces we build today, shifting the field&rsquo;s focus from simple extraction to genuine dialogue comprehension.</p>
<h2 id="references">References</h2>
<p>Reddy, S., Chen, D., &amp; Manning, C. D. (2019). CoQA: A conversational question answering challenge. <em>Transactions of the Association for Computational Linguistics</em>, 7, 249-266.</p>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions pioneered by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) remain critical for stabilizing distributed training and designing robust loss functions today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan-a-mathematical-revolution">Wasserstein GAN (WGAN): A Mathematical Revolution</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> revolutionized GAN training by replacing Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with state-of-the-art methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>This approach is particularly relevant in scientific simulation, where matching the statistical moments of a generated distribution to the true physical distribution (e.g., molecular conformations) is critical for physical validity.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work building terabyte-scale VLMs and training models on chaotic physical systems, understanding these foundational dynamics is critical. While we often use diffusion models or autoregressive transformers today, the adversarial training paradigms pioneered by GANs still inform how we stabilize distributed training and design robust loss functions. The choice of objective function fundamentally dictates generation quality, training stability, and computational constraints.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item><item><title>Word Embeddings in NLP: An Introduction</title><link>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</link><pubDate>Sun, 05 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</guid><description>Learn about word embeddings in NLP: from basic one-hot encoding to contextual models like ELMo. Guide with examples.</description><content:encoded><![CDATA[<h2 id="understanding-word-embeddings">Understanding Word Embeddings</h2>
<p>A word embedding maps words to real-valued vectors:</p>
<p>$$
\text{word} \rightarrow \mathbb{R}^n
$$</p>
<p>where $n$ represents the dimensionality of the embedding space.</p>
<p>The goal is simple: position semantically similar words close together in vector space. This dense representation typically uses hundreds of dimensions, a massive reduction from the millions required by one-hot encoding.</p>
<p>Word embeddings are grounded in <a href="https://en.wikipedia.org/wiki/Distributional_semantics">Zellig Harris&rsquo; distributional hypothesis</a>: words appearing in similar contexts tend to have similar meanings. This forms the foundation of distributional semantics.</p>















<figure class="post-figure center ">
    <img src="/img/distributional_semantics-50.webp"
         alt="Distributional semantics visualization"
         title="Distributional semantics visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Words embedded in three-dimensional space, organized by semantic similarity</figcaption>
    
</figure>

<p>Different embedding algorithms capture various aspects of this distributional principle. This post explores the main methods for creating word embeddings and their applications in natural language processing.</p>
<p>While modern foundation models and terabyte-scale Vision-Language Models (VLMs) rely on advanced subword tokenizers (like BPE) and massive Transformer embedding layers, the fundamental goal remains exactly the same: mapping discrete text to a continuous vector space where math can capture meaning. Understanding these foundational techniques provides the necessary intuition for debugging and scaling today&rsquo;s production ML systems.</p>
<h2 id="why-word-embeddings-matter-in-nlp">Why Word Embeddings Matter in NLP</h2>
<p>Computers require numerical representations to apply machine learning algorithms to text. Word embeddings bridge this gap by converting text into dense vectors that preserve semantic and syntactic relationships.</p>
<p><strong>Key advantages:</strong></p>
<ol>
<li><strong>Dense representation</strong>: Hundreds of dimensions provide a compact alternative to vocabulary-sized sparse vectors.</li>
<li><strong>Semantic preservation</strong>: Similar words cluster together in vector space.</li>
<li><strong>Mathematical operations</strong>: Enable analogical reasoning ($\text{king} - \text{man} + \text{woman} \approx \text{queen}$).</li>
<li><strong>Transfer learning</strong>: Pre-trained embeddings work across multiple tasks and domains.</li>
</ol>
<p>Modern deep learning architectures leverage these properties extensively. The development of universal, pre-trained embeddings was a significant step forward. We can use versatile embeddings that generalize across applications, eliminating the need to train task-specific representations from scratch.</p>
<h2 id="word-embedding-approaches">Word Embedding Approaches</h2>
<h3 id="one-hot-encoding-and-count-vectorization">One-Hot Encoding and Count Vectorization</h3>
<p>One-hot encoding represents the simplest approach to word vectorization. Each word gets a unique dimension in a vocabulary-sized vector, marked with 1 for presence and 0 elsewhere. Count vectorization extends this by counting the occurrences of each word in a document.</p>















<figure class="post-figure center ">
    <img src="/img/word_vector_onehot-50.webp"
         alt="One-hot encoding visualization"
         title="One-hot encoding visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">One-hot encoding creates sparse vectors with single active dimensions</figcaption>
    
</figure>

<p><strong>Characteristics:</strong></p>
<ul>
<li><strong>High dimensionality</strong>: Vector length equals vocabulary size.</li>
<li><strong>Extreme sparsity</strong>: Most dimensions contain zeros.</li>
<li><strong>No relationships</strong>: Treats all words as equally distant.</li>
<li><strong>Computational efficiency</strong>: Simple to implement and understand.</li>
</ul>
<p>While lacking semantic information, count vectorization serves as a foundation for more complex methods. Let&rsquo;s look at a practical implementation using scikit-learn&rsquo;s <code>CountVectorizer</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize the vectorizer</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Sample text for demonstration</span>
</span></span><span style="display:flex;"><span>sample_text <span style="color:#f92672">=</span> [<span style="color:#e6db74">&#34;One of the most basic ways we can numerically represent words &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;is through the one-hot encoding method (also sometimes called &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;count vectorizing).&#34;</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Fit the vectorizer to our text data</span>
</span></span><span style="display:flex;"><span>vectorizer<span style="color:#f92672">.</span>fit(sample_text)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Examine the vocabulary and word indices</span>
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Vocabulary:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vectorizer<span style="color:#f92672">.</span>vocabulary_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform text to vectors</span>
</span></span><span style="display:flex;"><span>vector <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(sample_text)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Full vector:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vector<span style="color:#f92672">.</span>toarray())
</span></span></code></pre></div><p>In a production environment, count vectorization introduces significant engineering challenges. When processing millions of documents, the vocabulary size explodes. Storing and computing on these massive sparse matrices quickly leads to memory exhaustion. In these scaling scenarios, practitioners often turn to the <strong>Hashing Trick</strong> (via <code>HashingVectorizer</code>) to bound the dimensionality, or they move entirely to the dense embeddings discussed later in this post.</p>
<p>We can see count vectorization in action with a real dataset, building a simple text classifier for the <a href="https://www.kaggle.com/datasets/crawford/20-newsgroups">20 Newsgroups dataset</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.datasets <span style="color:#f92672">import</span> fetch_20newsgroups
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.naive_bayes <span style="color:#f92672">import</span> MultinomialNB
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn <span style="color:#f92672">import</span> metrics
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load train and test splits, removing metadata for a cleaner signal</span>
</span></span><span style="display:flex;"><span>newsgroups_train <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;train&#39;</span>,
</span></span><span style="display:flex;"><span>                                      remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>newsgroups_test <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;test&#39;</span>,
</span></span><span style="display:flex;"><span>                                     remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize and fit vectorizer on training data</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>X_train <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>fit_transform(newsgroups_train<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Build and train classifier</span>
</span></span><span style="display:flex;"><span>classifier <span style="color:#f92672">=</span> MultinomialNB(alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.01</span>)
</span></span><span style="display:flex;"><span>classifier<span style="color:#f92672">.</span>fit(X_train, newsgroups_train<span style="color:#f92672">.</span>target)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform test data and make predictions</span>
</span></span><span style="display:flex;"><span>X_test <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(newsgroups_test<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>y_pred <span style="color:#f92672">=</span> classifier<span style="color:#f92672">.</span>predict(X_test)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Evaluate performance</span>
</span></span><span style="display:flex;"><span>accuracy <span style="color:#f92672">=</span> metrics<span style="color:#f92672">.</span>accuracy_score(newsgroups_test<span style="color:#f92672">.</span>target, y_pred)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Accuracy: </span><span style="color:#e6db74">{</span>accuracy<span style="color:#e6db74">:</span><span style="color:#e6db74">.3f</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>This provides a solid baseline. To capture actual semantic meaning and reduce dimensionality, we must move beyond simple counting.</p>
<h3 id="tf-idf-term-frequency-inverse-document-frequency">TF-IDF (Term Frequency-Inverse Document Frequency)</h3>
<p><a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html">TF-IDF</a> extends one-hot encoding by weighting terms based on their importance across a document collection. TF-IDF combines:</p>
<ul>
<li><strong>Term Frequency (TF)</strong>: How often a word appears in a document</li>
<li><strong>Inverse Document Frequency (IDF)</strong>: How rare a word is across all documents</li>
</ul>
<p>This weighting scheme reduces the impact of common words (like &ldquo;the&rdquo; or &ldquo;and&rdquo;) while emphasizing distinctive terms that appear frequently in specific documents but rarely elsewhere.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li>Captures document-level importance</li>
<li>Reduces impact of stop words</li>
<li>Effective for information retrieval tasks</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>Still high-dimensional and sparse</li>
<li>No semantic relationships between terms</li>
<li>Context-independent representation</li>
</ul>
<h3 id="co-occurrence-matrices">Co-Occurrence Matrices</h3>
<p>Co-occurrence matrices capture word relationships by recording which terms appear together within defined contexts (sentences, paragraphs, or fixed windows). The resulting matrix has dimensions equal to vocabulary size squared, with entries showing co-occurrence frequency.</p>















<figure class="post-figure center ">
    <img src="/img/Word_co-occurrence_network_%28range_3_words%29_-_ENG-50.webp"
         alt="Co-occurrence network visualization"
         title="Co-occurrence network visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Co-occurrence relationships within a three-word window</figcaption>
    
</figure>

<p><strong>Key properties:</strong></p>
<ul>
<li><strong>Global statistics</strong>: Captures corpus-wide word relationships</li>
<li><strong>Symmetric relationships</strong>: Mutual co-occurrence patterns</li>
<li><strong>Extreme dimensionality</strong>: Vocabulary size squared creates storage challenges</li>
<li><strong>Sparse representation</strong>: Most word pairs never co-occur</li>
</ul>
<p>While computationally expensive to store and process, co-occurrence matrices form the foundation for advanced methods like GloVe that compress this information into dense representations.</p>
<h2 id="neural-network-based-embeddings">Neural Network-Based Embeddings</h2>
<h3 id="neural-probabilistic-language-models">Neural Probabilistic Language Models</h3>
<p><a href="https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf">Neural probabilistic models</a> pioneered the use of neural networks for learning word embeddings. These models learn dense representations as a byproduct of language modeling, predicting the next word in a sequence.</p>















<figure class="post-figure center ">
    <img src="/img/bengio-npm-50.webp"
         alt="Neural probabilistic model diagram"
         title="Neural probabilistic model diagram"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Architecture of neural probabilistic language models</figcaption>
    
</figure>

<p><strong>Training process:</strong></p>
<ol>
<li>Initialize random dense embeddings for each vocabulary word</li>
<li>Use embeddings as inputs to predict language modeling objectives</li>
<li>Update embeddings through backpropagation based on prediction errors</li>
<li>Resulting embeddings capture patterns useful for the training task</li>
</ol>
<p>This approach demonstrated that task-specific embeddings could be learned jointly with model objectives, establishing the foundation for modern embedding methods.</p>
<h3 id="word2vec">Word2Vec</h3>
<p><a href="https://code.google.com/archive/p/word2vec/">Word2Vec</a> revolutionized word embeddings by introducing efficient training algorithms for massive corpora. It became the first method to demonstrate compelling vector arithmetic properties, enabling analogical reasoning like the famous &ldquo;$\text{king} - \text{man} + \text{woman} \approx \text{queen}$&rdquo; example.</p>















<figure class="post-figure center ">
    <img src="/img/Word_vector_illustration.webp"
         alt="Word2Vec vector arithmetic visualization"
         title="Word2Vec vector arithmetic visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Word2Vec demonstrates analogical relationships through vector arithmetic</figcaption>
    
</figure>

<p><strong>Two training architectures:</strong></p>
<h4 id="continuous-bag-of-words-cbow">Continuous Bag-of-Words (CBOW)</h4>
<p>Predicts target words from surrounding context words. Given a window of context words, the model learns to predict the central word.</p>
<h4 id="skip-gram">Skip-Gram</h4>
<p>Predicts context words from target words. Given a central word, the model learns to predict surrounding words within a defined window.</p>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>Computational efficiency</strong>: Much faster than neural probabilistic models</li>
<li><strong>Scalable training</strong>: Can process billion-word corpora effectively</li>
<li><strong>Quality embeddings</strong>: Captures semantic and syntactic relationships</li>
<li><strong>Flexible context</strong>: Window size controls topical vs. functional similarity</li>
</ul>
<p>The choice of window size significantly impacts learned relationships. Larger windows capture topical associations, while smaller windows focus on syntactic and functional similarities.</p>
<h3 id="glove-global-vectors">GloVe (Global Vectors)</h3>
<p><a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> combines the best aspects of matrix factorization methods (which capture global corpus statistics) and local context window approaches like Word2Vec. Matrix factorization methods excel at global patterns but struggle with analogical reasoning, while Word2Vec captures local relationships but may miss global structure.</p>
<p><strong>Key innovation:</strong>
GloVe trains on a global word-context co-occurrence matrix, incorporating corpus-wide statistical information while maintaining the analogical reasoning capabilities that made Word2Vec successful.</p>
<p><strong>Advantages over Word2Vec:</strong></p>
<ul>
<li><strong>Global optimization</strong>: Leverages entire corpus statistics</li>
<li><strong>Better performance</strong>: Often outperforms Word2Vec on word similarity and analogy tasks</li>
<li><strong>Stable training</strong>: More consistent convergence due to global objective function</li>
</ul>
<p>The result is embeddings that capture both local syntactic patterns and global semantic relationships more effectively.</p>
<h2 id="contextual-embedding-methods">Contextual Embedding Methods</h2>
<h3 id="fasttext">FastText</h3>
<p><a href="https://github.com/facebookresearch/fastText">FastText</a> addresses a critical limitation of previous methods: handling out-of-vocabulary (OOV) words. By incorporating subword information, FastText can generate meaningful representations for previously unseen words.</p>
<p><strong>Subword approach:</strong></p>
<ul>
<li>Decomposes words into character n-grams (typically 3-6 characters)</li>
<li>Represents words as sums of their component n-grams</li>
<li>Trains using skip-gram objective with negative sampling</li>
</ul>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>OOV handling</strong>: Can embed unseen words using known subword components</li>
<li><strong>Morphological awareness</strong>: Captures relationships between related word forms</li>
<li><strong>Multilingual support</strong>: Facebook released pre-trained embeddings for 294 languages</li>
<li><strong>Robust performance</strong>: Particularly effective for morphologically rich languages</li>
</ul>
<p>For example, if the model knows &ldquo;navigate,&rdquo; it can provide meaningful representation for &ldquo;circumnavigate&rdquo; by leveraging shared subword components, even if &ldquo;circumnavigate&rdquo; wasn&rsquo;t in the training data.</p>
<h3 id="poincaré-embeddings">Poincaré Embeddings</h3>
<p><a href="https://radimrehurek.com/gensim/models/poincare.html">Poincaré embeddings</a> introduce a novel approach by learning representations in hyperbolic space. This geometric innovation specifically targets hierarchical relationships in data.</p>
<p><strong>Hyperbolic geometry advantages:</strong></p>
<ul>
<li><strong>Natural hierarchy encoding</strong>: Distance represents similarity, while norm encodes hierarchical level</li>
<li><strong>Efficient representation</strong>: Requires fewer dimensions for hierarchical data</li>
<li><strong>Mathematical elegance</strong>: Leverages properties of hyperbolic space for embedding optimization</li>
</ul>
<p><strong>Applications:</strong>
Particularly effective for data with inherent hierarchical structure, such as:</p>
<ul>
<li>WordNet taxonomies</li>
<li>Organizational charts</li>
<li>Computer network topologies</li>
<li>Knowledge graphs</li>
</ul>
<p>The <a href="https://arxiv.org/abs/1705.08039">original paper</a> demonstrates good efficiency in reproducing WordNet relationships with significantly lower dimensionality compared to traditional embedding methods.</p>
<h2 id="contextual-embeddings">Contextual Embeddings</h2>
<h3 id="elmo-embeddings-from-language-models">ELMo (Embeddings from Language Models)</h3>
<p><a href="https://github.com/allenai/allennlp-models">ELMo</a> represents a paradigm shift toward contextual word representations. ELMo generates dynamic representations based on sentence context, adapting to word usage patterns.</p>
<p><strong>Architecture:</strong></p>
<ul>
<li><strong>Bidirectional LSTM</strong>: Processes text in both forward and backward directions</li>
<li><strong>Character-level input</strong>: Handles OOV words and captures morphological patterns</li>
<li><strong>Multi-layer representations</strong>: Combines different abstraction levels</li>
</ul>
<p><strong>Layer specialization:</strong></p>
<ul>
<li><strong>Lower layers</strong>: Excel at syntactic tasks (POS tagging, parsing)</li>
<li><strong>Higher layers</strong>: Capture semantic relationships (word sense disambiguation)</li>
<li><strong>Combined layers</strong>: Weighted combination achieves good performance</li>
</ul>
<p><strong>Key innovation:</strong>
ELMo embeddings vary by context. The word &ldquo;bank&rdquo; receives different representations in &ldquo;river bank&rdquo; versus &ldquo;financial bank,&rdquo; addressing polysemy directly through contextual awareness.</p>
<p>This approach achieved strong performance across numerous NLP tasks by providing context-sensitive representations that adapt to word usage patterns.</p>
<h3 id="probabilistic-fasttext">Probabilistic FastText</h3>
<p><a href="https://github.com/benathi/multisense-prob-fasttext">Probabilistic FastText</a> addresses polysemy (words with multiple meanings) through probabilistic modeling. Traditional embeddings conflate different word senses into single representations, limiting their precision.</p>
<p><strong>The polysemy problem:</strong>
Consider &ldquo;rock&rdquo; which can mean:</p>
<ul>
<li>Rock music (genre)</li>
<li>A stone (geological object)</li>
<li>Rocking motion (verb)</li>
</ul>
<p>Standard embeddings average these meanings, producing representations that may not capture any sense precisely.</p>
<p><strong>Probabilistic approach:</strong>
Probabilistic FastText represents words as Gaussian mixture models: probability distributions that can capture multiple distinct meanings as separate components.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li><strong>Multi-sense representation</strong>: Each word sense gets its own distribution</li>
<li><strong>Context sensitivity</strong>: Can select appropriate sense based on usage context</li>
<li><strong>Uncertainty quantification</strong>: Probabilistic framework captures embedding confidence</li>
</ul>
<p>This approach provides a more nuanced treatment of lexical ambiguity, particularly valuable for words with distinct, context-dependent meanings.</p>
<h2 id="summary-and-future-directions">Summary and Future Directions</h2>
<p>Word embeddings have evolved from simple one-hot encodings to contextual representations that capture nuanced linguistic relationships. Each approach offers distinct advantages:</p>
<p><strong>Static embeddings</strong> (Word2Vec, GloVe, FastText) provide:</p>
<ul>
<li>Computational efficiency for large-scale applications</li>
<li>Pre-trained models available for numerous languages</li>
<li>Clear analogical reasoning capabilities</li>
<li>Good performance on many downstream tasks</li>
</ul>
<p><strong>Contextual embeddings</strong> (ELMo, BERT, GPT) offer:</p>
<ul>
<li>Dynamic representations based on sentence context</li>
<li>Better handling of polysemy and word sense disambiguation</li>
<li>Strong performance on complex NLP tasks</li>
<li>Ability to capture subtle contextual nuances</li>
</ul>
<p><strong>Choosing the right approach</strong> depends on:</p>
<ul>
<li><strong>Task requirements</strong>: Static embeddings for efficiency, contextual for accuracy</li>
<li><strong>Data availability</strong>: Pre-trained models vs. domain-specific training</li>
<li><strong>Computational constraints</strong>: Static embeddings require less processing power</li>
<li><strong>Language coverage</strong>: Consider availability of pre-trained models for target languages</li>
</ul>
<p>The field continues advancing toward more efficient contextual models, better multilingual representations, and embeddings that capture increasingly complex linguistic phenomena.</p>
<p>For a production-grade Word2Vec implementation in PyTorch that takes these concepts further, see the <a href="/projects/modern-word2vec/">High-Performance Word2Vec project</a>.</p>
]]></content:encoded></item></channel></rss>