<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Neural-Networks on Hunter Heidenreich | Senior AI Research Scientist</title><link>https://hunterheidenreich.com/tags/neural-networks/</link><description>Recent content in Neural-Networks on Hunter Heidenreich | Senior AI Research Scientist</description><image><title>Hunter Heidenreich | Senior AI Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Sun, 31 May 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/tags/neural-networks/index.xml" rel="self" type="application/rss+xml"/><item><title>SpeechT5: Unified Speech-Text Pre-Training Framework</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/speecht5-unified-speech-text-pretraining/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/speecht5-unified-speech-text-pretraining/</guid><description>SpeechT5 introduces a shared encoder-decoder framework with cross-modal vector quantization for joint speech and text pre-training across six tasks.</description><content:encoded><![CDATA[<h2 id="a-unified-encoder-decoder-for-spoken-language-processing">A Unified Encoder-Decoder for Spoken Language Processing</h2>
<p>SpeechT5 is a <strong>Method</strong> paper that introduces a shared encoder-decoder pre-training framework for spoken language processing. Inspired by <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5&rsquo;s</a> text-to-text paradigm, SpeechT5 reformulates all spoken language tasks as &ldquo;speech/text to speech/text&rdquo; problems. The framework uses modal-specific pre-nets and post-nets to interface between raw speech or text and a shared Transformer encoder-decoder, enabling a single pre-trained model to handle six downstream tasks: automatic speech recognition (ASR), text-to-speech synthesis (TTS), speech translation (ST), voice conversion (VC), speech enhancement (SE), and speaker identification (SID).</p>
<h2 id="bridging-the-gap-between-speech-and-text-pre-training">Bridging the Gap Between Speech and Text Pre-Training</h2>
<p>Prior speech pre-training work (wav2vec 2.0, HuBERT) suffered from two key limitations. First, these models learned speech representations from unlabeled audio alone, ignoring the complementary information in text data that is critical for cross-modal tasks like ASR and TTS. Second, they relied on encoder-only architectures with task-specific prediction heads, leaving the decoder un-pretrained for sequence-to-sequence generation tasks.</p>
<p>SpeechT5 addresses both gaps by (1) jointly pre-training on unlabeled speech and text data, and (2) using a full encoder-decoder architecture that benefits generation tasks directly. The approach builds on the observation that speech and text, despite their surface differences, share underlying semantic structure that a unified representation can capture.</p>
<h2 id="cross-modal-vector-quantization-for-alignment">Cross-Modal Vector Quantization for Alignment</h2>
<p>The core innovation in SpeechT5 is a cross-modal <a href="https://en.wikipedia.org/wiki/Vector_quantization">vector quantization</a> (VQ) mechanism that aligns speech and text representations into a shared semantic space. The architecture consists of three components:</p>
<p><strong>Shared encoder-decoder backbone.</strong> A Transformer with 12 encoder blocks and 6 decoder blocks (768-dim, 12 heads), using relative position embeddings.</p>
<p><strong>Modal-specific pre/post-nets.</strong> Six specialized networks handle the conversion between raw modalities and the shared representation space:</p>
<ul>
<li>Speech-encoder pre-net: a convolutional feature extractor (from wav2vec 2.0) downsampling raw waveforms</li>
<li>Speech-decoder pre-net: three FC layers with ReLU, processing 80-dimensional log Mel-filterbank features</li>
<li>Speech-decoder post-net: a linear layer predicting Mel features plus five 1D conv layers (256 channels) for residual refinement, with an x-vector speaker embedding concatenated for multi-speaker support</li>
<li>Text pre/post-nets: shared embedding layers mapping between character-level token indices and hidden states (768-dim)</li>
</ul>
<p><strong>Cross-modal vector quantization.</strong> A shared codebook $\mathbf{C}^{K}$ with $K$ learnable embeddings bridges the two modalities. Encoder outputs $\mathbf{u}_i$ are quantized via nearest-neighbor lookup:</p>
<p>$$
\mathbf{c}_i = \arg\min_{j \in [K]} | \mathbf{u}_i - \mathbf{c}_j |_2
$$</p>
<p>A proportion (10%) of contextual representations are randomly replaced with these quantized latent units before being fed to the decoder&rsquo;s cross-attention. This mixing forces the quantizer to capture cross-modal features. A diversity loss encourages full codebook utilization:</p>
<p>$$
\mathcal{L}_d = \frac{1}{K} \sum_{k=1}^{K} p_k \log p_k
$$</p>
<h3 id="pre-training-objectives">Pre-Training Objectives</h3>
<p>SpeechT5 combines three pre-training objectives:</p>
<p><strong>Speech pre-training</strong> uses two tasks. A bidirectional masked prediction loss $\mathcal{L}_{mlm}^{s}$ follows HuBERT&rsquo;s approach, masking 8% of timesteps in 10-step spans and predicting frame-level targets from an acoustic unit discovery model:</p>
<p>$$
\mathcal{L}_{mlm}^{s} = \sum_{n \in \mathcal{M}} \log p(\mathbf{z}_n \mid \hat{\mathbf{H}}, n)
$$</p>
<p>A reconstruction loss $\mathcal{L}_{1}^{s}$ minimizes the $L_1$ distance between predicted and original Mel-filterbank features, plus a binary cross-entropy stop-token loss $\mathcal{L}_{bce}^{s}$.</p>
<p><strong>Text pre-training</strong> uses BART-style denoising, masking 30% of text spans (Poisson $\lambda = 3.5$) and training with maximum likelihood estimation:</p>
<p>$$
\mathcal{L}_{mle}^{t} = \sum_{n=1}^{N^t} \log p(\mathbf{y}_n^t \mid \mathbf{y}_{&lt; n}^t, \hat{\mathbf{X}}^t)
$$</p>
<p>The full pre-training loss combines all components:</p>
<p>$$
\mathcal{L} = \mathcal{L}_{mlm}^{s} + \mathcal{L}_{1}^{s} + \mathcal{L}_{bce}^{s} + \mathcal{L}_{mle}^{t} + \gamma \mathcal{L}_d
$$</p>
<p>where $\gamma = 0.1$.</p>
<h2 id="evaluation-across-six-spoken-language-tasks">Evaluation Across Six Spoken Language Tasks</h2>
<p>SpeechT5 was evaluated on six downstream tasks, each using a different combination of the shared encoder-decoder and task-appropriate pre/post-nets:</p>
<h3 id="automatic-speech-recognition-asr">Automatic Speech Recognition (ASR)</h3>
<p>Fine-tuned on LibriSpeech 100h with joint <a href="https://en.wikipedia.org/wiki/Connectionist_temporal_classification">CTC</a>/attention decoding. The decoding objective maximizes a combination of decoder, CTC, and language model log-probabilities:</p>
<p>$$
\alpha \log P_{Dec} + (1 - \alpha) \log P_{CTC} + \beta \log P_{LM}
$$</p>
<p>where $\alpha = 0.5$ and $\beta = 1.0$ for the 100h setting (beam size 30). Results on the test sets:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>LM</th>
          <th>test-clean</th>
          <th>test-other</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>wav2vec 2.0 BASE</td>
          <td>-</td>
          <td>6.1</td>
          <td>13.3</td>
      </tr>
      <tr>
          <td>HuBERT BASE</td>
          <td>-</td>
          <td>5.8</td>
          <td>13.3</td>
      </tr>
      <tr>
          <td><strong>SpeechT5</strong></td>
          <td><strong>-</strong></td>
          <td><strong>4.4</strong></td>
          <td><strong>10.4</strong></td>
      </tr>
      <tr>
          <td>wav2vec 2.0 BASE</td>
          <td>Transf.</td>
          <td>2.6</td>
          <td>6.3</td>
      </tr>
      <tr>
          <td><strong>SpeechT5</strong></td>
          <td><strong>Transf.</strong></td>
          <td><strong>2.4</strong></td>
          <td><strong>5.8</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="text-to-speech-synthesis-tts">Text-to-Speech Synthesis (TTS)</h3>
<p>Fine-tuned on LibriTTS 460h clean sets with HiFi-GAN vocoder:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Naturalness</th>
          <th>MOS</th>
          <th>CMOS</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Ground Truth</td>
          <td>-</td>
          <td>3.87 ± 0.04</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Baseline</td>
          <td>2.76</td>
          <td>3.56 ± 0.05</td>
          <td>0</td>
      </tr>
      <tr>
          <td><strong>SpeechT5</strong></td>
          <td><strong>2.91</strong></td>
          <td><strong>3.65 ± 0.04</strong></td>
          <td><strong>+0.290</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="speech-translation-st">Speech Translation (ST)</h3>
<p>Evaluated on MUST-C English-to-German and English-to-French:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>EN-DE</th>
          <th>EN-FR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fairseq ST</td>
          <td>22.70</td>
          <td>32.90</td>
      </tr>
      <tr>
          <td>Adapter Tuning</td>
          <td>24.63</td>
          <td>34.98</td>
      </tr>
      <tr>
          <td>Baseline (HuBERT init)</td>
          <td>23.43</td>
          <td>33.76</td>
      </tr>
      <tr>
          <td><strong>SpeechT5</strong></td>
          <td><strong>25.18</strong></td>
          <td><strong>35.30</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="voice-conversion-vc">Voice Conversion (VC)</h3>
<p>Evaluated on CMU Arctic:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>WER (bdl→slt)</th>
          <th>MCD (bdl→slt)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>VTN w/ TTS</td>
          <td>7.6%</td>
          <td>6.33</td>
      </tr>
      <tr>
          <td>Many-to-many VTN</td>
          <td>-</td>
          <td>6.13</td>
      </tr>
      <tr>
          <td><strong>SpeechT5</strong></td>
          <td><strong>7.8%</strong></td>
          <td><strong>5.93</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="speech-enhancement-se">Speech Enhancement (SE)</h3>
<p>On WHAM! dataset, SpeechT5 reduced WER from 76.1% (noisy) to 8.9%, a relative 9% improvement over the baseline&rsquo;s 10.9%.</p>
<h3 id="speaker-identification-sid">Speaker Identification (SID)</h3>
<p>On VoxCeleb1, SpeechT5 achieved 96.49% accuracy, outperforming HuBERT LARGE at 90.33% (from SUPERB) and SpeechNet multi-task at 87.90%.</p>
<h2 id="ablation-study-and-key-findings">Ablation Study and Key Findings</h2>
<p>The ablation study reveals the contribution of each pre-training component:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>ASR (clean)</th>
          <th>ASR (other)</th>
          <th>VC (MCD)</th>
          <th>SID (ACC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SpeechT5</td>
          <td>4.4</td>
          <td>10.7</td>
          <td>5.93</td>
          <td>96.49%</td>
      </tr>
      <tr>
          <td>w/o Speech PT</td>
          <td>-</td>
          <td>-</td>
          <td>6.49</td>
          <td>38.61%</td>
      </tr>
      <tr>
          <td>w/o Text PT</td>
          <td>5.4</td>
          <td>12.8</td>
          <td>6.03</td>
          <td>95.60%</td>
      </tr>
      <tr>
          <td>w/o Joint PT</td>
          <td>4.6</td>
          <td>11.3</td>
          <td>6.18</td>
          <td>95.54%</td>
      </tr>
      <tr>
          <td>w/o $\mathcal{L}_{mlm}^{s}$</td>
          <td>7.6</td>
          <td>22.4</td>
          <td>6.29</td>
          <td>90.91%</td>
      </tr>
  </tbody>
</table>
<p>Key findings:</p>
<ol>
<li><strong>Speech pre-training is critical</strong>: without it, ASR fails to converge entirely, and SID accuracy drops to 38.61%.</li>
<li><strong>Text pre-training complements speech</strong>: removing it degrades ASR by ~20% relative, confirming that textual knowledge transfers to speech tasks.</li>
<li><strong>Joint pre-training enables cross-modal transfer</strong>: the vector quantization approach is essential for modality-bridging tasks like ASR.</li>
<li><strong>The masked prediction loss $\mathcal{L}_{mlm}^{s}$ is the most important single component</strong>, responsible for learning strong acoustic features.</li>
</ol>
<p>The authors note limitations in the current scope (English-only, BASE model size) and propose scaling to larger models and multilingual settings as future work.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Speech pre-training</td>
          <td>LibriSpeech</td>
          <td>960 hours</td>
          <td>Full training set</td>
      </tr>
      <tr>
          <td>Text pre-training</td>
          <td>LibriSpeech LM text</td>
          <td>400M sentences</td>
          <td>Normalized language model text</td>
      </tr>
      <tr>
          <td>ASR fine-tuning</td>
          <td>LibriSpeech</td>
          <td>100h / 960h subsets</td>
          <td></td>
      </tr>
      <tr>
          <td>TTS fine-tuning</td>
          <td>LibriTTS</td>
          <td>460h clean sets</td>
          <td></td>
      </tr>
      <tr>
          <td>ST fine-tuning</td>
          <td>MUST-C</td>
          <td>EN-DE, EN-FR</td>
          <td></td>
      </tr>
      <tr>
          <td>VC fine-tuning</td>
          <td>CMU Arctic</td>
          <td>4 speakers</td>
          <td>bdl, clb, slt, rms</td>
      </tr>
      <tr>
          <td>SE fine-tuning</td>
          <td>WHAM!</td>
          <td>16 kHz max</td>
          <td>enhance-single task</td>
      </tr>
      <tr>
          <td>SID fine-tuning</td>
          <td>VoxCeleb1</td>
          <td>100k+ utterances</td>
          <td>1,251 speakers</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam with warmup (8% of steps) to peak LR $2 \times 10^{-4}$, then linear decay</li>
<li>Speech masking: 8% of timesteps, 10-step spans</li>
<li>Text masking: 30% of spans, Poisson $\lambda = 3.5$</li>
<li>Vector quantization: 2 codebooks × 100 entries = $10^4$ theoretical maximum codes</li>
<li>CTC/attention joint decoding for ASR (beam size 30)</li>
<li>HiFi-GAN vocoder for TTS and SE waveform generation</li>
<li>Parallel WaveGAN vocoder for VC</li>
</ul>
<h3 id="fine-tuning-hyperparameters">Fine-Tuning Hyperparameters</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>GPUs</th>
          <th>Steps</th>
          <th>Peak LR</th>
          <th>Batch (per GPU)</th>
          <th>Schedule</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ASR (100h)</td>
          <td>8×V100</td>
          <td>80k</td>
          <td>6e-5</td>
          <td>256k audio samples</td>
          <td>Warmup 10%, hold 40%, linear decay</td>
      </tr>
      <tr>
          <td>ASR (960h)</td>
          <td>8×V100</td>
          <td>320k</td>
          <td>1.3e-4</td>
          <td>256k audio samples</td>
          <td>Warmup 10%, hold 40%, linear decay</td>
      </tr>
      <tr>
          <td>TTS</td>
          <td>8×V100</td>
          <td>120k</td>
          <td>4e-4</td>
          <td>45k tokens</td>
          <td>Warmup 10k steps, inv. sqrt decay</td>
      </tr>
      <tr>
          <td>ST</td>
          <td>8×V100</td>
          <td>80k</td>
          <td>-</td>
          <td>-</td>
          <td>Warmup 10k steps</td>
      </tr>
      <tr>
          <td>VC</td>
          <td>8×V100</td>
          <td>60k</td>
          <td>1e-4</td>
          <td>20k tokens</td>
          <td>6k warmup, inv. sqrt decay</td>
      </tr>
      <tr>
          <td>SE</td>
          <td>8×V100</td>
          <td>100k</td>
          <td>1e-4</td>
          <td>16k tokens</td>
          <td>10k warmup, inv. sqrt decay</td>
      </tr>
      <tr>
          <td>SID</td>
          <td>8×V100</td>
          <td>60k</td>
          <td>5e-4</td>
          <td>64 segments (3s each)</td>
          <td>Triangular cyclical (1e-8 to 5e-4)</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<ul>
<li>Encoder: 12 Transformer blocks (768-dim, 3072 FFN, 12 heads)</li>
<li>Decoder: 6 Transformer blocks (same dimensions)</li>
<li>Speech-encoder pre-net: 7 conv blocks (512 channels, strides [5,2,2,2,2,2,2], kernels [10,3,3,3,3,2,2])</li>
<li>Code and pre-trained models available at <a href="https://github.com/microsoft/SpeechT5">github.com/microsoft/SpeechT5</a> (MIT license)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/microsoft/SpeechT5">microsoft/SpeechT5</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official Fairseq-based implementation</td>
      </tr>
      <tr>
          <td>Pre-trained models (via repo)</td>
          <td>Model</td>
          <td>MIT</td>
          <td>SpeechT5 BASE encoder-decoder checkpoints</td>
      </tr>
      <tr>
          <td><a href="https://www.openslr.org/12">LibriSpeech</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>960h speech pre-training and ASR fine-tuning</td>
      </tr>
      <tr>
          <td><a href="https://www.openslr.org/60">LibriTTS</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>460h TTS fine-tuning</td>
      </tr>
      <tr>
          <td><a href="https://ict.fbk.eu/must-c/">MUST-C</a></td>
          <td>Dataset</td>
          <td>CC-BY-NC-ND-4.0</td>
          <td>Speech translation fine-tuning</td>
      </tr>
      <tr>
          <td><a href="http://www.festvox.org/cmu_arctic/">CMU Arctic</a></td>
          <td>Dataset</td>
          <td>Free</td>
          <td>Voice conversion fine-tuning</td>
      </tr>
      <tr>
          <td><a href="http://wham.whisper.ai/">WHAM!</a></td>
          <td>Dataset</td>
          <td>CC-BY-NC-4.0</td>
          <td>Speech enhancement fine-tuning</td>
      </tr>
      <tr>
          <td><a href="https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html">VoxCeleb1</a></td>
          <td>Dataset</td>
          <td>CC-BY-SA-4.0</td>
          <td>Speaker identification fine-tuning</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 32 NVIDIA V100 GPUs</li>
<li>Batch: ~90s speech per GPU + 12k text tokens per GPU, gradient accumulation 2</li>
<li>Pre-training steps: 500k</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ao, J., Wang, R., Zhou, L., Wang, C., Ren, S., Wu, Y., Liu, S., Ko, T., Li, Q., Zhang, Y., Wei, Z., Qian, Y., Li, J., &amp; Wei, F. (2022). SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing. <em>Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)</em>, 5723-5738.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ao2022speecht,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ao, Junyi and Wang, Rui and Zhou, Long and Wang, Chengyi and Ren, Shuo and Wu, Yu and Liu, Shujie and Ko, Tom and Li, Qing and Zhang, Yu and Wei, Zhihua and Qian, Yao and Li, Jinyu and Wei, Furu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5723--5738}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2022.acl-long.393}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LSTNet: Long- and Short-Term Time Series Network</title><link>https://hunterheidenreich.com/notes/time-series/lstnet-multivariate-time-series/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/time-series/lstnet-multivariate-time-series/</guid><description>LSTNet combines CNNs, recurrent-skip connections, and autoregressive models to capture both short-term and long-term patterns in multivariate time series.</description><content:encoded><![CDATA[<h2 id="a-deep-learning-framework-for-multivariate-forecasting">A Deep Learning Framework for Multivariate Forecasting</h2>
<p>This is a <strong>Method</strong> paper that introduces the Long- and Short-term Time-series Network (LSTNet), a deep learning architecture specifically designed for multivariate time series forecasting. LSTNet combines convolutional neural networks (CNNs), recurrent neural networks (RNNs) with a novel skip-connection structure, and a traditional autoregressive (AR) component into a unified framework. The architecture targets the challenge of simultaneously capturing both short-term local dependencies and long-term periodic patterns in temporal data.</p>
<h2 id="why-short-term-and-long-term-patterns-need-separate-treatment">Why Short-Term and Long-Term Patterns Need Separate Treatment</h2>
<p>Real-world multivariate time series often exhibit a mixture of repeating patterns at different time scales. Highway traffic, for example, shows daily peaks (morning vs. evening commutes) alongside weekly patterns (weekday vs. weekend behavior). Solar energy output varies with cloud movements on short time scales and with seasonal daylight changes on longer ones. Electricity consumption follows similar daily and weekly cycles.</p>
<p>Traditional autoregressive methods (<a href="https://en.wikipedia.org/wiki/Vector_autoregression">VAR</a>, <a href="https://en.wikipedia.org/wiki/Autoregressive_integrated_moving_average">ARIMA</a>) and <a href="https://en.wikipedia.org/wiki/Gaussian_process">Gaussian Process</a> models struggle to distinguish and jointly model these two kinds of recurring patterns. Standard RNNs, including LSTM and <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> variants, theoretically handle long-range dependencies but in practice suffer from <a href="https://en.wikipedia.org/wiki/Vanishing_gradient_problem">gradient vanishing</a> when the period length is large (e.g., 24 hours at hourly resolution, or 168 time steps for weekly patterns). The authors also identify a scale sensitivity problem: neural network models can fail when the magnitude of the input signal changes in non-periodic ways, such as sudden shifts in electricity consumption due to holidays or weather events.</p>
<h2 id="combining-cnns-recurrent-skip-connections-and-autoregression">Combining CNNs, Recurrent-Skip Connections, and Autoregression</h2>
<p>The LSTNet architecture consists of four main components that work together.</p>
<h3 id="convolutional-component">Convolutional Component</h3>
<p>The first layer applies 1D convolution without pooling across the multivariate input. Each filter has width $\omega$ (in the time dimension) and height $n$ (spanning all variables), producing feature maps that capture short-term local dependency patterns among variables:</p>
<p>$$h_k = \text{RELU}(W_k * X + b_k)$$</p>
<p>where $*$ denotes convolution and the input is zero-padded so each output vector has length $T$. The output is a $d_c \times T$ matrix where $d_c$ is the number of filters.</p>
<h3 id="recurrent-component">Recurrent Component</h3>
<p>The CNN output feeds into a GRU-based recurrent layer that uses RELU (rather than the standard tanh) as the hidden update activation:</p>
<p>$$\begin{aligned}
r_t &amp;= \sigma(x_t W_{xr} + h_{t-1} W_{hr} + b_r) \\
u_t &amp;= \sigma(x_t W_{xu} + h_{t-1} W_{hu} + b_u) \\
c_t &amp;= \text{RELU}(x_t W_{xc} + r_t \odot (h_{t-1} W_{hc}) + b_c) \\
h_t &amp;= (1 - u_t) \odot h_{t-1} + u_t \odot c_t
\end{aligned}$$</p>
<h3 id="recurrent-skip-component">Recurrent-Skip Component</h3>
<p>The key architectural innovation is a recurrent structure with temporal skip connections. Instead of connecting to the immediately preceding hidden state $h_{t-1}$, skip links connect to the hidden state from $p$ steps ago ($h_{t-p}$), where $p$ corresponds to the period length of the data (e.g., $p = 24$ for hourly data with daily periodicity):</p>
<p>$$\begin{aligned}
r_t &amp;= \sigma(x_t W_{xr} + h_{t-p} W_{hr} + b_r) \\
u_t &amp;= \sigma(x_t W_{xu} + h_{t-p} W_{hu} + b_u) \\
c_t &amp;= \text{RELU}(x_t W_{xc} + r_t \odot (h_{t-p} W_{hc}) + b_c) \\
h_t &amp;= (1 - u_t) \odot h_{t-p} + u_t \odot c_t
\end{aligned}$$</p>
<p>This design shortens the effective path length for learning periodic dependencies, making optimization easier. A dense layer combines outputs from both recurrent components:</p>
<p>$$h_t^D = W^R h_t^R + \sum_{i=0}^{p-1} W_i^S h_{t-i}^S + b$$</p>
<h3 id="temporal-attention-alternative">Temporal Attention Alternative</h3>
<p>For datasets without clear periodicity, LSTNet offers an attention-based variant (LSTNet-Attn) as an alternative to the recurrent-skip component. The attention mechanism learns to weight hidden representations across the input window adaptively. The attention weights $\alpha_t \in \mathbb{R}^q$ at time $t$ are computed as:</p>
<p>$$\alpha_t = \text{AttnScore}(H_t^R, h_{t-1}^R)$$</p>
<p>where $H_t^R = [h_{t-q}^R, \dots, h_{t-1}^R]$ stacks the RNN hidden representations column-wise and AttnScore is a similarity function (dot product, cosine, or a parameterized MLP). The weighted context vector and final output are:</p>
<p>$$\begin{aligned}
c_t &amp;= H_t \alpha_t \\
h_t^D &amp;= W[c_t;; h_{t-1}^R] + b
\end{aligned}$$</p>
<h3 id="autoregressive-component">Autoregressive Component</h3>
<p>To address the scale insensitivity of neural networks, LSTNet adds a classical autoregressive model in parallel:</p>
<p>$$h_{t,i}^L = \sum_{k=0}^{q^{ar}-1} W_k^{ar} y_{t-k,i} + b^{ar}$$</p>
<p>The final prediction integrates both the neural network and AR outputs:</p>
<p>$$\hat{Y}_t = h_t^D + h_t^L$$</p>
<p>This decomposition separates the prediction into a linear part (handling local scale changes) and a non-linear part (capturing recurring patterns).</p>
<h3 id="objective-function">Objective Function</h3>
<p>LSTNet supports two loss functions, selected via validation performance. The default is the squared (L2) loss:</p>
<p>$$\underset{\Theta}{\text{minimize}} \sum_{t \in \Omega_{\text{Train}}} \left| Y_t - \hat{Y}_{t-h} \right|_F^2$$</p>
<p>Motivated by the strong performance of Linear SVR baselines, LSTNet also supports the absolute (L1) loss, which is more robust to anomalies in real time series data:</p>
<p>$$\underset{\Theta}{\text{minimize}} \sum_{t \in \Omega_{\text{Train}}} \sum_{i=0}^{n-1} \left| Y_{t,i} - \hat{Y}_{t-h,i} \right|$$</p>
<p>where $\Theta$ is the full parameter set, $\Omega_{\text{Train}}$ is the set of training time stamps, $|\cdot|_F$ is the Frobenius norm, and $h$ is the forecast horizon.</p>
<h2 id="evaluation-on-four-benchmark-datasets">Evaluation on Four Benchmark Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Length</th>
          <th>Variables</th>
          <th>Sample Rate</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Traffic</td>
          <td>17,544</td>
          <td>862</td>
          <td>1 hour</td>
      </tr>
      <tr>
          <td>Solar-Energy</td>
          <td>52,560</td>
          <td>137</td>
          <td>10 minutes</td>
      </tr>
      <tr>
          <td>Electricity</td>
          <td>26,304</td>
          <td>321</td>
          <td>1 hour</td>
      </tr>
      <tr>
          <td>Exchange-Rate</td>
          <td>7,588</td>
          <td>8</td>
          <td>1 day</td>
      </tr>
  </tbody>
</table>
<p>All datasets are split 60/20/20 (train/validation/test) in chronological order. Traffic, Solar-Energy, and Electricity exhibit clear periodic patterns (daily and weekly), while Exchange-Rate shows only short-term local continuity.</p>
<h3 id="baselines">Baselines</h3>
<p>The authors compare against seven methods: AR (univariate autoregression), LRidge (VAR with L2 regularization), LSVR (VAR with SVR objective), TRMF (temporal regularized matrix factorization), GP (Gaussian Process), VAR-MLP (hybrid MLP-autoregressive), and RNN-GRU (standard GRU).</p>
<h3 id="metrics">Metrics</h3>
<p>Two evaluation metrics are used:</p>
<ul>
<li><strong>Root Relative Squared Error (RSE)</strong> (lower is better): A scaled RMSE that normalizes by the standard deviation of the test data, making comparison across datasets readable regardless of data scale:</li>
</ul>
<p>$$\text{RSE} = \frac{\sqrt{\sum_{(i,t) \in \Omega_{\text{Test}}} (Y_{it} - \hat{Y}_{it})^2}}{\sqrt{\sum_{(i,t) \in \Omega_{\text{Test}}} (Y_{it} - \text{mean}(Y))^2}}$$</p>
<ul>
<li><strong>Empirical Correlation Coefficient (CORR)</strong> (higher is better): The average Pearson correlation between predicted and true time series across all $n$ variables:</li>
</ul>
<p>$$\text{CORR} = \frac{1}{n} \sum_{i=1}^{n} \frac{\sum_t (Y_{it} - \text{mean}(Y_i))(\hat{Y}_{it} - \text{mean}(\hat{Y}_i))}{\sqrt{\sum_t (Y_{it} - \text{mean}(Y_i))^2 \sum_t (\hat{Y}_{it} - \text{mean}(\hat{Y}_i))^2}}$$</p>
<h3 id="main-results">Main Results</h3>
<p>The models are evaluated at horizons $h \in {3, 6, 12, 24}$, corresponding to 3-24 hours for Traffic and Electricity, 30-240 minutes for Solar-Energy, and 3-24 days for Exchange-Rate.</p>
<p>LSTNet-Skip achieved the best result in 17 out of 32 (dataset, metric, horizon) combinations, and LSTNet-Attn won 7 more. No other method won more than 3. At horizon 24, the best LSTNet variant improved over RNN-GRU by 9.2% RSE on Solar-Energy (LSTNet-Attn), 11.7% on Traffic (LSTNet-Skip), and 22.2% on Electricity (LSTNet-Skip). On the Exchange-Rate dataset, which lacks periodic patterns, LSTNet performed comparably to AR and LRidge, as expected.</p>
<h3 id="ablation-study">Ablation Study</h3>
<p>Removing each component individually revealed:</p>
<ul>
<li><strong>Without AR</strong>: The largest performance drops across most datasets, confirming the AR component&rsquo;s role in handling scale changes. Visualization showed that LSTNet-Skip successfully tracks sudden magnitude shifts in electricity consumption around the 1000th hour, while the model without AR fails.</li>
<li><strong>Without Skip/CNN</strong>: Significant drops on datasets with periodic patterns, though less consistent than removing AR.</li>
<li><strong>Full LSTNet</strong>: The most robust configuration across all datasets and horizons.</li>
</ul>
<p>A simulation experiment with synthetic autoregressive data confirmed that standard RNN-GRU fails to track non-periodic scale changes, while LSTNet with its AR component adapts properly.</p>
<h2 id="robust-performance-through-architectural-complementarity">Robust Performance Through Architectural Complementarity</h2>
<p>LSTNet&rsquo;s main strength is the complementarity of its components. The CNN captures short-term local patterns, the recurrent-skip layer captures long-term periodic dependencies, and the AR component provides robustness to scale changes. On datasets with strong periodicity (Traffic, Solar-Energy, Electricity), the skip connections provide large gains. On datasets without periodicity (Exchange-Rate), the AR component prevents degradation below competitive baselines.</p>
<p>The primary limitation is that the skip length $p$ in the recurrent-skip component must be manually specified or tuned. For datasets with known periodicity (e.g., hourly data with daily cycles), $p$ is straightforward to set. For datasets without clear periodicity, $p$ must be tuned as a hyperparameter, and the attention-based variant (LSTNet-Attn) offers an alternative that avoids this requirement. Future work directions include automatic period detection and incorporating variable-level attribute information into the convolutional layer.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>Traffic</td>
          <td>17,544 x 862</td>
          <td>California DoT highway occupancy, hourly, 2015-2016</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Solar-Energy</td>
          <td>52,560 x 137</td>
          <td>Solar power from 137 PV plants in Alabama, 10-min intervals, 2006</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Electricity</td>
          <td>26,304 x 321</td>
          <td>kWh consumption for 321 clients, hourly, 2012-2014</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Exchange-Rate</td>
          <td>7,588 x 8</td>
          <td>Daily exchange rates for 8 countries, 1990-2016</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available via the <a href="https://github.com/laiguokun/LSTNet">GitHub repository</a>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam</li>
<li>Dropout: 0.1 or 0.2 after each layer except input and output</li>
<li>Window size $q$: grid search over ${2^0, 2^1, \ldots, 2^9}$</li>
<li>Skip length $p$: set to 24 for Traffic/Electricity; tuned from $2^1$ to $2^6$ for Solar-Energy/Exchange-Rate</li>
<li>Objective: L2 loss (Eq. 7) or L1 loss (Eq. 9), selected via validation</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Hidden dimensions (Recurrent/CNN): ${50, 100, 200}$</li>
<li>Hidden dimensions (Recurrent-skip): ${20, 50, 100}$</li>
<li>AR regularization: ${0.1, 1, 10}$</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best LSTNet RSE</th>
          <th>Baseline (RNN-GRU)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Solar-Energy (h=24)</td>
          <td>0.4403 (Attn)</td>
          <td>0.4852</td>
          <td>9.2%</td>
      </tr>
      <tr>
          <td>Traffic (h=24)</td>
          <td>0.4973 (Skip)</td>
          <td>0.5633</td>
          <td>11.7%</td>
      </tr>
      <tr>
          <td>Electricity (h=24)</td>
          <td>0.1007 (Skip)</td>
          <td>0.1295</td>
          <td>22.2%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/laiguokun/LSTNet">LSTNet (laiguokun/LSTNet)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation (Python 2.7, PyTorch 0.3.0)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/laiguokun/multivariate-time-series-data">Multivariate Time Series Data (laiguokun/multivariate-time-series-data)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Preprocessed benchmark datasets (Traffic, Solar-Energy, Electricity, Exchange-Rate)</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Code and all four benchmark datasets are publicly available. Hyperparameter search ranges are fully specified.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lai, G., Chang, W.-C., Yang, Y., &amp; Liu, H. (2018). Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks. <em>The 41st International ACM SIGIR Conference on Research &amp; Development in Information Retrieval (SIGIR &lsquo;18)</em>, 95-104. <a href="https://doi.org/10.1145/3209978.3210006">https://doi.org/10.1145/3209978.3210006</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lai2018modeling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Modeling Long- and Short-Term Temporal Patterns with Deep Neural Networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lai, Guokun and Chang, Wei-Cheng and Yang, Yiming and Liu, Hanxiao}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The 41st International ACM SIGIR Conference on Research \&amp; Development in Information Retrieval}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{95--104}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3209978.3210006}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DGCNN: Dynamic Graph CNN for Point Cloud Learning</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/dgcnn-dynamic-graph-point-clouds/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/dgcnn-dynamic-graph-point-clouds/</guid><description>EdgeConv module learns point cloud features on dynamically recomputed k-NN graphs in feature space, achieving strong classification and segmentation results.</description><content:encoded><![CDATA[<h2 id="a-general-purpose-edge-convolution-module-for-point-cloud-learning">A General-Purpose Edge Convolution Module for Point Cloud Learning</h2>
<p>This is a <strong>Method</strong> paper that introduces EdgeConv, a neural network module for learning on point clouds. The key idea is to construct a local graph structure and define convolution-like operations over edges connecting neighboring points. Unlike prior <a href="/notes/machine-learning/model-architectures/relational-inductive-biases-deep-learning-graph-networks/">graph neural network approaches</a> that operate on a fixed graph, DGCNN (Dynamic Graph CNN) recomputes the graph at each layer using k-nearest neighbors in feature space. This dynamic graph update allows the network to learn semantic groupings that differ from spatial proximity, enabling information propagation across long distances in the original point cloud. The model achieves strong results on classification (ModelNet40), part segmentation (ShapeNetPart), and semantic segmentation (S3DIS) benchmarks.</p>
<h2 id="why-point-clouds-need-topology-recovery">Why Point Clouds Need Topology Recovery</h2>
<p>Point clouds are the raw output of most 3D acquisition devices (<a href="https://en.wikipedia.org/wiki/Lidar">LiDAR</a>, stereo reconstruction) and serve as the simplest geometric representation for countless applications in graphics, robotics, and autonomous driving. However, point clouds inherently lack topological information: they are unordered sets of points with no connectivity structure.</p>
<p>Standard CNNs require grid-structured input, making them incompatible with irregular point cloud data. Volumetric approaches that discretize point clouds onto 3D grids introduce quantization artifacts and excessive memory usage. PointNet addressed this by operating on each point independently and aggregating with a symmetric function (max pooling), achieving permutation invariance. However, this independence means PointNet cannot capture local geometric structure.</p>
<p>PointNet++ partially addresses this by applying PointNet hierarchically in local neighborhoods, but it constructs neighborhoods based on Euclidean distances in the input space and does not update the graph structure during processing. The fundamental limitation is that treating points independently, even locally, prevents the model from learning the geometric relationships between points that carry important structural and semantic information.</p>
<h2 id="edgeconv-combining-local-geometry-with-global-structure">EdgeConv: Combining Local Geometry with Global Structure</h2>
<p>Given an $F$-dimensional point cloud $\mathbf{X} = \lbrace \mathbf{x}_1, \ldots, \mathbf{x}_n \rbrace \subseteq \mathbb{R}^F$, DGCNN constructs a directed graph $\mathcal{G} = (\mathcal{V}, \mathcal{E})$ as the $k$-nearest neighbor graph in $\mathbb{R}^F$, including self-loops so each node also points to itself. Edge features are defined as:</p>
<p>$$
\mathbf{x}_i&rsquo; = \square_{j:(i,j) \in \mathcal{E}} h_\Theta(\mathbf{x}_i, \mathbf{x}_j)
$$</p>
<p>where $h_\Theta$ is a learnable nonlinear function and $\square$ denotes a channel-wise symmetric aggregation operation (e.g., max or sum).</p>
<p>The choice of edge function $h_\Theta$ determines the model&rsquo;s properties. The authors analyze several options:</p>
<table>
  <thead>
      <tr>
          <th>Choice</th>
          <th>Edge function</th>
          <th>Properties</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Standard convolution</td>
          <td>$\theta_m \cdot \mathbf{x}_j$</td>
          <td>Requires fixed grid structure</td>
      </tr>
      <tr>
          <td>PointNet</td>
          <td>$h_\Theta(\mathbf{x}_i)$</td>
          <td>Global only, ignores local structure</td>
      </tr>
      <tr>
          <td>PointNet++</td>
          <td>$h_\Theta(\mathbf{x}_j)$</td>
          <td>Local only, loses global context</td>
      </tr>
      <tr>
          <td>Local difference</td>
          <td>$h_\Theta(\mathbf{x}_j - \mathbf{x}_i)$</td>
          <td>Local patches without global positioning</td>
      </tr>
      <tr>
          <td><strong>EdgeConv (this work)</strong></td>
          <td>$\bar{h}_\Theta(\mathbf{x}_i, \mathbf{x}_j - \mathbf{x}_i)$</td>
          <td><strong>Both local geometry and global structure</strong></td>
      </tr>
  </tbody>
</table>
<p>The concrete EdgeConv operation uses an asymmetric edge function that combines the point&rsquo;s own features $\mathbf{x}_i$ (global shape structure) with the relative difference $\mathbf{x}_j - \mathbf{x}_i$ (local neighborhood information):</p>
<p>$$
e&rsquo;_{ijm} = \text{ReLU}(\boldsymbol{\theta}_m \cdot (\mathbf{x}_j - \mathbf{x}_i) + \boldsymbol{\phi}_m \cdot \mathbf{x}_i)
$$</p>
<p>$$
x&rsquo;_{im} = \max_{j:(i,j) \in \mathcal{E}} e&rsquo;_{ijm}
$$</p>
<p>where $\boldsymbol{\Theta} = (\theta_1, \ldots, \theta_M, \phi_1, \ldots, \phi_M)$ are learnable parameters. This formulation can be implemented as a shared MLP followed by max pooling over neighbors.</p>
<h3 id="dynamic-graph-recomputation">Dynamic Graph Recomputation</h3>
<p>The defining feature of DGCNN is that the graph $\mathcal{G}^{(l)}$ is recomputed at each layer $l$ using k-NN in the current feature space, rather than being fixed based on input coordinates. This means:</p>
<ul>
<li>The receptive field grows to be as large as the diameter of the point cloud while remaining sparse.</li>
<li>Points that are far apart in Euclidean space but semantically similar (e.g., the two wings of an airplane) become neighbors in deeper feature spaces.</li>
<li>The model learns to construct the graph itself, rather than taking it as a fixed input.</li>
</ul>
<h3 id="permutation-and-translation-invariance">Permutation and Translation Invariance</h3>
<p>EdgeConv is permutation invariant because the max aggregation is a symmetric function. It has a &ldquo;partial&rdquo; translation invariance property: the local difference term $\mathbf{x}_j - \mathbf{x}_i$ is fully translation invariant, while the global term $\boldsymbol{\phi}_m \cdot \mathbf{x}_i$ is translation-dependent. Setting $\boldsymbol{\phi}_m = 0$ yields full translation invariance but loses global positioning information.</p>
<h2 id="benchmarks-classification-part-segmentation-and-scene-segmentation">Benchmarks: Classification, Part Segmentation, and Scene Segmentation</h2>
<h3 id="classification-on-modelnet40">Classification on ModelNet40</h3>
<p>The classification architecture uses four EdgeConv layers with output dimensions (64, 64, 128, 256), $k = 20$ nearest neighbors, and shortcut connections that concatenate all EdgeConv outputs into a $64 + 64 + 128 + 256 = 512$-dimensional per-point feature. A shared fully-connected layer (1024) aggregates these multi-scale features. Global max and sum pooling produce a 1D descriptor, followed by two fully-connected layers (512, 256) with dropout (probability 0.5). All layers use LeakyReLU and batch normalization. Input point clouds are rescaled to fit into the unit sphere.</p>
<p>Training uses SGD with momentum 0.9, initial learning rate 0.1, cosine annealing to 0.001, and batch size 32. Batch normalization momentum is 0.9 with no BN decay. Data augmentation includes random scaling and perturbation of object and point locations. The value of $k$ is selected using an 80/20 train/validation split, then the model is retrained on the full training set.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Mean Class Acc. (%)</th>
          <th>Overall Acc. (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PointNet</td>
          <td>86.0</td>
          <td>89.2</td>
      </tr>
      <tr>
          <td>PointNet++</td>
          <td>&ndash;</td>
          <td>90.7</td>
      </tr>
      <tr>
          <td>PointCNN</td>
          <td>88.1</td>
          <td>92.2</td>
      </tr>
      <tr>
          <td>PCNN</td>
          <td>&ndash;</td>
          <td>92.3</td>
      </tr>
      <tr>
          <td><strong>DGCNN (baseline, fixed graph)</strong></td>
          <td><strong>88.9</strong></td>
          <td><strong>91.7</strong></td>
      </tr>
      <tr>
          <td><strong>DGCNN</strong></td>
          <td><strong>90.2</strong></td>
          <td><strong>92.9</strong></td>
      </tr>
      <tr>
          <td><strong>DGCNN (2048 points)</strong></td>
          <td><strong>90.7</strong></td>
          <td><strong>93.5</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="model-complexity">Model Complexity</h3>
<p>DGCNN achieves a favorable tradeoff between model size, inference speed, and accuracy:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Model Size (MB)</th>
          <th>Time (ms)</th>
          <th>Accuracy (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PointNet (baseline)</td>
          <td>9.4</td>
          <td>6.8</td>
          <td>87.1</td>
      </tr>
      <tr>
          <td>PointNet</td>
          <td>40</td>
          <td>16.6</td>
          <td>89.2</td>
      </tr>
      <tr>
          <td>PointNet++</td>
          <td>12</td>
          <td>163.2</td>
          <td>90.7</td>
      </tr>
      <tr>
          <td>PCNN</td>
          <td>94</td>
          <td>117.0</td>
          <td>92.3</td>
      </tr>
      <tr>
          <td>DGCNN (baseline)</td>
          <td>11</td>
          <td>19.7</td>
          <td>91.7</td>
      </tr>
      <tr>
          <td>DGCNN</td>
          <td>21</td>
          <td>27.2</td>
          <td>92.9</td>
      </tr>
  </tbody>
</table>
<p>The DGCNN baseline outperforms PointNet++ by 1.0% while being 7x faster. The full DGCNN outperforms PCNN by 0.6% while being 4x faster with 4.5x fewer parameters.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: center">Centralization</th>
          <th style="text-align: center">Dynamic Graph</th>
          <th style="text-align: center">2048 Points</th>
          <th>Mean Class (%)</th>
          <th>Overall (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td>88.9</td>
          <td>91.7</td>
      </tr>
      <tr>
          <td style="text-align: center">x</td>
          <td style="text-align: center"></td>
          <td style="text-align: center"></td>
          <td>89.3</td>
          <td>92.2</td>
      </tr>
      <tr>
          <td style="text-align: center">x</td>
          <td style="text-align: center">x</td>
          <td style="text-align: center"></td>
          <td>90.2</td>
          <td>92.9</td>
      </tr>
      <tr>
          <td style="text-align: center">x</td>
          <td style="text-align: center">x</td>
          <td style="text-align: center">x</td>
          <td>90.7</td>
          <td>93.5</td>
      </tr>
  </tbody>
</table>
<p>The choice of $k$ also matters:</p>
<table>
  <thead>
      <tr>
          <th>$k$</th>
          <th>Mean Class Acc. (%)</th>
          <th>Overall Acc. (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>5</td>
          <td>88.0</td>
          <td>90.5</td>
      </tr>
      <tr>
          <td>10</td>
          <td>88.9</td>
          <td>91.4</td>
      </tr>
      <tr>
          <td>20</td>
          <td>90.2</td>
          <td>92.9</td>
      </tr>
      <tr>
          <td>40</td>
          <td>89.4</td>
          <td>92.4</td>
      </tr>
  </tbody>
</table>
<p>$k = 20$ performs best on 1024 points. Larger $k$ (e.g., 40) degrades performance because Euclidean distance poorly approximates geodesic distance at larger scales for a given point density.</p>
<h3 id="part-segmentation-on-shapenetpart">Part Segmentation on ShapeNetPart</h3>
<p>On the ShapeNetPart dataset (16,881 shapes, 16 categories, 50 part labels), DGCNN achieves 85.2% mean IoU, comparable to PointNet++ (85.1%) and PointCNN (86.1%). The model also demonstrates robustness to partial data, maintaining reasonable segmentation quality even when half of the points are removed.</p>
<h3 id="indoor-scene-segmentation-on-s3dis">Indoor Scene Segmentation on S3DIS</h3>
<p>On the Stanford Large-Scale 3D Indoor Spaces Dataset (6 indoor areas, 272 rooms, 13 semantic categories), DGCNN achieves 56.1% mean IoU and 84.1% overall accuracy using 6-fold cross-validation over the areas, outperforming PointNet (47.6% / 78.5%) and producing smoother segmentation boundaries. Each point is represented as a 9D vector (XYZ, RGB, and normalized spatial coordinates), with 4,096 points sampled per $1\text{m} \times 1\text{m}$ block during training.</p>
<h2 id="semantic-feature-spaces-and-future-directions">Semantic Feature Spaces and Future Directions</h2>
<p>A key qualitative finding is that the feature spaces learned by DGCNN in deeper layers capture semantic similarity rather than spatial proximity. Visualizations show that semantically similar structures (e.g., all legs of a table, or all wings of an airplane) are brought close together in feature space, even when they are far apart in the original 3D embedding. This property also transfers across shapes: features from one airplane&rsquo;s wing are close to the wing features of a different airplane in the learned feature space.</p>
<p>The authors identify several directions for future work:</p>
<ul>
<li><strong>Efficiency</strong>: Incorporating fast data structures (e.g., KD-trees) instead of computing pairwise distances for k-NN queries.</li>
<li><strong>Higher-order relationships</strong>: Considering tuples of points rather than only pairwise relationships.</li>
<li><strong>Non-shared transformations</strong>: Applying different transformations to different local patches rather than using shared weights.</li>
<li><strong>Abstract point clouds</strong>: Extending the approach to non-geometric applications like document retrieval and image processing, where the role of geometry in abstract feature spaces may provide new insights.</li>
</ul>
<p>The model has some limitations. On S3DIS, PointCNN achieves notably higher mean IoU (65.39% vs. 56.1%), suggesting room for improvement on large-scale scene segmentation. The dynamic k-NN computation adds overhead relative to fixed-graph approaches, though the overall model remains efficient.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification</td>
          <td>ModelNet40</td>
          <td>12,311 CAD models (40 categories)</td>
          <td>1,024 points uniformly sampled per model</td>
      </tr>
      <tr>
          <td>Part Segmentation</td>
          <td>ShapeNetPart</td>
          <td>16,881 shapes (16 categories, 50 parts)</td>
          <td>2,048 points per shape</td>
      </tr>
      <tr>
          <td>Scene Segmentation</td>
          <td>S3DIS</td>
          <td>272 rooms (13 categories)</td>
          <td>4,096 points per 1m x 1m block</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>k-NN graph construction</strong>: Pairwise distance matrix in feature space, $k = 20$ (classification) or $k = 40$ (2048 points).</li>
<li><strong>EdgeConv</strong>: Shared MLP on concatenated $[\mathbf{x}_i, \mathbf{x}_j - \mathbf{x}_i]$ features, followed by channel-wise max pooling over neighbors.</li>
<li><strong>Dynamic graph update</strong>: Graph recomputed from k-NN in feature space at each EdgeConv layer.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Classification</strong>: 4 EdgeConv layers (64, 64, 128, 256) + shortcut concatenation (512-dim) + shared FC (1024) + global max/sum pooling + FC (512, 256). 21 MB.</li>
<li><strong>Segmentation</strong>: Spatial transformer + 3 EdgeConv layers + shared FC (1024) aggregation + shortcut connections + FC (256, 256, 128).</li>
<li>All layers use LeakyReLU and batch normalization. Dropout 0.5 in final FC layers.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>DGCNN</th>
          <th>Best Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ModelNet40 Classification</td>
          <td>Overall Accuracy</td>
          <td>92.9%</td>
          <td>92.3% (PCNN)</td>
      </tr>
      <tr>
          <td>ShapeNetPart Segmentation</td>
          <td>Mean IoU</td>
          <td>85.2%</td>
          <td>86.1% (PointCNN)</td>
      </tr>
      <tr>
          <td>S3DIS Scene Segmentation</td>
          <td>Mean IoU</td>
          <td>56.1%</td>
          <td>65.39% (PointCNN)</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/WangYueFt/dgcnn">WangYueFt/dgcnn</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official TensorFlow and PyTorch implementations</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training used NVIDIA TITAN X GPUs. Distributed training (2 GPUs) for part segmentation.</li>
<li>Forward pass time: 27.2 ms per sample (1,024 points) on a single GPU.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, Y., Sun, Y., Liu, Z., Sarma, S. E., Bronstein, M. M., &amp; Solomon, J. M. (2019). Dynamic Graph CNN for Learning on Point Clouds. <em>ACM Transactions on Graphics</em>, 38(5), Article 146. <a href="https://doi.org/10.1145/3326362">https://doi.org/10.1145/3326362</a></p>
<p><strong>Code</strong>: <a href="https://github.com/WangYueFt/dgcnn">github.com/WangYueFt/dgcnn</a> (MIT License)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wang2019dynamic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Dynamic Graph CNN for Learning on Point Clouds}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wang, Yue and Sun, Yongbin and Liu, Ziwei and Sarma, Sanjay E. and Bronstein, Michael M. and Solomon, Justin M.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ACM Transactions on Graphics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{38}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">articleno</span>=<span style="color:#e6db74">{146}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3326362}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Defining Disentangled Representations via Group Theory</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/defining-disentangled-representations/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/defining-disentangled-representations/</guid><description>First formal definition of disentangled representations using group theory, connecting symmetry transformations to vector space decompositions.</description><content:encoded><![CDATA[<h2 id="a-theory-paper-grounding-disentanglement-in-symmetry">A Theory Paper Grounding Disentanglement in Symmetry</h2>
<p>This is a <strong>Theory</strong> paper that provides the first formal mathematical definition of disentangled representations. Rather than proposing a new learning algorithm or evaluating existing methods, the paper uses group theory and representation theory to define precisely what it means for a representation to be disentangled. The authors argue that the relevant structure of the world is captured by symmetry transformations, and that a disentangled representation must decompose into independent subspaces aligned with the decomposition of the corresponding symmetry group.</p>
<h2 id="why-disentangling-lacks-a-formal-foundation">Why Disentangling Lacks a Formal Foundation</h2>
<p>Disentangled representation learning aims to learn representations where distinct factors of variation in the data are separated into independent components. This idea has driven significant research, particularly through models like $\beta$-VAE and InfoGAN. Despite this progress, the field has lacked agreement on several fundamental questions: what constitutes the &ldquo;data generative factors,&rdquo; whether each factor should correspond to a single latent dimension or multiple dimensions, and whether a disentangled representation should have a unique axis alignment.</p>
<p>Without a formal definition, evaluating disentanglement methods remains subjective, relying on human intuition or metrics that encode different (sometimes contradictory) assumptions. For example, some metrics penalize multi-dimensional subspaces while others allow them. The lack of formal grounding also means there is no principled way to determine whether certain factors of variation (such as 3D rotations) can even be disentangled in principle.</p>
<p>The authors draw inspiration from physics, where symmetry transformations have been central to understanding world structure since <a href="https://en.wikipedia.org/wiki/Noether%27s_theorem">Noether&rsquo;s theorem</a> connected conservation laws to continuous symmetries. Gell-Mann&rsquo;s prediction of the $\Omega^{-}$ particle from symmetry-based classification of hadrons, and the unification of electricity and magnetism through shared symmetry transformations, illustrate the power of the symmetry perspective for generalization to new domains.</p>
<h2 id="symmetry-groups-as-the-foundation-for-disentanglement">Symmetry Groups as the Foundation for Disentanglement</h2>
<p>The core insight is that the &ldquo;data generative factors&rdquo; previously used to discuss disentanglement should be replaced by symmetry transformations of the world. The paper defines a disentangled representation through three key concepts.</p>
<h3 id="disentangled-group-action">Disentangled Group Action</h3>
<p>Given a group $G$ that decomposes as a <a href="https://en.wikipedia.org/wiki/Direct_product_of_groups">direct product</a> $G = G_1 \times G_2 \times \ldots \times G_n$, an action of $G$ on a set $X$ is <strong>disentangled</strong> if there exists a decomposition $X = X_1 \times X_2 \times \ldots \times X_n$ such that each subgroup $G_i$ acts only on $X_i$ and leaves all other components fixed:</p>
<p>$$(g_1, g_2) \cdot (v_1, v_2) = (g_1 \cdot_1 v_1, g_2 \cdot_2 v_2)$$</p>
<h3 id="disentangled-representation">Disentangled Representation</h3>
<p>Let $W$ be the set of world states with symmetry group $G$ acting on it. A generative process $b: W \to O$ produces observations, and an inference process $h: O \to Z$ produces representations. The composition $f = h \circ b$ maps world states to representations. The representation is <strong>disentangled</strong> if:</p>
<ol>
<li>There exists an action $\cdot: G \times Z \to Z$</li>
<li>The map $f: W \to Z$ is <strong><a href="https://en.wikipedia.org/wiki/Equivariant_map">equivariant</a></strong>: $g \cdot f(w) = f(g \cdot w)$ for all $g \in G, w \in W$</li>
<li>There exists a decomposition $Z = Z_1 \oplus Z_2 \oplus \ldots \oplus Z_n$ such that each $Z_i$ is affected only by $G_i$ and fixed by all other subgroups</li>
</ol>
<p>The equivariance condition ensures that the symmetry structure of the world is faithfully reflected in the representation space.</p>
<h3 id="linear-disentangled-representation">Linear Disentangled Representation</h3>
<p>When the group action on $Z$ is additionally constrained to be linear, the representation becomes a <strong>linear disentangled representation</strong>. This leverages <a href="https://en.wikipedia.org/wiki/Group_representation">group representation theory</a>, where the action is described by a homomorphism $\rho: G \to GL(Z)$. The representation is linearly disentangled if it decomposes as a direct sum $\rho = \rho_1 \oplus \rho_2 \oplus \ldots \oplus \rho_n$, where each $\rho_i$ acts only on $Z_i$. In matrix terms, this means $\rho(g)$ takes a block-diagonal form.</p>
<p>For the irreducible representations of a direct product group $G = G_1 \times G_2$, disentanglement requires that each irreducible component $\rho_1 \otimes \rho_2$ has at most one non-trivial factor. This prevents any subspace from being jointly affected by multiple subgroups.</p>
<h2 id="grid-world-example-and-the-so3-counterexample">Grid World Example and the SO(3) Counterexample</h2>
<p>Since this is a theory paper, the &ldquo;experiments&rdquo; consist of worked examples that illustrate the definition.</p>
<h3 id="grid-world-verification">Grid World Verification</h3>
<p>The authors consider a grid world where an object can translate horizontally, vertically, and change color, with wraparound boundaries. The symmetry group decomposes as $G = G_x \times G_y \times G_c$, where each subgroup is isomorphic to the <a href="https://en.wikipedia.org/wiki/Cyclic_group">cyclic group</a> $C_N$.</p>
<p>A CCI-VAE model trained on observations from this world learns a representation that approximately satisfies the equivariance condition $f(x, y, c) \approx (\lambda_x x, \lambda_y y, \lambda_c c)$, where each subgroup acts independently on its corresponding subspace. The group structure (commutativity of actions) is approximately preserved, though the learned representation uses translation rather than linear action, and the cyclic structure is lost.</p>
<p>For a linear disentangled representation, the map $f(x, y, c) = (e^{2\pi i x / N}, e^{2\pi i y / N}, e^{2\pi i c / N})$ over $\mathbb{C}^3$ provides an exact solution. The generator of each subgroup acts as multiplication by $e^{2\pi i / N}$ on its corresponding coordinate, yielding a truly linear and disentangled action. Equivalently, viewing $\rho$ as a representation over $\mathbb{R}^6$ (since $\mathbb{C}^3 \cong \mathbb{R}^6$), the group action is expressed using block-diagonal matrices of $2 \times 2$ rotation matrices, and each invariant subspace becomes two-dimensional.</p>
<h3 id="3d-rotations-cannot-be-disentangled">3D Rotations Cannot Be Disentangled</h3>
<p>The group of 3D rotations <a href="https://en.wikipedia.org/wiki/3D_rotation_group">$SO(3)$</a> has subgroups for rotations about the $x$, $y$, and $z$ axes. Intuitively, one might expect to disentangle these three rotation axes. However, rotations about different axes do not commute (rotating $90°$ about $x$ then $y$ gives a different result from $y$ then $x$), so $SO(3)$ cannot be written as a direct product of these subgroups. The definition correctly rules out disentangling along these lines.</p>
<p>Rotations can still be disentangled from other independent symmetries. For an object that can rotate and change color, the relevant group $G = SO(3) \times G_c$ is a valid direct product, so rotation and color form two disentangled subspaces (even though the rotation subspace is itself multi-dimensional and internally entangled).</p>
<h2 id="resolving-disagreements-and-defining-the-path-forward">Resolving Disagreements and Defining the Path Forward</h2>
<h3 id="backward-compatibility-with-existing-intuitions">Backward Compatibility with Existing Intuitions</h3>
<p>The paper evaluates its definition against three established dimensions of disentanglement:</p>
<p><strong>Modularity</strong> (each latent dimension encodes at most one factor): Satisfied by the new definition, with &ldquo;data generative factors&rdquo; replaced by &ldquo;disentangled actions of the symmetry group.&rdquo; The $SO(3)$ case shows where the new definition disagrees with naive intuition, correctly identifying that non-commuting factors cannot be disentangled.</p>
<p><strong>Compactness</strong> (each factor encoded by a single dimension): The new definition allows multi-dimensional subspaces, siding with approaches that permit distributed representations of individual factors. The dimensionality of each subspace is determined by the structure of the corresponding group representation.</p>
<p><strong>Explicitness</strong> (factors linearly decodable): The general definition does not require linearity. Linear disentangled representations are a strictly stronger condition, and the paper provides a separate formal definition for this case.</p>
<h3 id="key-consequences">Key Consequences</h3>
<p>The definition is relative to a particular decomposition of the symmetry group into subgroups. This has two implications. First, the same group may admit multiple decompositions, and different decompositions yield different disentangled representations (potentially useful for different downstream tasks). Second, identifying the &ldquo;natural&rdquo; decomposition is a separate problem that the authors leave to future work, suggesting that active perception and causal interventions may play a role.</p>
<p>The paper connects to Locatello et al. (2018), who proved that unsupervised learning of disentangled representations is impossible without inductive biases. The symmetry-based framework suggests that such biases could come from an agent&rsquo;s ability to interact with the world and discover which aspects remain invariant under various transformations.</p>
<h3 id="limitations">Limitations</h3>
<p>The paper explicitly focuses on defining disentanglement rather than solving the learning problem. It assumes that the symmetry group decomposes as a direct product of subgroups and that a useful decomposition is known. The authors acknowledge that relaxing these assumptions (e.g., discovering useful decompositions automatically) is important future work. The worked examples use toy environments, and bridging the gap to realistic data remains an open challenge.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a purely theoretical paper. The only empirical element is a qualitative demonstration using a CCI-VAE model on a grid world environment, where an object translates on a grid with wraparound and changes color through discrete steps on a circular hue axis.</p>
<h3 id="algorithms">Algorithms</h3>
<p>No new algorithms are proposed. The CCI-VAE model from Burgess et al. (2018) is used for the grid world demonstration. The paper&rsquo;s contribution is a set of formal definitions, not an algorithmic procedure.</p>
<h3 id="evaluation">Evaluation</h3>
<p>No quantitative evaluation is performed. The paper discusses how existing disentanglement metrics relate to the proposed definition, noting that they each capture different subsets of the three dimensions (modularity, compactness, explicitness) and that the formal definition provides a principled way to evaluate their relative merits.</p>
<h3 id="reproducibility-status-closed">Reproducibility Status: Closed</h3>
<p>This is a theory paper whose primary contribution is a set of formal definitions. The theoretical content (definitions, proofs, worked examples) is self-contained in the paper. No code, data, or models are released. The CCI-VAE demonstration uses a model from Burgess et al. (2018), but no implementation or training details specific to the grid world experiment are provided.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Higgins, I., Amos, D., Pfau, D., Racanière, S., Matthey, L., Rezende, D., &amp; Lerchner, A. (2018). Towards a Definition of Disentangled Representations. <em>arXiv preprint arXiv:1812.02230</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{higgins2018towards,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Towards a Definition of Disentangled Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Higgins, Irina and Amos, David and Pfau, David and Racani\`{e}re, S\&#39;{e}bastien and Matthey, Lo\&#34;{i}c and Rezende, Danilo and Lerchner, Alexander}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1812.02230}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Conformation Autoencoder for 3D Molecules</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/conformation-autoencoder/</link><pubDate>Sat, 11 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/conformation-autoencoder/</guid><description>An autoencoder that maps 3D molecular conformations to a continuous latent space using internal coordinates and graph attention networks.</description><content:encoded><![CDATA[<h2 id="a-method-for-learning-conformation-embeddings">A Method for Learning Conformation Embeddings</h2>
<p>This is a <strong>Method</strong> paper that introduces an autoencoder architecture for molecular conformations. The model converts the discrete 3D spatial arrangement of atoms (a conformation) in a given molecular graph into a continuous, fixed-size latent representation and back. The approach uses <a href="https://en.wikipedia.org/wiki/Z-matrix_(chemistry)">internal coordinates</a> (bond lengths, bond angles, dihedral angles) as input rather than Cartesian coordinates, making the representation inherently invariant to rigid translations and rotations.</p>
<h2 id="why-3d-structure-matters-for-molecular-modeling">Why 3D Structure Matters for Molecular Modeling</h2>
<p>Most deep learning methods for molecules operate on 2D representations: molecular graphs (atoms as nodes, bonds as edges) or <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. These representations capture connectivity and atom types but do not encode the 3D spatial arrangement of atoms. Many important molecular properties, such as the ability to fit inside a protein binding pocket or the shape-dependent pharmacological effect, depend on the molecule&rsquo;s possible energetically stable spatial arrangements (conformations).</p>
<p>Prior work has addressed either property prediction from fixed conformations (SchNet, Schütt et al., 2018) or conformation generation for a given molecular graph (Mansimov et al., 2019; Simm and Hernández-Lobato, 2019). This paper addresses a different gap: learning a continuous, fixed-size embedding of a conformation that is independent of molecule size and atom ordering, enabling both reconstruction and generation.</p>
<h2 id="internal-coordinates-and-set-based-encoding">Internal Coordinates and Set-Based Encoding</h2>
<p>The core innovation is a two-part architecture: a conformation-independent graph neural network and a conformation-dependent encoder/decoder that operates on internal coordinates.</p>
<h3 id="internal-coordinate-representation">Internal Coordinate Representation</h3>
<p>Instead of Cartesian coordinates, conformations are represented as a set of internal coordinates:</p>
<p>$$
\Xi = (\mathcal{D}, \Phi, \Psi)
$$</p>
<p>where $\mathcal{D} = \{d_1, \ldots, d_{N_\mathcal{D}}\}$ are bond lengths, $\Phi = \{\phi_1, \ldots, \phi_{N_\Phi}\}$ are bond angles, and $\Psi = \{\psi_1, \ldots, \psi_{N_\Psi}\}$ are dihedral angles. This representation is invariant to rotations and rigid translations and can always be converted to and from Cartesian coordinates.</p>
<h3 id="molecular-graph-encoder">Molecular Graph Encoder</h3>
<p>A Graph Neural Network extracts conformation-independent node embeddings from the molecular graph. The molecular graph $\mathcal{G} = (\mathcal{V}, \mathcal{E})$ uses node features $v_i \in \mathbb{R}^{F_v}$ encoding atom properties (element type, charge) and edge features $\mathbf{e}_{i,j} \in \mathbb{R}^{F_e}$ encoding bond type (single, double, triple, or aromatic). The architecture combines an edge-conditioned convolution (EConv) layer to encode bond-type information with multiple Graph Attention Network (GAT) layers:</p>
<p>$$
\mathbf{h}_i^l = \mathbf{GAT}^{l-1} \circ \cdots \circ \mathbf{GAT}^1 \circ \text{EConv}(\mathbf{h}_i^0)
$$</p>
<p>where $\mathbf{h}_i^0 = v_i \in \mathbb{R}^{F_v}$ are the initial atom features. The GAT attention coefficients are:</p>
<p>$$
\alpha_{i,j} = \frac{\exp\left(\sigma\left(\mathbf{a}^T [\boldsymbol{\Theta}\mathbf{h}_i | \boldsymbol{\Theta}\mathbf{h}_j]\right)\right)}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} \exp\left(\sigma\left(\mathbf{a}^T [\boldsymbol{\Theta}\mathbf{h}_i | \boldsymbol{\Theta}\mathbf{h}_k]\right)\right)}
$$</p>
<p>Each GAT layer updates node embeddings using the attention weights:</p>
<p>$$
\mathbf{h}&rsquo;_i = \alpha_{i,i}\boldsymbol{\Theta}\mathbf{h}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\boldsymbol{\Theta}\mathbf{h}_j
$$</p>
<p>The EConv layer incorporates edge (bond-type) information via a learned filter:</p>
<p>$$
\mathbf{h}&rsquo;_i = \boldsymbol{\Theta}\mathbf{h}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{h}_j \cdot \mathrm{f}_{\boldsymbol{\Theta}}(\mathbf{e}_{i,j})
$$</p>
<p>where $\mathrm{f}_{\boldsymbol{\Theta}}$ is a multi-layer perceptron.</p>
<h3 id="permutation-invariant-conformation-encoder">Permutation-Invariant Conformation Encoder</h3>
<p>The conformation encoder uses a Deep Sets-style architecture (Zaheer et al., 2017) to achieve permutation invariance. Three separate neural networks encode each type of internal coordinate, conditioned on the corresponding node embeddings:</p>
<p>$$
z_\Xi = \frac{1}{N_\mathcal{D} + N_\Phi + N_\Psi} \left(\sum_{d \in \mathcal{D}} \rho_\Theta^{(\mathcal{D})}(\mathcal{H}, d) + \sum_{\phi \in \Phi} \rho_\Theta^{(\Phi)}(\mathcal{H}, \phi) + \sum_{\psi \in \Psi} \rho_\Theta^{(\Psi)}(\mathcal{H}, \psi)\right)
$$</p>
<p>Each encoding function $\rho_\Theta$ takes both the internal coordinate value and the node embeddings of the involved atoms as input. The resulting conformation embedding $z_\Xi \in \mathbb{R}^{F_z}$ has a fixed dimensionality regardless of molecule size.</p>
<h3 id="conformation-decoder-and-loss">Conformation Decoder and Loss</h3>
<p>Three decoder networks $\delta_\Theta^{(\mathcal{D})}$, $\delta_\Theta^{(\Phi)}$, and $\delta_\Theta^{(\Psi)}$ reconstruct internal coordinates from the conformation embedding, conditioned on the node embeddings. The reconstruction loss is:</p>
<p>$$
\mathcal{C}_\Xi = \frac{1}{N_\mathcal{D}} \sum_{d \in \mathcal{D}} |d - \hat{d}|_2^2 + \frac{1}{N_\Phi} \sum_{\phi \in \Phi} |\phi - \hat{\phi}|_2^2 + \frac{1}{N_\Psi} \sum_{\psi \in \Psi} \min\left(|\psi - \hat{\psi}|_2^2, 2\pi - |\psi - \hat{\psi}|_2^2\right)
$$</p>
<p>The dihedral angle loss uses a periodic distance to account for angular periodicity. The model can be extended to a <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoder (VAE)</a> by applying the reparameterization trick from Kingma and Welling (2013).</p>
<h2 id="conformer-generation-and-spatial-optimization-experiments">Conformer Generation and Spatial Optimization Experiments</h2>
<h3 id="dataset-and-training">Dataset and Training</h3>
<p>The model was trained on the PubChem3D dataset (Bolton et al., 2011), which contains organic molecules with up to 50 heavy atoms with multiple conformations generated by the OMEGA forcefield software.</p>
<h3 id="reconstruction-quality">Reconstruction Quality</h3>
<p>Upon convergence, the model reconstructs conformations with low RMSD to the input. The median energetic difference between input and reconstructed conformations is approximately 80 kcal/mol (evaluated using the <a href="https://en.wikipedia.org/wiki/Merck_molecular_force_field">MMFF94</a> forcefield via <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>), corresponding to small deviations from local minima without atom clashes.</p>
<h3 id="latent-space-structure">Latent Space Structure</h3>
<p>The learned latent space exhibits meaningful clustering: similar conformations map to nearby points, while distinct conformations separate. Principal component analysis of 200 conformations of a small molecule reveals clear conformational clusters in the first two principal components.</p>
<h3 id="conformer-generation-via-vae">Conformer Generation via VAE</h3>
<p>The variational autoencoder variant can sample diverse conformers from the learned distribution. Comparing the average inter-conformer RMSD (icRMSD) for 200 sampled conformers per molecule against the ETKDG algorithm (Riniker and Landrum, 2015) implemented in RDKit, the model achieves comparable diversity with a slightly higher average icRMSD of 0.07 Angstrom.</p>
<h3 id="multi-objective-molecular-optimization">Multi-Objective Molecular Optimization</h3>
<p>By combining the conformation embedding with a continuous molecular structure embedding (<a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a>, Winter et al., 2019), the model enables joint optimization over both molecular graph and conformation. Using <a href="https://en.wikipedia.org/wiki/Particle_swarm_optimization">particle swarm optimization</a> (Kennedy and Eberhart, 1995) to maximize QED (drug-likeness, values between 0 and 1) and asphericity (deviation from spherical shape, values between 0 and 1), starting from aspirin (combined score 0.76), the method finds molecules with a combined score of 1.82 after 50 iterations.</p>
<h2 id="compact-conformation-encoding-with-practical-applications">Compact Conformation Encoding with Practical Applications</h2>
<p>The conformation autoencoder produces fixed-size latent representations of molecular 3D structures that are invariant to molecule size, atom ordering, and rigid transformations. The key findings are:</p>
<ol>
<li><strong>Meaningful latent space</strong>: Conformational similarity is preserved in the embedding space, enabling clustering and interpolation.</li>
<li><strong>Diverse conformer generation</strong>: The VAE variant generates conformer ensembles with diversity comparable to established force-field-based methods.</li>
<li><strong>Joint optimization</strong>: Combining conformation and structure embeddings enables multi-objective optimization over both molecular graph and spatial arrangement.</li>
</ol>
<p>Limitations include the relatively small energy evaluation (MMFF94 only), the lack of comparison with quantum mechanical energy evaluations, and the proof-of-concept nature of the spatial optimization experiments. The approach also relies on the quality of the internal coordinate representation, which may lose information about ring conformations and other constrained geometries.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>PubChem3D</td>
          <td>Multiple conformations per molecule</td>
          <td>Organic molecules, up to 50 heavy atoms</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>PubChem3D holdout</td>
          <td>Subset</td>
          <td>Same distribution as training</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Graph Neural Network: EConv + multiple GAT layers</li>
<li>Conformation encoder: Deep Sets architecture with three coordinate-specific encoders</li>
<li>VAE: Reparameterization trick for probabilistic sampling</li>
<li>Optimization: Particle Swarm Optimization for multi-objective design</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Conformation-independent: EConv + GAT layers for node embeddings</li>
<li>Conformation-dependent: Three encoder/decoder feed-forward networks per coordinate type</li>
<li>Latent dimension $F_z$ is fixed (exact value not specified in the workshop paper)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Median energy difference</td>
          <td>~80 kcal/mol</td>
          <td>Input conformations</td>
          <td>MMFF94 forcefield</td>
      </tr>
      <tr>
          <td>icRMSD difference vs ETKDG</td>
          <td>+0.07 Angstrom</td>
          <td>ETKDG (RDKit)</td>
          <td>200 conformers per molecule</td>
      </tr>
      <tr>
          <td>Combined QED+asphericity</td>
          <td>1.82</td>
          <td>0.76 (aspirin)</td>
          <td>After 50 optimization iterations</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware details are not specified in the workshop paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem3D</a></td>
          <td>Dataset</td>
          <td>Public domain</td>
          <td>NIH public database; conformations generated by OMEGA (Hawkins et al., 2010)</td>
      </tr>
      <tr>
          <td><a href="https://arxiv.org/abs/2101.01618">arXiv preprint</a></td>
          <td>Paper</td>
          <td>arXiv license</td>
          <td>6-page workshop paper, open access</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status: Partially Reproducible.</strong> The training dataset (PubChem3D) is publicly available, and the architecture is described in sufficient detail for reimplementation. No source code, pre-trained weights, or exact hyperparameters (latent dimension $F_z$, learning rate, number of GAT layers) are released. The workshop paper format (6 pages) limits the level of experimental detail provided.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Winter, R., Noé, F., &amp; Clevert, D.-A. (2020). Auto-Encoding Molecular Conformations. <em>Machine Learning for Molecules Workshop, NeurIPS 2020</em>.</p>
<p><strong>Publication</strong>: Machine Learning for Molecules Workshop at NeurIPS 2020</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{winter2021auto,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Molecular Conformations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Winter, Robin and No\&#39;{e}, Frank and Clevert, Djork-Arn\&#39;{e}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2101.01618}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>T5: Exploring Transfer Learning Limits</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</guid><description>Raffel et al. systematically study transfer learning for NLP with a text-to-text framework, ablating architectures, objectives, data, and multi-task mixing.</description><content:encoded><![CDATA[<h2 id="a-systematic-study-of-nlp-transfer-learning">A systematic study of NLP transfer learning</h2>
<p>This is a <strong>systematization paper</strong> that provides a comprehensive empirical survey of transfer learning techniques for NLP. Rather than proposing a single new method, T5 introduces a unified text-to-text framework and uses it as a testbed to systematically compare pre-training objectives, architectures, unlabeled data sources, transfer approaches, and multi-task mixing strategies. The scale of the ablation study (covering dozens of configurations) and the release of C4, pre-trained models, and code make it both a reference guide and a resource.</p>
<h2 id="unifying-nlp-tasks-as-text-to-text">Unifying NLP tasks as text-to-text</h2>
<p>The core design decision is to cast every NLP task as a text-to-text problem: both the input and output are text strings, with a task-specific prefix. Classification, regression, summarization, translation, and question answering all use the same model, loss function (cross-entropy on output tokens), and decoding procedure. This simplicity enables fair comparison across tasks and training strategies.</p>
<p>The model architecture is a standard encoder-decoder Transformer. The paper finds that this form outperforms decoder-only (language model) and encoder-only (BERT-style) variants in the text-to-text setting, while having similar computational cost to decoder-only models despite twice the parameters (the encoder processes the input only once, then the decoder attends to it).</p>
<h2 id="multi-task-mixing-strategies-and-findings">Multi-task mixing: strategies and findings</h2>
<p>The most thesis-relevant contribution is the systematic ablation of multi-task mixing strategies (Section 3.5.2). When training on multiple tasks simultaneously (which in the text-to-text framework simply means mixing data from different sources), the central question is how to set the proportion of data from each task.</p>
<h3 id="three-mixing-strategies">Three mixing strategies</h3>
<p><strong>Examples-proportional mixing.</strong> Sample in proportion to each dataset&rsquo;s size, with an artificial cap $K$ on the maximum dataset size. Without the cap, the unsupervised pre-training data (orders of magnitude larger) would dominate all batches. The mixing rate for task $m$ is:</p>
<p>$$
r_{m} = \frac{\min(e_{m}, K)}{\sum_{n} \min(e_{n}, K)}
$$</p>
<p>where $e_{m}$ is the number of examples in task $m$&rsquo;s dataset.</p>
<p><strong>Temperature-scaled mixing.</strong> Raise each mixing rate $r_{m}$ to the power $1/T$ and renormalize. At $T=1$ this equals examples-proportional mixing; as $T$ increases, proportions approach equal mixing. Uses a large cap $K = 2^{21}$.</p>
<p><strong>Equal mixing.</strong> Sample uniformly from all tasks. Included as a negative reference: the model overfits on low-resource tasks and underfits on high-resource tasks.</p>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Mixing strategy</th>
          <th>GLUE</th>
          <th>CNN/DM</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
          <th>EnDe</th>
          <th>EnFr</th>
          <th>EnRo</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Baseline (pre-train/fine-tune)</td>
          <td>83.28</td>
          <td>19.24</td>
          <td>80.88</td>
          <td>71.36</td>
          <td>26.98</td>
          <td>39.82</td>
          <td>27.65</td>
      </tr>
      <tr>
          <td>Equal</td>
          <td>76.13</td>
          <td>19.02</td>
          <td>76.51</td>
          <td>63.37</td>
          <td>23.89</td>
          <td>34.31</td>
          <td>26.78</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{18}$</td>
          <td>81.67</td>
          <td>19.07</td>
          <td>78.17</td>
          <td>67.94</td>
          <td>24.57</td>
          <td>35.19</td>
          <td>27.39</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{19}$</td>
          <td>81.42</td>
          <td>19.24</td>
          <td>79.78</td>
          <td>67.30</td>
          <td>25.21</td>
          <td>36.30</td>
          <td>27.76</td>
      </tr>
      <tr>
          <td>Temperature-scaled, $T=2$</td>
          <td>81.90</td>
          <td>19.28</td>
          <td>79.42</td>
          <td>69.92</td>
          <td>25.42</td>
          <td>36.72</td>
          <td>27.20</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings on mixing:</strong></p>
<ol>
<li>
<p><strong>Multi-task training underperforms pre-train-then-fine-tune on most tasks.</strong> No mixing strategy matches the baseline of unsupervised pre-training followed by task-specific fine-tuning.</p>
</li>
<li>
<p><strong>Equal mixing is worst.</strong> It dramatically degrades performance, confirming that proportions matter.</p>
</li>
<li>
<p><strong>There exists a task-specific sweet spot for the cap $K$.</strong> Most tasks have an optimal $K$ value; larger or smaller values hurt. The exception is very high-resource tasks (WMT English-French) that always benefit from higher mixing proportions.</p>
</li>
<li>
<p><strong>Temperature scaling at $T=2$ provides the best single compromise.</strong> It achieves reasonable performance across all tasks without requiring per-task tuning of $K$.</p>
</li>
<li>
<p><strong>Multi-task pre-training followed by fine-tuning closes the gap.</strong> When multi-task training is used as pre-training (not as the final training stage), followed by task-specific fine-tuning, performance becomes comparable to unsupervised pre-training alone. This suggests that multi-task exposure during pre-training provides useful early signal without the negative effects of forcing a single model to perform all tasks simultaneously.</p>
</li>
<li>
<p><strong>&ldquo;Leave-one-out&rdquo; training works.</strong> Pre-training on a multi-task mixture that excludes a target task, then fine-tuning on it, produces only slightly worse results. This indicates that multi-task pre-training builds general capabilities that transfer to unseen tasks without dramatic task interference.</p>
</li>
</ol>
<h2 id="data-repetition-degrades-performance">Data repetition degrades performance</h2>
<p>The paper also systematically tests the effect of pre-training data set size by truncating C4 and training over repeated data:</p>
<table>
  <thead>
      <tr>
          <th>Unique tokens</th>
          <th>Repeats</th>
          <th>GLUE</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Full dataset</td>
          <td>0</td>
          <td>83.28</td>
          <td>80.88</td>
          <td>71.36</td>
      </tr>
      <tr>
          <td>$2^{29}$</td>
          <td>64</td>
          <td>82.87</td>
          <td>80.97</td>
          <td>72.03</td>
      </tr>
      <tr>
          <td>$2^{27}$</td>
          <td>256</td>
          <td>82.62</td>
          <td>79.78</td>
          <td>69.97</td>
      </tr>
      <tr>
          <td>$2^{25}$</td>
          <td>1,024</td>
          <td>79.55</td>
          <td>76.27</td>
          <td>64.76</td>
      </tr>
      <tr>
          <td>$2^{23}$</td>
          <td>4,096</td>
          <td>76.34</td>
          <td>70.92</td>
          <td>59.29</td>
      </tr>
  </tbody>
</table>
<p>Performance degrades as data shrinks, with 64 repeats showing limited effects but 1,024+ repeats causing significant degradation. Training loss curves confirm memorization at high repetition counts. The paper recommends using large, diverse pre-training datasets whenever possible.</p>
<h2 id="scaling-and-final-configuration">Scaling and final configuration</h2>
<p>The paper compares scaling strategies: more data, larger models, and ensembles. Training a larger model for fewer steps generally outperforms training a smaller model on more data. Ensembles of independently pre-trained and fine-tuned models provide orthogonal gains.</p>
<p>The final T5-11B model combines the best choices from all ablations: encoder-decoder architecture, span corruption objective, C4 pre-training data, multi-task pre-training followed by fine-tuning, and scaling to 11B parameters trained on over 1 trillion tokens. It achieves state-of-the-art results on GLUE (90.3 average), SuperGLUE (88.9, near human performance of 89.8), SQuAD, and CNN/Daily Mail. It does not achieve state-of-the-art on WMT translation tasks, where methods using backtranslation and cross-lingual pre-training retain the lead.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The T5 paper&rsquo;s multi-task mixing findings are its most enduring contribution beyond the model itself. The core lessons: proportions matter enormously (equal mixing fails), examples-proportional mixing with a cap is a reasonable default, temperature scaling provides a single-knob alternative, and multi-task pre-training followed by fine-tuning can match pure unsupervised pre-training.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>All ablations use the same encoder-decoder architecture. Findings may not transfer to decoder-only models that dominate current practice.</li>
<li>The multi-task mixing experiments treat each task as a separate &ldquo;domain.&rdquo; Interactions between similar tasks (e.g., multiple classification tasks) are not isolated.</li>
<li>The paper does not provide a principled method for choosing $K$ or $T$; both require empirical search.</li>
<li>C4 has known quality issues (templated text, noisy content) that have been addressed in later datasets.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, pre-trained models, and the C4 dataset are all publicly released.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>C4 (Colossal Clean Crawled Corpus)</td>
          <td>~750 GB</td>
          <td>Heuristically cleaned Common Crawl</td>
      </tr>
      <tr>
          <td>Downstream</td>
          <td>GLUE, SuperGLUE, SQuAD, CNN/DM, WMT (EnDe, EnFr, EnRo)</td>
          <td>Standard splits</td>
          <td>Text-to-text format</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>Encoder-decoder Transformer. Sizes: Base (220M), Small (60M), Large (770M), 3B, 11B. Baseline uses Base size. SentencePiece vocabulary with 32K tokens. Pre-trained for $2^{19}$ steps, fine-tuned for $2^{18}$ steps on individual tasks.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Multi-task mixing: examples-proportional with cap $K \in {2^{16}, \ldots, 2^{21}}$, temperature-scaled with $T \in {2, 4, 8}$, and equal mixing. Unsupervised objective: span corruption (mean span length 3, 15% corruption rate). Training with Adafactor optimizer, inverse square root learning rate schedule.</p>
<h3 id="hardware">Hardware</h3>
<p>All models trained using Mesh TensorFlow on TPU slices. T5-11B pre-trained for 1M steps with batch size $2^{11}$ sequences of length 512 (~1 trillion tokens total). Exact TPU pod configurations per experiment not detailed.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer">T5 Code</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official TensorFlow implementation (JAX successor: T5X)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints">T5 Models</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained checkpoints (Small through 11B)</td>
      </tr>
      <tr>
          <td><a href="https://www.tensorflow.org/datasets/catalog/c4">C4 Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>~750 GB cleaned Common Crawl, via TensorFlow Datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{raffel2020exploring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{21}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{140}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Scaling Data-Constrained Language Models</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</guid><description>Muennighoff et al. extend Chinchilla scaling laws to repeated data, finding up to 4 epochs cause negligible loss and 16 epochs mark diminishing returns.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-scaling-under-data-constraints">An empirical study of scaling under data constraints</h2>
<p>This is a <strong>discovery paper</strong> that systematically investigates what happens when language models are trained for multiple epochs on repeated data. It extends the Chinchilla scaling laws to the data-constrained regime by proposing a new scaling formula that accounts for the diminishing value of repeated tokens, validated across 400+ training runs ranging from 10M to 9B parameters and up to 1500 epochs.</p>
<h2 id="running-out-of-unique-training-data">Running out of unique training data</h2>
<p>The Chinchilla scaling laws assume unlimited unique data: for a given compute budget, there exists an optimal balance of model parameters and training tokens. But extrapolating these laws to larger models implies data requirements that exceed what is available. Villalobos et al. estimated that high-quality English text would be exhausted by 2024 under Chinchilla-optimal scaling. Most prior large language models trained for a single epoch, and some work explicitly warned against data reuse. The Galactica models (trained for 4.25 epochs) showed that multi-epoch training could work, but no systematic study had quantified the tradeoff between repeated data and fresh data, or how to allocate compute optimally when data is finite.</p>
<h2 id="effective-data-with-exponential-decay-for-repetition">Effective data with exponential decay for repetition</h2>
<p>The paper generalizes the Chinchilla scaling law by replacing raw token count $D$ with an effective data term $D&rsquo;$ that accounts for the diminishing value of repeated tokens:</p>
<p>$$
L(N, D) = \frac{A}{N&rsquo;^{\alpha}} + \frac{B}{D&rsquo;^{\beta}} + E
$$</p>
<p>where the effective data is:</p>
<p>$$
D&rsquo; = U_{D} + U_{D} R_{D}^{<em>} \left(1 - e^{-R_{D}/R_{D}^{</em>}}\right)
$$</p>
<p>Here $U_{D}$ is the number of unique tokens, $R_{D}$ is the number of repetitions (epochs minus 1), and $R_{D}^{<em>}$ is a learned constant representing the &ldquo;half-life&rdquo; of data repetition. When $R_{D} = 0$ (single epoch), $D&rsquo; = U_{D} = D$ and the formula reduces to standard Chinchilla. When $R_{D} \ll R_{D}^{</em>}$, repeated data is worth almost the same as fresh data. As $R_{D}$ grows large, the value of repeated tokens decays to zero, and $D&rsquo;$ saturates at $U_{D}(1 + R_{D}^{<em>})$, meaning no amount of repetition can substitute for more than $R_{D}^{</em>}$ epochs&rsquo; worth of fresh data.</p>
<p>A symmetric formula handles excess parameters:</p>
<p>$$
N&rsquo; = U_{N} + U_{N} R_{N}^{<em>} \left(1 - e^{-R_{N}/R_{N}^{</em>}}\right)
$$</p>
<p>where $U_{N}$ is the compute-optimal parameter count for $U_{D}$ unique tokens and $R_{N}$ measures how much the model exceeds that count. The fitted values are $R_{D}^{<em>} \approx 15.0$ (data repetition half-life at ~16 epochs) and $R_{N}^{</em>} \approx 5.3$ (excess parameters decay faster than repeated data).</p>
<h2 id="experiments-across-400-models">Experiments across 400+ models</h2>
<p><strong>Scale.</strong> Models from 10M to 9B parameters, trained for up to 1500 epochs. Three experimental protocols: fixed unique data (100M, 400M, 1.5B tokens), fixed FLOPs, and parametric fitting across all runs. Training on C4 (English web text) with GPT-2 architecture decoder-only transformers.</p>
<h3 id="resource-allocation-epochs-scale-faster-than-parameters">Resource allocation: epochs scale faster than parameters</h3>
<p>With fixed unique data, results show that more than 50% loss reduction is possible by training beyond one epoch and increasing model size beyond the single-epoch optimum. The data-constrained efficient frontier recommends allocating most additional compute to more epochs rather than more parameters, because excess parameters decay faster ($R_{N}^{<em>} &lt; R_{D}^{</em>}$). This contrasts with Chinchilla, which recommends scaling both equally.</p>
<p>A concrete validation: training the data-constrained compute-optimal model for $9.3 \times 10^{21}$ FLOPs with 25B unique tokens, the recommended allocation (27% fewer parameters, more epochs) achieves better loss and downstream performance than the Chinchilla-optimal allocation.</p>
<h3 id="resource-return-the-4-epoch-safe-zone-and-16-epoch-half-life">Resource return: the 4-epoch safe zone and 16-epoch half-life</h3>
<table>
  <thead>
      <tr>
          <th>Epochs</th>
          <th>Loss impact</th>
          <th>Downstream impact</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1 (baseline)</td>
          <td>Optimal</td>
          <td>Optimal</td>
      </tr>
      <tr>
          <td>Up to 4</td>
          <td>Negligible (+0.5% loss)</td>
          <td>No significant difference</td>
      </tr>
      <tr>
          <td>~16 ($R_{D}^{*}$)</td>
          <td>Diminishing returns begin sharply</td>
          <td>Measurable degradation</td>
      </tr>
      <tr>
          <td>Beyond 16</td>
          <td>Returns decay to near zero</td>
          <td>Significant degradation</td>
      </tr>
      <tr>
          <td>Extreme (44+)</td>
          <td>Training can diverge</td>
          <td>Failure</td>
      </tr>
  </tbody>
</table>
<p>The 8.7B parameter model trained for 4 epochs ($D_{C} = 44$B unique tokens) finishes with only 0.5% higher validation loss than the single-epoch model ($D_{C} = 178$B unique tokens). Beyond 16 epochs, each repeated token retains only $1 - 1/e \approx 63%$ of the value of a fresh token, meaning roughly 37% of value is lost per repetition cycle at the half-life point.</p>
<h3 id="complementary-strategies-code-augmentation-and-filtering">Complementary strategies: code augmentation and filtering</h3>
<p>When data is limited, two strategies can extend the effective dataset:</p>
<p><strong>Code augmentation.</strong> Mixing Python code from The Stack with natural language data. Up to 50% code (42B tokens) shows no degradation on natural language benchmarks, effectively providing a 2x increase in useful training data. Some tasks (WebNLG generation, bAbI reasoning) actually improve with code, possibly because code trains long-range state-tracking capabilities.</p>
<p><strong>Filtering relaxation.</strong> Perplexity filtering (keeping the 25% lowest-perplexity samples) is effective on noisy datasets, but deduplication filtering does not improve downstream performance (though it may reduce memorization). The recommendation: reserve aggressive filtering for noisy data sources; for clean datasets, more data through reduced filtering is better than less data through strict filtering.</p>
<p><strong>Combined strategy</strong>: doubling available data with code and then repeating for 4 epochs yields 8x more training tokens with performance expected to match 8x more unique data.</p>
<h2 id="key-findings-and-limitations">Key findings and limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Multi-epoch training is beneficial, not harmful, up to moderate repetition counts.</li>
<li>The data-constrained scaling law accurately predicts loss under repetition using an exponential decay formulation.</li>
<li>Compute should be allocated to epochs faster than parameters when data is constrained.</li>
<li>Code augmentation and selective filtering extend effective data without quality degradation.</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>All experiments use the GPT-2 transformer architecture; applicability to other architectures or modalities is untested.</li>
<li>Only the entire dataset is repeated uniformly. Selectively repeating subsets (e.g., high-value data for more epochs) is not modeled.</li>
<li>Hyperparameter sensitivity (learning rate, dropout) to epoch count is unexplored. Higher learning rates may cause earlier onset of diminishing returns.</li>
<li>Focused on English text. Cross-lingual augmentation effects are not studied.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, models, datasets, and hyperparameters are all publicly released under Apache 2.0.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>C4 (English)</td>
          <td>Varies by experiment</td>
          <td>Fixed unique data: 100M, 400M, 1.5B tokens</td>
      </tr>
      <tr>
          <td>Code augmentation</td>
          <td>The Stack (Python)</td>
          <td>Up to 42B tokens</td>
          <td>Mixed with natural language</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>19 NL tasks</td>
          <td>Standard splits</td>
          <td>Zero to five-shot, 114 scores per model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data-constrained scaling law: $D&rsquo; = U_{D} + U_{D} R_{D}^{<em>}(1 - e^{-R_{D}/R_{D}^{</em>}})$ with $R_{D}^{<em>} \approx 15.0$, $R_{N}^{</em>} \approx 5.3$. Fitted using the methodology of Hoffmann et al. (2022) adapted for the repetition terms. 400+ training runs used for fitting.</p>
<h3 id="models">Models</h3>
<p>GPT-2 architecture decoder-only transformers with GPT-2 tokenizer. Sizes: 10M to 8.7B parameters. Cosine learning rate schedule (max 2e-4, decay to 2e-5), Adam optimizer ($\beta_2 = 0.999$), dropout 0.1, weight decay 0.1, gradient clipping at 1.0. bfloat16 precision. Trained using Megatron-DeepSpeed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Data-Constrained Optimal</th>
          <th>Chinchilla Optimal</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validation loss (9.3e21 FLOPs, 25B unique)</td>
          <td>Lower</td>
          <td>Higher</td>
          <td>27% fewer parameters</td>
      </tr>
      <tr>
          <td>Downstream (4 epochs vs 1)</td>
          <td>No significant difference</td>
          <td>Baseline</td>
          <td>8.7B params, 44B unique tokens</td>
      </tr>
      <tr>
          <td>Code augmentation (50% code)</td>
          <td>No NL degradation</td>
          <td>Baseline</td>
          <td>Some tasks improve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Trained on the LUMI supercomputer (Finland) using AMD Instinct MI250X GPUs with data, tensor, and pipeline parallelism. Up to 256 GPUs (64 nodes) per run, with up to 2,200 nodes (~8,800 GPUs) used in parallel across all concurrent runs. Total compute: approximately 3 million GPU hours. The cluster runs on 100% renewable hydroelectric energy.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/huggingface/datablations">datablations</a></td>
          <td>Code + Models + Data</td>
          <td>Apache 2.0</td>
          <td>All 400+ models, datasets, and training code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/TurkuNLP/Megatron-DeepSpeed">Megatron-DeepSpeed fork</a></td>
          <td>Code</td>
          <td>-</td>
          <td>Training framework adapted for AMD ROCm</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{muennighoff2023scaling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scaling Data-Constrained Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Muennighoff, Niklas and Rush, Alexander M. and Barak, Boaz and Le Scao, Teven and Piktus, Aleksandra and Tazi, Nouamane and Pyysalo, Sampo and Wolf, Thomas and Raffel, Colin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DoReMi: Optimizing Data Mixtures for LM Pretraining</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</guid><description>DoReMi uses a small proxy model with distributionally robust optimization to learn domain weights that speed up large-scale language model pretraining by 2.6x.</description><content:encoded><![CDATA[<h2 id="a-method-for-automatic-domain-reweighting">A method for automatic domain reweighting</h2>
<p>This is a <strong>method paper</strong> that introduces Domain Reweighting with Minimax Optimization (DoReMi), an algorithm for automatically tuning the mixture proportions of pretraining data domains. Rather than relying on heuristics or expensive downstream-task-based tuning, DoReMi uses a small proxy model trained with <a href="https://en.wikipedia.org/wiki/Robust_optimization">group distributionally robust optimization (Group DRO)</a> to produce domain weights that transfer to much larger models.</p>
<h2 id="why-data-mixture-proportions-matter">Why data mixture proportions matter</h2>
<p>Language model pretraining datasets combine text from many domains: web crawls, Wikipedia, books, code, academic papers, and others. The mixture proportions (how much of each domain to include) significantly affect downstream performance, but existing approaches either set them by hand (<a href="https://en.wikipedia.org/wiki/The_Pile_(dataset)">The Pile</a> uses heuristic weights) or tune them against downstream tasks (GLaM/PaLM), which is expensive and risks overfitting to a specific evaluation set. No principled, task-agnostic method existed for determining mixture proportions.</p>
<h2 id="minimax-optimization-over-domain-excess-loss">Minimax optimization over domain excess loss</h2>
<p>DoReMi&rsquo;s core insight is to frame data mixture optimization as a minimax problem: find domain weights that minimize the worst-case excess loss across all domains. The algorithm has three steps.</p>
<p><strong>Step 1</strong>: Train a small reference model (280M parameters) on some default domain weights $\alpha_{\text{ref}}$ (e.g., proportional to raw token count).</p>
<p><strong>Step 2</strong>: Train a small proxy model $p_{\theta}$ using Group DRO, which solves the minimax objective:</p>
<p>$$
\min_{\theta} \max_{\alpha \in \Delta^{k}} \sum_{i=1}^{k} \alpha_{i} \cdot \left[ \frac{1}{\sum_{x \in D_{i}} |x|} \sum_{x \in D_{i}} \ell_{\theta}(x) - \ell_{\text{ref}}(x) \right]
$$</p>
<p>where $\ell_{\theta}(x) = -\log p_{\theta}(x)$ and $\ell_{\text{ref}}(x) = -\log p_{\text{ref}}(x)$. The excess loss $\ell_{\theta}(x) - \ell_{\text{ref}}(x)$ measures how much headroom the proxy has to improve on each example relative to the reference. The inner maximization upweights domains with high excess loss via exponentiated gradient ascent, while the outer minimization trains the proxy on those upweighted domains.</p>
<p>At each training step, the domain weights update as:</p>
<p>$$
\alpha_{t}&rsquo; \leftarrow \alpha_{t-1} \exp(\eta \lambda_{t})
$$</p>
<p>where $\lambda_{t}[i]$ is the per-domain excess loss (clipped at zero), followed by renormalization and smoothing with a uniform component: $\alpha_{t} \leftarrow (1-c)\frac{\alpha_{t}&rsquo;}{\sum_{i} \alpha_{t}&rsquo;[i]} + cu$, with $c = 10^{-3}$.</p>
<p>The final domain weights are the average over all training steps: $\bar{\alpha} = \frac{1}{T}\sum_{t=1}^{T} \alpha_{t}$.</p>
<p><strong>Step 3</strong>: Resample data according to $\bar{\alpha}$ and train the full-scale model using standard procedures.</p>
<p><strong>Iterated DoReMi</strong> extends this by running multiple rounds, using the previous round&rsquo;s optimized weights as the next round&rsquo;s reference weights. This converges within 3 rounds on the GLaM dataset.</p>
<h2 id="experiments-across-the-pile-and-glam-datasets">Experiments across The Pile and GLaM datasets</h2>
<p><strong>Datasets.</strong> The Pile (22 domains, 800GB) and the GLaM dataset (8 domains, also used for PaLM). On The Pile, baseline weights come from the dataset defaults. On GLaM, baseline weights are uniform, with downstream-tuned oracle weights available for comparison.</p>
<p><strong>Setup.</strong> Transformer decoder-only LMs trained with next-token prediction. All models use batch size 512 and sequence length 1024. Proxy and reference models are 280M parameters. Main models are 8B parameters (30x larger). Training runs: 200K steps (Pile) or 300K steps (GLaM). The domain weight optimization cost (training two 280M models) is 8% of the compute for the 8B main model.</p>
<p><strong>Evaluation.</strong> Per-domain held-out perplexity and one-shot generative accuracy on five tasks: TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, and LAMBADA.</p>
<h3 id="key-domain-weight-shifts">Key domain weight shifts</h3>
<p>On The Pile, DoReMi (280M) dramatically upweights diverse web text (Pile-CC: 0.112 to 0.606) while downweighting specialized domains like ArXiv (0.105 to 0.004), PubMed Central (0.107 to 0.005), and StackExchange (0.093 to 0.015). Smaller, underrepresented domains like YouTubeSubtitles and PhilPapers receive proportionally large increases.</p>
<h3 id="scaling-behavior">Scaling behavior</h3>
<p>DoReMi was tested with matched proxy/main model sizes (280M through 1B) and with varying proxy sizes (70M through 1B) feeding into an 8B main model.</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Speedup to baseline accuracy</th>
          <th>Downstream improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DoReMi (280M to 280M)</td>
          <td>4x</td>
          <td>+2% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (280M to 8B)</td>
          <td>2.6x</td>
          <td>+6.5% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (150M to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
      <tr>
          <td>DoReMi (1B to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
  </tbody>
</table>
<p>Improvements are consistent across all tested model scales (280M to 1B matched), with no sign of diminishing returns at larger sizes.</p>
<h2 id="perplexity-improves-everywhere-even-on-downweighted-domains">Perplexity improves everywhere, even on downweighted domains</h2>
<p>The most striking finding is that DoReMi improves perplexity on all 22 domains in The Pile, including domains it downweights. The proposed explanation: the lowest-entropy domains need few samples to learn (they&rsquo;re statistically simple), while the highest-entropy domains have token distributions close to the uniform initialization and also need fewer samples. Reallocating weight to medium-entropy domains generates positive transfer that lifts all domains.</p>
<p>On The Pile, DoReMi reaches the baseline&rsquo;s downstream accuracy in 75K steps versus 200K for the baseline (2.6x speedup) and achieves a 6.5% absolute improvement in average one-shot accuracy at 200K steps.</p>
<p>On the GLaM dataset, iterated DoReMi (round 2) matches the performance of domain weights that were tuned directly on downstream task performance, despite having no knowledge of downstream tasks. Domain weights converge within 3 iterations.</p>
<h3 id="ablations">Ablations</h3>
<p>Using only the proxy model&rsquo;s loss (prefer hardest domains) or only the negative reference loss (prefer easiest domains) both underperform the full excess loss formulation. Both components are necessary: the excess loss identifies domains where the proxy has room to improve relative to what is learnable.</p>
<p>The proxy model itself typically underperforms the main model trained on its weights, and this gap grows at larger proxy scales. A 1B proxy model underperforms the 1B baseline, yet its domain weights still improve 1B main model training by over 2x. This suggests the domain weight signal is robust even when the proxy model itself is not well-trained.</p>
<h3 id="limitations">Limitations</h3>
<p>The domain weight landscape may have multiple local optima: a 280M proxy puts most weight on Pile-CC, while a 1B proxy favors OpenWebText2 instead. Both configurations improve over baseline, but the optimal weights are not unique.</p>
<p>The granularity of &ldquo;domains&rdquo; matters. DoReMi works better with more domains (22 on The Pile versus 8 on GLaM). Domains are defined by data provenance, which is coarse-grained. Fine-grained domain definitions (e.g., via clustering) could improve results but also risk DRO putting all weight on a small set of worst-case examples.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>800 GB, 22 domains</td>
          <td>Default heuristic weights as baseline</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>GLaM dataset</td>
          <td>8 domains</td>
          <td>Uniform weights as baseline; downstream-tuned oracle available</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, LAMBADA</td>
          <td>Standard splits</td>
          <td>One-shot generative evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Group DRO with exponentiated gradient ascent for domain weight updates. Step size $\eta = 1$, smoothing $c = 10^{-3}$. Per-token excess loss clipped at zero. Domain weights averaged over all training steps. Iterated DoReMi converges when $|\bar{\alpha} - \alpha_{\text{ref}}|_{\infty} &lt; 10^{-3}$.</p>
<h3 id="models">Models</h3>
<p>Vanilla Transformer decoder-only models with 256K vocabulary. Sizes: 70M (3 layers), 150M (6 layers), 280M (12 layers), 510M (12 layers), 760M (12 layers), 1B (16 layers), 8B (32 layers). All use 64-dim attention heads except 8B (128-dim).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>DoReMi (280M to 8B)</th>
          <th>Baseline (8B)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Avg one-shot accuracy</td>
          <td>+6.5% over baseline</td>
          <td>Reference</td>
          <td>5 generative tasks</td>
      </tr>
      <tr>
          <td>Worst-case log-perplexity</td>
          <td>1.46</td>
          <td>1.71</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Avg log-perplexity</td>
          <td>1.40</td>
          <td>1.64</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Domains beating baseline</td>
          <td>22/22</td>
          <td>0/22</td>
          <td>Per-domain perplexity</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Proxy and reference models (under 1B) trained on TPUv3. Models at 1B and 8B trained on TPUv4. Domain weight optimization (two 280M runs) costs 8% of 8B training FLOPs.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xie2023doremi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xie, Sang Michael and Pham, Hieu and Dong, Xuanyi and Du, Nan and Liu, Hanxiao and Lu, Yifeng and Liang, Percy and Le, Quoc V. and Ma, Tengyu and Yu, Adams Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Data Mixing Laws for LM Pretraining Optimization</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</guid><description>Ye et al. discover that LM loss follows an exponential law over domain mixture proportions, enabling cheap prediction and optimization of data mixtures.</description><content:encoded><![CDATA[<h2 id="an-empirical-discovery-of-predictable-mixture-loss-relationships">An empirical discovery of predictable mixture-loss relationships</h2>
<p>This is a <strong>discovery paper</strong> that identifies a quantitative, functional relationship between pretraining data mixture proportions and language model loss. The key finding is that domain-specific validation loss follows an exponential law over the linear combination of training domain proportions, and this law composes with standard scaling laws to enable cheap prediction of large-model performance under arbitrary mixtures.</p>
<h2 id="the-missing-quantitative-link-between-data-mixtures-and-performance">The missing quantitative link between data mixtures and performance</h2>
<p>Pretraining data for large language models combines text from many domains (web, code, academic, books, etc.), and mixture proportions significantly affect model quality. Existing approaches either set proportions by hand without disclosed criteria (LLaMA, Baichuan) or use algorithmic methods like <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> that optimize qualitatively but cannot predict the quantitative effect of a specific mixture before training. Scaling laws exist for model size and data quantity, but no equivalent existed for mixture proportions. This paper fills that gap.</p>
<h2 id="the-exponential-data-mixing-law">The exponential data mixing law</h2>
<p>The core finding: for a model of fixed size trained for a fixed number of steps, the validation loss on domain $i$ as a function of the training mixture proportions $r_{1 \dots M}$ follows:</p>
<p>$$
L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right)
$$</p>
<p>where $c_{i}$, $k_{i}$, and $t_{ij}$ are fitted parameters. The constant $c_{i}$ represents the irreducible loss (not affected by mixture changes). The interaction coefficients $t_{ij}$ capture how training domain $j$ affects validation loss on domain $i$: negative $t_{ij}$ means domain $j$ helps domain $i$, positive means it hurts.</p>
<p>This was discovered progressively:</p>
<ol>
<li><strong>Two domains</strong>: Log-reducible-loss is linear in domain proportion (univariate exponential).</li>
<li><strong>Three domains</strong>: The exponential generalizes to a linear combination over all domain proportions (Eq. above), outperforming alternatives with comparable parameter count.</li>
<li><strong>General validation</strong>: For a validation set composed of $K$ domains with proportions $s_{1 \dots K}$, the overall loss is:</li>
</ol>
<p>$$
L(r_{1 \dots M}) = \sum_{i=1}^{K} s_{i} \left[ c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right) \right]
$$</p>
<p>When the validation set composition is unknown, implicit domain aggregation treats $s_{i}$ as learnable parameters. Setting the number of implicit domains larger than the true number works well and is robust to overestimation.</p>
<h3 id="domain-interaction-patterns">Domain interaction patterns</h3>
<p>Visualizing the fitted $t_{ij}$ coefficients across 5 coarse Pile domains reveals three relationship types: most domain pairs are <strong>unrelated</strong> (sparse interaction matrix where each domain&rsquo;s loss is dominated by its own training proportion), some show <strong>facilitation</strong> (e.g., dialogue data helps internet text), and some show <strong>conflict</strong> (e.g., symbolic data hurts prose). This sparsity explains why the law can be fitted with fewer samples than the quadratic parameter count would suggest.</p>
<h2 id="nested-scaling-pipeline-for-cheap-prediction">Nested scaling pipeline for cheap prediction</h2>
<p>Fitting data mixing laws directly at target scale is too expensive (requires many full training runs at different mixtures). The paper proposes nesting three scaling laws:</p>
<p><strong>Step 1</strong>: For each mixture $r_{i}$ and each small model size $N_{j}$, train for $S_{0}$ steps. Fit a <a href="https://en.wikipedia.org/wiki/Power_law">power law</a> $L(S) = E_{1} + B/S^{\beta}$ over steps to extrapolate to the target step count $S_{\text{target}}$.</p>
<p><strong>Step 2</strong>: With the step-extrapolated losses for each mixture, fit a power law $L(N) = E_{2} + A/N^{\alpha}$ over model sizes to extrapolate to the target model size $N_{\text{target}}$.</p>
<p><strong>Step 3</strong>: With the predicted losses at $(N_{\text{target}}, S_{\text{target}})$ for all sampled mixtures, fit the data mixing law and search for the optimal mixture.</p>
<p>This pipeline requires only training small models (70M to 410M) for short runs (30B tokens) to predict performance of a 1B model trained for 100B tokens.</p>
<h3 id="mixture-sampling-strategy">Mixture sampling strategy</h3>
<p>To get informative samples efficiently, the paper uses double-diminishing proportions: for each domain, enumerate proportions by halving from the maximum available. This distributes losses evenly across the exponential law&rsquo;s range. From 40 candidate mixtures trained at the smallest scale (70M), 20 are selected based on which subset minimizes data mixing law fitting error.</p>
<h2 id="experiments-on-redpajama-and-continual-pretraining">Experiments on RedPajama and continual pretraining</h2>
<p><strong>Main experiment.</strong> Models trained on RedPajama, validated on the Pile (mimicking the common scenario where validation data comes from a different distribution than training). Small models: 70M, 160M, 305M, 410M trained for 30B tokens. Target: 1B model for 100B tokens.</p>
<p>The optimized mixture dramatically redistributes weight compared to RedPajama defaults:</p>
<table>
  <thead>
      <tr>
          <th>Domain</th>
          <th>Default</th>
          <th>Optimized</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>0.670</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>0.150</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>0.045</td>
          <td>0.141</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>0.045</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>0.045</td>
          <td>0.094</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>0.025</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>0.020</td>
          <td>0.016</td>
      </tr>
  </tbody>
</table>
<p>The optimized mixture reaches the default mixture&rsquo;s final performance in 73% of the training steps and eventually achieves performance equivalent to 48% more training on the default mixture.</p>
<p><strong>Comparison to DoReMi and DoGE.</strong> Data mixing laws outperform both: the predicted-optimal mixture achieves lower validation loss than DoReMi and DoGE (both universal and OOD settings) for 1B models trained for 100B tokens on RedPajama.</p>
<p><strong>Continual pretraining.</strong> The law extends to continual pretraining (Pythia-70M on Pile + Python code). It accurately predicts the critical mixture proportion that avoids <a href="https://en.wikipedia.org/wiki/Catastrophic_interference">catastrophic forgetting</a> on the original domain while improving the target domain. This suggests data mixing laws could guide dynamic data schedules across multi-stage pretraining.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The data mixing law provides a predictive framework rather than just an optimization algorithm. Key implications:</p>
<ul>
<li>The interaction coefficients $t_{ij}$ make domain relationships quantitatively observable before full-scale training, identifying facilitation and conflict pairs.</li>
<li>The nested pipeline&rsquo;s cost is dominated by the small-model training runs (40 mixtures at 70M scale), which is orders of magnitude cheaper than even a single target-scale run.</li>
<li>The continual pretraining application opens the door to optimizing dynamic data schedules, where mixture proportions change across training stages.</li>
</ul>
<p><strong>Limitations</strong>: The &ldquo;domain&rdquo; concept remains loosely defined (provenance-based). The nested scaling laws introduce compounding errors at each step, and predictions tend to slightly underestimate actual loss. The number of required fitting samples, while subquadratic in practice due to sparsity, still scales with the number of domains. No theoretical justification for the exponential form is provided; it is a purely empirical finding.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (pilot)</td>
          <td>The Pile (GitHub, Pile-CC, Books3)</td>
          <td>30B tokens</td>
          <td>2-domain and 3-domain experiments</td>
      </tr>
      <tr>
          <td>Training (main)</td>
          <td>RedPajama</td>
          <td>100B tokens</td>
          <td>7 domains</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>The Pile validation set</td>
          <td>Standard split</td>
          <td>Out-of-distribution relative to RedPajama</td>
      </tr>
      <tr>
          <td>Continual pretraining</td>
          <td>Pile + Python code</td>
          <td>10B tokens</td>
          <td>Pythia-70M base model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data mixing law: $L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp(\sum_{j} t_{ij} r_{j})$. Fitted via AdaBoost Regressor on sampled mixtures. Step scaling law: $L(S) = E_{1} + B/S^{\beta}$. Model size scaling law: $L(N) = E_{2} + A/N^{\alpha}$. Both fitted via Huber loss minimization with LBFGS. Decomposed Chinchilla-style (separate fits for stability). 40 candidate mixtures sampled via double-diminishing proportions, 20 selected for the final pipeline.</p>
<h3 id="models">Models</h3>
<p>Transformer decoder-only LMs. Pilot: 70M, 160M. Main pipeline: 70M, 160M, 305M, 410M (for fitting), 1B (target). Batch size: 1M tokens. Cosine learning rate decay with 2K step warmup, decaying to 0.1x at 100K steps.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Optimized Mixture</th>
          <th>Default Mixture</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Steps to match default final loss</td>
          <td>73K (73%)</td>
          <td>100K (100%)</td>
          <td>27% training reduction</td>
      </tr>
      <tr>
          <td>Equivalent extra training</td>
          <td>+48%</td>
          <td>Baseline</td>
          <td>Estimated via step scaling law</td>
      </tr>
      <tr>
          <td>Validation loss (1B, 100B)</td>
          <td>Lowest</td>
          <td>Higher than optimized</td>
          <td>Also beats DoReMi and DoGE</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>8 A100 GPUs. Training times per 30B-token run: 3.5 hours (70M), 8 hours (160M), 16 hours (305M), 21 hours (410M).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Pilot and validation data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/togethercomputer/RedPajama-Data">RedPajama</a></td>
          <td>Dataset</td>
          <td>Apache 2.0</td>
          <td>Main training data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/EleutherAI/pythia">Pythia Suite</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Model architecture configs; Pythia-70M checkpoint for continual pretraining</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status: Partially Reproducible.</strong> Datasets and base model checkpoints are public. No official code release for the data mixing law fitting pipeline, mixture sampling, or the nested scaling law prediction workflow.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ye2025datamixinglaws,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data Mixing Laws: Optimizing Data Mixtures by Predicting Language Modeling Performance}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Jiasheng and Liu, Peiju and Sun, Tianxiang and Zhan, Jun and Zhou, Yunhua and Qiu, Xipeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RWKV: Linear-Cost RNN with Transformer Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</guid><description>RWKV combines parallelizable transformer training with constant-cost RNN inference using linear attention and channel-wise decay.</description><content:encoded><![CDATA[<h2 id="a-new-architecture-bridging-rnns-and-transformers">A New Architecture Bridging RNNs and Transformers</h2>
<p>This is a <strong>Method</strong> paper that introduces RWKV (Receptance Weighted Key Value), a novel sequence model architecture that combines the parallelizable training of Transformers with the efficient $O(Td)$ inference of RNNs. RWKV can be formulated equivalently as either a Transformer (for parallel training) or an RNN (for sequential inference), achieving the lowest computational and memory complexity among comparable architectures while matching Transformer-level performance. The authors scale RWKV to 14 billion parameters, making it the largest dense RNN ever trained at the time of publication.</p>
<h2 id="the-quadratic-cost-of-self-attention">The Quadratic Cost of Self-Attention</h2>
<p>Transformers have become the dominant architecture for NLP, powering models like GPT-3, LLaMA, and Chinchilla. Their self-attention mechanism captures both local and long-range dependencies while supporting parallelized training. However, self-attention scales quadratically with sequence length in both time ($O(T^2d)$) and space ($O(T^2 + Td)$), making it computationally and memory intensive for long sequences and resource-constrained deployment.</p>
<p>RNNs, by contrast, offer linear scaling in memory and computation, but suffer from the vanishing gradient problem and cannot parallelize across the time dimension during training. This limits their scalability and makes them unable to match Transformer performance in practice.</p>
<p>Prior work on efficient Transformers (Reformer, Performer, Linformer, AFT, MEGA) has attempted to reduce this quadratic cost, often at the expense of model expressivity. RWKV aims to combine the best of both worlds: Transformer-grade training efficiency with RNN-grade inference cost, without any approximation to the attention mechanism.</p>
<h2 id="linear-attention-via-channel-wise-decay">Linear Attention via Channel-Wise Decay</h2>
<p>RWKV is built on four core vectors that interact multiplicatively at each timestep:</p>
<ul>
<li><strong>R</strong> (Receptance): receives past information, acting as a gating signal</li>
<li><strong>W</strong> (Weight): a trainable positional weight decay vector</li>
<li><strong>K</strong> (Key): analogous to keys in standard attention</li>
<li><strong>V</strong> (Value): analogous to values in standard attention</li>
</ul>
<p>The architecture consists of stacked residual blocks, each containing a <strong>time-mixing</strong> sub-block and a <strong>channel-mixing</strong> sub-block.</p>
<h3 id="token-shift">Token Shift</h3>
<p>All linear projection vectors are produced by interpolating between the current input $x_t$ and the previous input $x_{t-1}$, creating a token shift mechanism:</p>
<p>$$
r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})
$$</p>
<p>$$
k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})
$$</p>
<p>$$
v_t = W_v \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1})
$$</p>
<p>where $\mu_r$, $\mu_k$, $\mu_v$ are learnable interpolation parameters. This is implemented efficiently as a simple offset in the temporal dimension.</p>
<h3 id="the-wkv-operator">The WKV Operator</h3>
<p>The core attention-like computation replaces standard dot-product attention with a channel-wise weighted sum using exponential decay:</p>
<p>$$
wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} \odot v_i + e^{u + k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u + k_t}}
$$</p>
<p>Here $w$ is the channel-wise time decay vector and $u$ is a separate bonus vector that attends specifically to the current token. Unlike AFT where $W$ is a pairwise matrix, RWKV treats $W$ as a channel-wise vector modified by relative position, enabling the recurrent formulation.</p>
<h3 id="output-gating">Output Gating</h3>
<p>The receptance vector gates the WKV output through a sigmoid:</p>
<p>$$
o_t = W_o \cdot (\sigma(r_t) \odot wkv_t)
$$</p>
<p>The channel-mixing block uses a similar gating mechanism with squared ReLU activation:</p>
<p>$$
o&rsquo;_t = \sigma(r&rsquo;_t) \odot (W&rsquo;_v \cdot \max(k&rsquo;_t, 0)^2)
$$</p>
<h3 id="dual-mode-operation">Dual-Mode Operation</h3>
<p>During <strong>training</strong>, RWKV operates in time-parallel mode. The matrix multiplications ($W_\lambda$ for $\lambda \in {r, k, v, o}$) dominate at $O(BTd^2)$ and parallelize identically to standard Transformers. The element-wise WKV computation is $O(BTd)$ and parallelizes along batch and channel dimensions.</p>
<p>During <strong>inference</strong>, RWKV switches to time-sequential mode. Each timestep updates a fixed-size state vector, giving constant $O(d)$ memory and $O(Td)$ total time for generating $T$ tokens, compared to $O(T^2d)$ for standard Transformers.</p>
<h3 id="optimizations">Optimizations</h3>
<p>Three additional design choices improve training:</p>
<ol>
<li><strong>Custom CUDA kernels</strong> for the sequential WKV computation, fusing it into a single kernel on training accelerators</li>
<li><strong>Small init embedding</strong>: initializing the embedding matrix with small values plus an additional LayerNorm, accelerating convergence</li>
<li><strong>Custom initialization</strong>: most weights initialized to zero with no biases, following identity-mapping principles from residual network design</li>
</ol>
<h2 id="scaling-to-14b-parameters-and-benchmark-evaluation">Scaling to 14B Parameters and Benchmark Evaluation</h2>
<h3 id="model-scaling">Model Scaling</h3>
<p>The authors train six RWKV models from 169M to 14B parameters, all for one epoch (330B tokens) on the Pile:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>Dimension</th>
          <th>Parameters</th>
          <th>FLOP/Token</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>12</td>
          <td>768</td>
          <td>$1.69 \times 10^8$</td>
          <td>$2.61 \times 10^8$</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>24</td>
          <td>1024</td>
          <td>$4.30 \times 10^8$</td>
          <td>$7.57 \times 10^8$</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>24</td>
          <td>2048</td>
          <td>$1.52 \times 10^9$</td>
          <td>$2.82 \times 10^9$</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>32</td>
          <td>2560</td>
          <td>$2.99 \times 10^9$</td>
          <td>$5.71 \times 10^9$</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>32</td>
          <td>4096</td>
          <td>$7.39 \times 10^9$</td>
          <td>$1.44 \times 10^{10}$</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>40</td>
          <td>5120</td>
          <td>$1.42 \times 10^{10}$</td>
          <td>$2.78 \times 10^{10}$</td>
      </tr>
  </tbody>
</table>
<p>The parameter count follows: $\text{params} = 2VD + 13D^2L + D(11L + 4)$, where $V = 50277$ is vocabulary size, $D$ is model dimension, and $L$ is layers. FLOPs match the standard transformer formula: $\text{FLOP} = 6 \cdot [\text{tokens}] \cdot [\text{params}]$.</p>
<h3 id="scaling-laws">Scaling Laws</h3>
<p>Training 45 RWKV models across varied (dataset, parameters) pairs, the authors find that RWKV follows the same log-log linear scaling law established for Transformers. The linear fit to Pareto-optimal points achieves $r^2 = 0.994$, and extrapolation an additional order of magnitude still yields $r^2 = 0.875$. This contrasts with prior claims that LSTMs do not follow transformer-like scaling.</p>
<h3 id="nlp-benchmarks">NLP Benchmarks</h3>
<p>RWKV is compared against similarly-sized models trained on comparable token budgets: Pythia, OPT, and BLOOM (all FLOP-matched). Results span twelve benchmarks: ARC (Easy/Challenge), BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, and Winogrande.</p>
<p>RWKV performs competitively with Transformers across all model sizes. On average across benchmarks, RWKV tracks closely with Pythia and outperforms OPT and BLOOM at comparable scales.</p>
<h3 id="long-context-and-extended-finetuning">Long Context and Extended Finetuning</h3>
<p>RWKV can extend its context length after pretraining through progressive finetuning: doubling from 1024 to 2048 (10B tokens), then to 4096 (100B tokens), and finally to 8192 (100B tokens). Each doubling reduces test loss on the Pile, indicating effective use of longer context.</p>
<p>On the Long Range Arena (LRA) benchmark, which tests sequences from 1,000 to 16,000 tokens, RWKV performs second only to S4 across the five datasets.</p>
<h3 id="inference-efficiency">Inference Efficiency</h3>
<p>Benchmarking text generation on CPU (x86) and GPU (NVIDIA A100 80GB) at float32 precision shows that RWKV exhibits linear scaling in generation time, while Transformers scale quadratically. This advantage grows with sequence length: for long outputs, RWKV completes generation substantially faster at equivalent model sizes.</p>
<h2 id="competitive-performance-with-key-caveats">Competitive Performance with Key Caveats</h2>
<p>RWKV demonstrates that RNN-class models can match Transformer performance at scale, while maintaining $O(Td)$ time and $O(d)$ memory during inference. The key findings are:</p>
<ol>
<li><strong>Scaling laws hold</strong>: RWKV follows the same compute-optimal scaling as Transformers ($r^2 = 0.994$), contradicting earlier claims about RNN scaling behavior</li>
<li><strong>Competitive NLP performance</strong>: Across twelve benchmarks, RWKV matches similarly-sized Transformers trained on comparable data</li>
<li><strong>Linear inference cost</strong>: Generation time scales linearly rather than quadratically, with constant memory regardless of sequence length</li>
<li><strong>Context extension</strong>: Progressive finetuning effectively extends the context window post-training</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors identify two primary limitations:</p>
<p><strong>Information compression</strong>: Linear attention funnels all past information through a single fixed-size state vector. For tasks requiring recall of specific details over very long contexts, this is mechanistically more constrained than full self-attention, which maintains direct access to all previous tokens.</p>
<p><strong>Prompt sensitivity</strong>: RWKV is more sensitive to prompt engineering than standard Transformers. The linear attention mechanism limits how much prompt information carries forward, making the order of information in the prompt particularly important. Reordering prompts improved F1 from 44.2% to 74.8% on one task.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest several avenues: applying parallel scan to reduce WKV cost to $O(B \log(T) d)$, extending RWKV to encoder-decoder and multimodal architectures, leveraging hidden states for interpretability and safety, and increasing internal state size to improve long-range recall.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch training and inference implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/BlinkDL/rwkv-4-pile-14b">Pre-trained weights (169M to 14B)</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>All six Pile-trained sizes on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>)</td>
      </tr>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>Mixed</td>
          <td>825 GiB pretraining corpus; component licenses vary by source</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Highly Reproducible. Training code (Apache-2.0), pre-trained weights for all six model sizes, the full training corpus, and complete hyperparameters (Appendix G) are all publicly available. The only missing detail is the specific GPU cluster configuration used for pretraining.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>330B tokens</td>
          <td>One full epoch for all model sizes</td>
      </tr>
      <tr>
          <td>Context extension</td>
          <td>The Pile</td>
          <td>210B additional tokens</td>
          <td>Progressive doubling: 1024 to 8192</td>
      </tr>
      <tr>
          <td>NLP evaluation</td>
          <td>ARC, BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, Winogrande</td>
          <td>Various</td>
          <td>Zero-shot evaluation</td>
      </tr>
      <tr>
          <td>Long-range evaluation</td>
          <td>Long Range Arena (LRA)</td>
          <td>1K-16K tokens</td>
          <td>Five sub-tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam ($\beta = (0.9, 0.99)$), no weight decay</li>
<li>Precision: bfloat16</li>
<li>Training context length: 1024 tokens</li>
<li>Learning rate: constant warmup, then exponential decay</li>
<li>Auxiliary loss from PaLM (softmax normalizer regularization)</li>
<li>Batch size: 128 or 256 sequences (dynamically switched)</li>
<li>Training organized into mini-epochs of 40,320 samples each (8,043 mini-epochs per Pile epoch)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Init LR</th>
          <th>Warmup Mini-Epochs</th>
          <th>End LR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>6e-4</td>
          <td>361</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>4e-4</td>
          <td>411</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>3e-4</td>
          <td>443</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>1.5e-4</td>
          <td>451</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>1.5e-4</td>
          <td>465</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>1e-4</td>
          <td>544</td>
          <td>7e-6</td>
      </tr>
  </tbody>
</table>
<p>All pretrained models (169M to 14B) are publicly released on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>) under Apache-2.0. Training code is at <a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a> (Apache-2.0).</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>All NLP benchmarks evaluated in zero-shot setting</li>
<li>FLOP-matched comparison against Pythia, OPT, BLOOM</li>
<li>Inference benchmarked on CPU (x86) and GPU (NVIDIA A100 80GB) at float32</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Inference experiments: NVIDIA A100 80GB GPU</li>
<li>Training hardware details not fully specified; FLOP budgets reported per model</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., &hellip; &amp; Zhu, R.-J. (2023). RWKV: Reinventing RNNs for the Transformer Era. In <em>Findings of the Association for Computational Linguistics: EMNLP 2023</em>, pp. 14048-14077.</p>
<p><strong>Publication</strong>: Findings of EMNLP 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/BlinkDL/RWKV-LM">GitHub Repository (Apache-2.0)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{peng2023rwkv,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{RWKV: Reinventing RNNs for the Transformer Era}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Peng, Bo and Alcaide, Eric and Anthony, Quentin and Albalak, Alon and Arcadinho, Samuel and Biderman, Stella and Cao, Huanqi and Cheng, Xin and Chung, Michael and Derczynski, Leon and Du, Xingjian and Grella, Matteo and GV, Kranthi Kiran and He, Xuzheng and Hou, Haowen and Kazienko, Przemys{\l}aw and Koco{\&#39;n}, Jan and Kong, Jiaming and Koptyra, Bart{\l}omiej and Lau, Hayden and Lin, Jiaju and Mantri, Krishna Sri Ipsit and Mom, Ferdinand and Saito, Atsushi and Song, Guangyu and Tang, Xiangru and Wind, Johan S. and Wo{\&#39;z}niak, Stanis{\l}aw and Zhang, Zhenyuan and Zhou, Qinghua and Zhu, Jian and Zhu, Rui-Jie}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Findings of the Association for Computational Linguistics: EMNLP 2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{14048--14077}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2023.findings-emnlp.936}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Liquid-S4: Input-Dependent State-Space Models</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/liquid-s4-state-space-models/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/liquid-s4-state-space-models/</guid><description>Liquid-S4 combines liquid time-constant networks with structured state-space models, adding input-dependent kernels for long-range sequence modeling.</description><content:encoded><![CDATA[<h2 id="a-method-for-input-adaptive-sequence-modeling">A Method for Input-Adaptive Sequence Modeling</h2>
<p>This is a <strong>Method</strong> paper that introduces Liquid-S4, a new state-space model combining the structured state-space framework (S4) with liquid time-constant (LTC) networks. The primary contribution is an input-dependent state transition mechanism that allows the model to adapt its dynamics based on incoming inputs, while retaining the efficient convolutional kernel computation of S4.</p>
<h2 id="scaling-liquid-networks-to-long-sequences">Scaling Liquid Networks to Long Sequences</h2>
<p>Liquid time-constant (LTC) networks are continuous-time neural networks with input-dependent state transitions, giving them strong generalization and causal modeling properties. However, LTCs rely on ODE solvers that limit their scalability to long sequences. Structured state-space models (S4) solve this scalability problem through HiPPO initialization, diagonal plus low-rank (DPLR) parameterization, and efficient Cauchy kernel computation in the frequency domain, but they use fixed (input-independent) state transitions.</p>
<p>The key question this paper addresses: can the expressivity of LTC networks be combined with the efficiency and scalability of S4 to improve long-range sequence modeling?</p>
<h2 id="the-liquid-kernel-input-dependent-convolutions">The Liquid Kernel: Input-Dependent Convolutions</h2>
<p>The core innovation is a linearized LTC state-space model that replaces the standard SSM dynamics:</p>
<p>$$\dot{x}(t) = \mathbf{A}x(t) + \mathbf{B}u(t)$$</p>
<p>with an input-dependent formulation:</p>
<p>$$\dot{x}(t) = \left[\mathbf{A} + \mathbf{B}u(t)\right]x(t) + \mathbf{B}u(t)$$</p>
<p>where $u(t)$ now modulates the state transition matrix itself. After discretization via the <a href="https://en.wikipedia.org/wiki/Bilinear_transform">bilinear transform</a>, the recurrence becomes:</p>
<p>$$x_{k} = \left(\overline{\mathbf{A}} + \overline{\mathbf{B}}u_{k}\right)x_{k-1} + \overline{\mathbf{B}}u_{k}$$</p>
<p>Unrolling this recurrence reveals that the output $y_{k}$ decomposes into two parts:</p>
<p>$$y = \overline{\mathbf{K}} * u + \overline{\mathbf{K}}_{\text{liquid}} * u_{\text{correlations}}$$</p>
<p>The first term is the standard S4 convolutional kernel $\overline{\mathbf{K}}$, mapping individual input time steps independently. The second term is a new &ldquo;liquid kernel&rdquo; $\overline{\mathbf{K}}_{\text{liquid}}$ that operates on <a href="https://en.wikipedia.org/wiki/Autocorrelation">auto-correlation</a> terms of the input signal (products $u_{i}u_{j}$, $u_{i}u_{j}u_{k}$, etc., up to a chosen order $\mathcal{P}$).</p>
<p><strong>Proposition 1</strong> shows that each liquid kernel of order $p$ can be computed from the precomputed S4 kernel via a <a href="https://en.wikipedia.org/wiki/Hadamard_product_(matrices)">Hadamard product</a> with $\overline{\mathbf{B}}^{p-1}$ followed by an anti-diagonal transformation (flip):</p>
<p>$$\overline{\mathbf{K}}_{\text{liquid}=p} = \left[\overline{\mathbf{K}}_{(L-\tilde{L},L)} \odot \overline{\mathbf{B}}_{(L-\tilde{L},L)}^{p-1}\right] * \mathbf{J}_{\tilde{L}}$$</p>
<p>This is the KB (Kernel $\times$ B) mode. The authors also propose a simplified PB (Powers of B) mode that sets the transition matrix $\overline{\mathbf{A}}$ to identity for the correlation terms:</p>
<p>$$\overline{\mathbf{K}}_{\text{liquid}=p} = \overline{\mathbf{C}} \odot \overline{\mathbf{B}}^{p-1}$$</p>
<p>The PB kernel is cheaper to compute and performs equally well or better in practice.</p>
<p>The computational complexity is $\tilde{\mathcal{O}}(N + L + p_{\text{max}}\tilde{L})$, where $N$ is the state size, $L$ the sequence length, $p_{\text{max}}$ the maximum liquid order, and $\tilde{L}$ the liquid kernel length (typically two orders of magnitude smaller than $L$).</p>
<h2 id="benchmarks-across-long-range-sequence-tasks">Benchmarks Across Long-Range Sequence Tasks</h2>
<p>Liquid-S4 is evaluated on four benchmark suites with the PB kernel using the S4-LegS (scaled <a href="https://en.wikipedia.org/wiki/Legendre_polynomials">Legendre</a>) parameterization.</p>
<h3 id="long-range-arena-lra">Long Range Arena (LRA)</h3>
<p>The LRA benchmark contains six tasks with sequence lengths from 1K to 16K. Liquid-S4 achieves state-of-the-art on all six tasks with an average accuracy of 87.32%:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Input Length</th>
          <th>Liquid-S4</th>
          <th>S4-LegS</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ListOps</td>
          <td>2048</td>
          <td>62.75%</td>
          <td>59.60%</td>
          <td>+3.15%</td>
      </tr>
      <tr>
          <td>Text (IMDB)</td>
          <td>2048</td>
          <td>89.02%</td>
          <td>86.82%</td>
          <td>+2.20%</td>
      </tr>
      <tr>
          <td>Retrieval (AAN)</td>
          <td>4000</td>
          <td>91.20%</td>
          <td>90.90%</td>
          <td>+0.30%</td>
      </tr>
      <tr>
          <td>Image (CIFAR)</td>
          <td>1024</td>
          <td>89.50%</td>
          <td>88.65%</td>
          <td>+0.85%</td>
      </tr>
      <tr>
          <td>Pathfinder</td>
          <td>1024</td>
          <td>94.80%</td>
          <td>94.20%</td>
          <td>+0.60%</td>
      </tr>
      <tr>
          <td>Path-X</td>
          <td>16384</td>
          <td>96.66%</td>
          <td>96.35%</td>
          <td>+0.31%</td>
      </tr>
      <tr>
          <td><strong>Average</strong></td>
          <td></td>
          <td><strong>87.32%</strong></td>
          <td><strong>86.09%</strong></td>
          <td><strong>+1.23%</strong></td>
      </tr>
  </tbody>
</table>
<p>Liquid orders $p$ range from 2 to 6 across tasks.</p>
<h3 id="bidmc-vital-signs">BIDMC Vital Signs</h3>
<p>On medical time-series regression (heart rate, respiratory rate, <a href="https://en.wikipedia.org/wiki/Oxygen_saturation_(medicine)">SpO2</a> prediction from length-4000 biomarker signals):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Liquid-S4 (RMSE)</th>
          <th>S4-LegS (RMSE)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Heart Rate</td>
          <td>0.303</td>
          <td>0.332</td>
          <td>8.7%</td>
      </tr>
      <tr>
          <td>Respiratory Rate</td>
          <td>0.158</td>
          <td>0.247</td>
          <td>36.0%</td>
      </tr>
      <tr>
          <td>SpO2</td>
          <td>0.066</td>
          <td>0.090</td>
          <td>26.7%</td>
      </tr>
  </tbody>
</table>
<h3 id="sequential-cifar-scifar">Sequential CIFAR (sCIFAR)</h3>
<p>Liquid-S4 with $p=3$ achieves 92.02% accuracy on 1-D pixel-level image classification, improving over S4-LegS (91.80%).</p>
<h3 id="speech-commands-full-35-labels">Speech Commands (Full 35 Labels)</h3>
<p>On the raw 16kHz speech recognition task, Liquid-S4 achieves 96.78% accuracy with only 224K parameters, a 30% reduction compared to S4&rsquo;s 307K. On the zero-shot 8kHz experiment, performance drops to 90.00% (vs. 91.32% for S4-LegS), which the authors attribute to the liquid kernel&rsquo;s sensitivity to input covariance structure at different sampling rates.</p>
<h2 id="consistent-improvements-with-smaller-models">Consistent Improvements with Smaller Models</h2>
<p>Liquid-S4 achieves state-of-the-art performance on every benchmark evaluated: all six LRA tasks (87.32% average), all three BIDMC vital signs tasks, sCIFAR, and full Speech Commands recognition. The gains are particularly large on tasks where input correlation structure matters (ListOps +3.15%, IMDB +2.20%, respiratory rate RMSE improvement of 36%).</p>
<p>A practical advantage is that Liquid-S4 works well with smaller state sizes (as low as 7 units for some tasks), reducing parameter counts. The PB kernel is recommended over KB for its simplicity and competitive performance. Higher liquid orders ($p$) consistently improve performance, though $p=3$ is recommended as a default.</p>
<p>Limitations include degraded performance in zero-shot frequency transfer (8kHz Speech Commands), suggesting the liquid kernel&rsquo;s input covariance terms may not generalize well across sampling rate changes. The paper also does not compare against non-SSM approaches beyond the LRA benchmark. The causal (unidirectional) configuration works better than bidirectional for Liquid-S4, which may limit applicability to tasks that benefit from bidirectional context.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Classification: Partially Reproducible.</strong> Code and all benchmark datasets are publicly available, with complete hyperparameters documented. No pre-trained weights are released and hardware requirements are not specified.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/raminmh/liquid-s4">raminmh/liquid-s4</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch implementation; fork of the S4 repo with KB and PB kernels added</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>Long Range Arena (LRA)</td>
          <td>6 tasks, 1K-16K seq length</td>
          <td>ListOps, IMDB, AAN, CIFAR, Pathfinder, Path-X</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BIDMC Vital Signs</td>
          <td>4000-length biomarker signals</td>
          <td>Heart rate, respiratory rate, SpO2</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>sCIFAR</td>
          <td>1024-length flattened images</td>
          <td>10-class classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Speech Commands</td>
          <td>16kHz raw audio, 35 labels</td>
          <td>Full dataset with zero-shot 8kHz test</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The Liquid-S4 kernel computation builds on the S4 kernel pipeline:</p>
<ol>
<li>Initialize $\mathbf{A}$ with HiPPO (scaled Legendre) matrix in DPLR form</li>
<li>Compute S4 kernel $\overline{\mathbf{K}}$ via Cauchy kernel and iFFT</li>
<li>For each liquid order $p \in {2, \ldots, \mathcal{P}}$, compute $\overline{\mathbf{K}}_{\text{liquid}=p}$ using either KB or PB mode</li>
<li>Convolve $\overline{\mathbf{K}}_{\text{liquid}}$ with input correlation vector $u_{\text{correlations}}$</li>
</ol>
<p>The PB kernel mode is used in all reported experiments. The PyKeops package is used for large tensor computations.</p>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Depth</th>
          <th>Features</th>
          <th>State Size</th>
          <th>Norm</th>
          <th>LR</th>
          <th>Epochs</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ListOps</td>
          <td>9</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.002</td>
          <td>30</td>
      </tr>
      <tr>
          <td>IMDB</td>
          <td>4</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.003</td>
          <td>50</td>
      </tr>
      <tr>
          <td>AAN</td>
          <td>6</td>
          <td>256</td>
          <td>64</td>
          <td>BN</td>
          <td>0.005</td>
          <td>20</td>
      </tr>
      <tr>
          <td>CIFAR (LRA)</td>
          <td>6</td>
          <td>512</td>
          <td>512</td>
          <td>LN</td>
          <td>0.01</td>
          <td>200</td>
      </tr>
      <tr>
          <td>Pathfinder</td>
          <td>6</td>
          <td>256</td>
          <td>64</td>
          <td>BN</td>
          <td>0.0004</td>
          <td>200</td>
      </tr>
      <tr>
          <td>Path-X</td>
          <td>6</td>
          <td>320</td>
          <td>64</td>
          <td>BN</td>
          <td>0.001</td>
          <td>60</td>
      </tr>
      <tr>
          <td>Speech Commands</td>
          <td>6</td>
          <td>128</td>
          <td>7</td>
          <td>BN</td>
          <td>0.008</td>
          <td>50</td>
      </tr>
      <tr>
          <td>BIDMC (HR)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.005</td>
          <td>500</td>
      </tr>
      <tr>
          <td>BIDMC (RR)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.01</td>
          <td>500</td>
      </tr>
      <tr>
          <td>BIDMC (SpO2)</td>
          <td>6</td>
          <td>128</td>
          <td>256</td>
          <td>LN</td>
          <td>0.01</td>
          <td>500</td>
      </tr>
      <tr>
          <td>sCIFAR</td>
          <td>6</td>
          <td>512</td>
          <td>512</td>
          <td>LN</td>
          <td>0.01</td>
          <td>200</td>
      </tr>
  </tbody>
</table>
<p>Liquid-S4 generally requires smaller learning rates than S4/S4D. $\Delta t_{\text{max}} = 0.2$ for all experiments; $\Delta t_{\text{min}} \propto 1/\text{seq_length}$.</p>
<h3 id="evaluation">Evaluation</h3>
<p>All results report validation accuracy (except BIDMC, which reports test RMSE). Experiments use 2-3 random seeds with standard deviations reported.</p>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hasani, R., Lechner, M., Wang, T.-H., Chahine, M., Amini, A., &amp; Rus, D. (2022). Liquid Structural State-Space Models. <em>arXiv preprint arXiv:2209.12951</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{hasani2022liquid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Liquid Structural State-Space Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hasani, Ramin and Lechner, Mathias and Wang, Tsun-Hsuan and Chahine, Makram and Amini, Alexander and Rus, Daniela}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2209.12951}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Lagrangian Neural Networks for Physics</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/lagrangian-neural-networks/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/lagrangian-neural-networks/</guid><description>LNNs parameterize arbitrary Lagrangians with neural networks, learning energy-conserving dynamics without requiring canonical coordinates.</description><content:encoded><![CDATA[<h2 id="a-method-for-learning-arbitrary-lagrangians">A Method for Learning Arbitrary Lagrangians</h2>
<p>This is a <strong>Method</strong> paper that introduces Lagrangian Neural Networks (LNNs), a neural network architecture that parameterizes arbitrary Lagrangians to learn energy-conserving dynamics from data. The key contribution is showing that neural networks can learn Lagrangian functions directly, and that the Euler-Lagrange equation can be solved numerically using automatic differentiation to produce physically consistent dynamics. The approach is strictly more general than prior methods: it does not require canonical coordinates (unlike Hamiltonian Neural Networks) and does not restrict the functional form of kinetic energy (unlike Deep Lagrangian Networks).</p>
<h2 id="why-standard-neural-networks-fail-at-conservation-laws">Why Standard Neural Networks Fail at Conservation Laws</h2>
<p>Neural networks struggle to learn fundamental symmetries and conservation laws from data. A standard neural network trained on trajectories of a <a href="https://en.wikipedia.org/wiki/Double_pendulum">double pendulum</a> will gradually dissipate energy over long rollouts, producing physically implausible behavior. This happens because unconstrained function approximators have no inductive bias toward conservation.</p>
<p>Hamiltonian Neural Networks (HNNs) addressed this by learning a Hamiltonian function, which automatically enforces energy conservation. However, the <a href="https://en.wikipedia.org/wiki/Hamiltonian_mechanics">Hamiltonian formalism</a> requires inputs in <a href="https://en.wikipedia.org/wiki/Canonical_coordinates">canonical coordinates</a> $(q, p)$ satisfying strict <a href="https://en.wikipedia.org/wiki/Poisson_bracket">Poisson bracket</a> relations:</p>
<p>$$
p_i \equiv \frac{\partial \mathcal{L}}{\partial \dot{q}_i} \quad \Longleftrightarrow \quad {q_i, q_j} = 0, \quad {p_i, p_j} = 0, \quad {q_i, p_j} = \delta_{ij}
$$</p>
<p>In many real-world settings, the canonical momenta are unknown or difficult to compute. For example, in special relativity the canonical momentum $\dot{q}(1 - \dot{q}^2)^{-3/2}$ is a complex nonlinear function of velocity. Deep Lagrangian Networks (DeLaNs) partially addressed this by learning Lagrangians, but they assumed kinetic energy takes the rigid-body form $T = \dot{q}^T M \dot{q}$, which excludes relativistic and other non-standard systems.</p>
<h2 id="solving-euler-lagrange-for-a-black-box-lagrangian">Solving Euler-Lagrange for a Black-Box Lagrangian</h2>
<p>The core innovation of LNNs is a method for computing accelerations from a neural network that represents an arbitrary Lagrangian $\mathcal{L}(q, \dot{q})$. Starting from the <a href="https://en.wikipedia.org/wiki/Euler%E2%80%93Lagrange_equation">Euler-Lagrange equation</a>:</p>
<p>$$
\frac{d}{dt} \nabla_{\dot{q}} \mathcal{L} = \nabla_{q} \mathcal{L}
$$</p>
<p>The authors expand the time derivative using the chain rule, yielding:</p>
<p>$$
\left(\nabla_{\dot{q}} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \ddot{q} + \left(\nabla_{q} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \dot{q} = \nabla_{q} \mathcal{L}
$$</p>
<p>Solving for the accelerations gives:</p>
<p>$$
\ddot{q} = \left(\nabla_{\dot{q}} \nabla_{\dot{q}}^{\top} \mathcal{L}\right)^{-1} \left[ \nabla_{q} \mathcal{L} - \left(\nabla_{q} \nabla_{\dot{q}}^{\top} \mathcal{L}\right) \dot{q} \right]
$$</p>
<p>This requires computing the Hessian of the neural network with respect to $\dot{q}$ and then inverting it (using a pseudoinverse for numerical stability). JAX&rsquo;s automatic differentiation makes this feasible in just a few lines of code, despite the seemingly complex chain of second-order derivatives. The matrix inverse scales as $\mathcal{O}(d^3)$ with the number of coordinates $d$.</p>
<p>A critical implementation detail is the choice of activation function. Since the method takes second-order derivatives of the network, ReLU is unsuitable (its second derivative is zero everywhere). After a hyperparameter search over ReLU$^2$, ReLU$^3$, tanh, sigmoid, and softplus, the authors found <a href="https://en.wikipedia.org/wiki/Softplus">softplus</a> performed best.</p>
<p>The authors also developed a custom initialization scheme, using symbolic regression to find initialization variances that maintain well-conditioned gradients through the Hessian computation:</p>
<p>$$
\sigma = \frac{1}{\sqrt{n}} \begin{cases} 2.2 &amp; \text{First layer} \\ 0.58i &amp; \text{Hidden layer } i \\ n &amp; \text{Output layer} \end{cases}
$$</p>
<h2 id="extension-to-graphs-and-continuous-systems">Extension to Graphs and Continuous Systems</h2>
<p>LNNs extend naturally to graph-structured and continuous systems via Lagrangian <a href="/notes/machine-learning/model-architectures/relational-inductive-biases-deep-learning-graph-networks/">Graph Networks</a>. For a system with $n$ gridpoints, the total Lagrangian is decomposed into local densities:</p>
<p>$$
\mathcal{L} = \sum_{i=1}^{n} \mathcal{L}_i, \quad \text{where} \quad \mathcal{L}_i = \mathcal{L}_{\text{density}}\left({\phi_j, \dot{\phi}_j}_{j \in \mathcal{I}_i}\right)
$$</p>
<p>Here $\mathcal{I}_i$ defines the neighborhood of node $i$ (e.g., ${i-1, i, i+1}$ for a 1D grid). The Lagrangian density is modeled as an MLP. The resulting Hessian matrix is sparse, with non-zero entries only at &ldquo;neighbor of neighbor&rdquo; positions, enabling efficient computation: in 1D, only 5 forward-over-backward autodiff passes are needed, and the tridiagonal inverse runs in linear time.</p>
<h2 id="experiments-double-pendulum-relativity-and-waves">Experiments: Double Pendulum, Relativity, and Waves</h2>
<p>All models used 4-layer MLPs with 500 hidden units, softplus activations, a decaying learning rate starting at $10^{-3}$, and batch size 32.</p>
<h3 id="double-pendulum">Double Pendulum</h3>
<p>The LNN and baseline achieved similar instantaneous acceleration losses ($7.3$ vs. $7.4 \times 10^{-2}$). The key difference appeared in long-term energy conservation: averaged over 40 random initial conditions with 100 time steps, the mean energy discrepancy was 8% of max potential energy for the baseline but only 0.4% for the LNN.</p>
<h3 id="relativistic-particle">Relativistic Particle</h3>
<p>For a particle with Lagrangian $\mathcal{L} = ((1 - \dot{q}^2)^{-1/2} - 1) + gq$, the canonical momenta $\dot{q}(1 - \dot{q}^2)^{-3/2}$ are non-trivial. An HNN trained on non-canonical coordinates $(q, \dot{q})$ failed to learn the dynamics. The LNN succeeded using the same non-canonical coordinates, matching the performance of an HNN given the correct canonical coordinates.</p>
<h3 id="1d-wave-equation">1D Wave Equation</h3>
<p>The Lagrangian Graph Network learned the wave equation dynamics ($\ddot{\phi} = \frac{\partial^2 \phi}{\partial x^2}$ with $c = 1$) on a 100-gridpoint domain with periodic boundary conditions. The network learned the Lagrangian density corresponding to the continuum form $\mathcal{L} = \int (\dot{\phi}^2 - (\partial \phi / \partial x)^2) dx$, accurately modeling wave propagation and conserving energy across the material.</p>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Model</th>
          <th>Energy Error (% of max PE)</th>
          <th>Canonical Coords Required</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Double Pendulum</td>
          <td>Baseline</td>
          <td>8%</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Double Pendulum</td>
          <td>LNN</td>
          <td>0.4%</td>
          <td>No</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>HNN (non-canonical)</td>
          <td>Failed</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>HNN (canonical)</td>
          <td>Succeeded</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Relativistic Particle</td>
          <td>LNN</td>
          <td>Succeeded</td>
          <td>No</td>
      </tr>
      <tr>
          <td>1D Wave Equation</td>
          <td>LGN</td>
          <td>Energy conserved</td>
          <td>No</td>
      </tr>
  </tbody>
</table>
<h2 id="findings-and-comparison-to-prior-approaches">Findings and Comparison to Prior Approaches</h2>
<p>LNNs combine several desirable properties that no single prior method offers:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Neural Net</th>
          <th>Neural ODE</th>
          <th>HNN</th>
          <th>DeLaN</th>
          <th>LNN</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Models dynamical systems</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns differential equations</td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns exact conservation laws</td>
          <td></td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns from arbitrary coordinates</td>
          <td>Yes</td>
          <td>Yes</td>
          <td></td>
          <td>Yes</td>
          <td>Yes</td>
      </tr>
      <tr>
          <td>Learns arbitrary Lagrangians</td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td>Yes</td>
      </tr>
  </tbody>
</table>
<p>The main limitation is computational cost: the Hessian computation and inversion scale as $\mathcal{O}(d^3)$ in the number of coordinates. The Lagrangian Graph Network partially mitigates this for spatially extended systems through the sparsity of the resulting Hessian. The method also assumes access to state derivatives ($\dot{q}$) during training, which may not always be directly available from observations.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Double pendulum</td>
          <td>600,000 random initial conditions</td>
          <td>Simulated with masses and lengths set to 1</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>Relativistic particle</td>
          <td>Random initial conditions and $g$ values</td>
          <td>$c = 1$, mass = 1, uniform potential</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>1D wave equation</td>
          <td>100 gridpoints</td>
          <td>Periodic boundary conditions, $c = 1$</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Forward model: Euler-Lagrange equation solved via Equation 6 using JAX autodiff</li>
<li>Pseudoinverse used for Hessian inversion to handle potential singular matrices</li>
<li>Custom initialization scheme (Equation 16) derived via symbolic regression with eureqa</li>
<li>Softplus activation selected via hyperparameter search</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>4-layer MLP with 500 hidden units for all experiments</li>
<li>Softplus activation function</li>
<li>Code: <a href="https://github.com/MilesCranmer/lagrangian_nns">github.com/MilesCranmer/lagrangian_nns</a> (Apache-2.0)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>LNN</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Acceleration loss (double pendulum)</td>
          <td>$7.3 \times 10^{-2}$</td>
          <td>$7.4 \times 10^{-2}$</td>
          <td>Similar short-term accuracy</td>
      </tr>
      <tr>
          <td>Energy error (double pendulum)</td>
          <td>0.4%</td>
          <td>8%</td>
          <td>Percentage of max potential energy</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. JAX-based implementation supports CPU and GPU execution.</p>
<hr>
<p><strong>Reproducibility Status</strong>: Highly Reproducible</p>
<h2 id="artifacts">Artifacts</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MilesCranmer/lagrangian_nns">lagrangian_nns</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX implementation with notebooks for all experiments</td>
      </tr>
      <tr>
          <td>Training data</td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Generated procedurally; simulation code included in repository</td>
      </tr>
      <tr>
          <td>Trained models</td>
          <td>Model</td>
          <td>N/A</td>
          <td>Not provided</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cranmer, M., Greydanus, S., Hoyer, S., Battaglia, P., Spergel, D., &amp; Ho, S. (2020). Lagrangian Neural Networks. <em>ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations</em>. arXiv: <a href="https://arxiv.org/abs/2003.04630">2003.04630</a></p>
<p><strong>Publication</strong>: ICLR 2020 Workshop on Integration of Deep Neural Models and Differential Equations</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{cranmer2020lagrangian,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Lagrangian Neural Networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cranmer, Miles and Greydanus, Sam and Hoyer, Stephan and Battaglia, Peter and Spergel, David and Ho, Shirley}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2003.04630}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Ewald Message Passing for Molecular Graphs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/ewald-message-passing-molecular-graphs/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/ewald-message-passing-molecular-graphs/</guid><description>Ewald message passing augments GNNs with Fourier-space long-range interactions, improving energy predictions by 10-16% on OC20 and OE62 benchmarks.</description><content:encoded><![CDATA[<h2 id="a-fourier-space-long-range-correction-for-molecular-gnns">A Fourier-Space Long-Range Correction for Molecular GNNs</h2>
<p>This is a <strong>Method</strong> paper that introduces Ewald message passing (Ewald MP), a general framework for incorporating long-range interactions into message passing neural networks (MPNNs) for molecular <a href="/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/">potential energy surface</a> prediction. The key contribution is a nonlocal Fourier-space message passing scheme, grounded in the classical <a href="https://en.wikipedia.org/wiki/Ewald_summation">Ewald summation</a> technique from computational physics, that complements the short-range message passing of existing GNN architectures.</p>
<h2 id="the-long-range-interaction-problem-in-molecular-gnns">The Long-Range Interaction Problem in Molecular GNNs</h2>
<p>Standard MPNNs for molecular property prediction rely on a spatial distance cutoff to define atomic neighborhoods. While this locality assumption enables favorable scaling with system size and provides a useful inductive bias, it fundamentally limits the model&rsquo;s ability to capture long-range interactions such as electrostatic forces and van der Waals (<a href="https://en.wikipedia.org/wiki/London_dispersion_force">London dispersion</a>) interactions. These interactions decay slowly with distance (e.g., electrostatic energy follows a $1/r$ power law), and truncating them with a distance cutoff can introduce severe artifacts in thermochemical predictions.</p>
<p>This problem is well-known in molecular dynamics, where empirical force fields explicitly separate bonded (short-range) and non-bonded (long-range) energy terms. The Ewald summation technique addresses this by decomposing interactions into a short-range part that converges quickly with a distance cutoff and a long-range part whose Fourier transform converges quickly with a frequency cutoff. The authors propose bringing this same strategy into the GNN paradigm.</p>
<h2 id="from-ewald-summation-to-learnable-fourier-space-messages">From Ewald Summation to Learnable Fourier-Space Messages</h2>
<p>The core insight is a formal analogy between the continuous-filter convolution used in MPNNs and the electrostatic potential computation in Ewald summation. In a standard continuous-filter convolution, the message sum for atom $i$ is:</p>
<p>$$
M_i^{(l+1)} = \sum_{j \in \mathcal{N}(i)} h_j^{(l)} \cdot \Phi^{(l)}(| \mathbf{x}_i - \mathbf{x}_j |)
$$</p>
<p>where $h_j^{(l)}$ are atom embeddings and $\Phi^{(l)}$ is a learned radial filter. Comparing this to the electrostatic potential $V_i^{\text{es}}(\mathbf{x}_i) = \sum_{j \neq i} q_j \cdot \Phi^{\text{es}}(| \mathbf{x}_i - \mathbf{x}_j |)$ reveals a direct correspondence: atom embeddings play the role of partial charges, and learned filters replace the $1/r$ kernel.</p>
<p>Ewald MP decomposes the learned filter into short-range and long-range components. The short-range part is handled by any existing GNN architecture with a distance cutoff. The long-range part is computed as a sum over Fourier frequencies:</p>
<p>$$
M^{\text{lr}}(\mathbf{x}_i) = \sum_{\mathbf{k}} \exp(i \mathbf{k}^T \mathbf{x}_i) \cdot s_{\mathbf{k}} \cdot \hat{\Phi}^{\text{lr}}(| \mathbf{k} |)
$$</p>
<p>where $s_{\mathbf{k}}$ are <strong><a href="https://en.wikipedia.org/wiki/Structure_factor">structure factor</a> embeddings</strong>, computed as:</p>
<p>$$
s_{\mathbf{k}} = \sum_{j \in \mathcal{S}} h_j \exp(-i \mathbf{k}^T \mathbf{x}_j)
$$</p>
<p>These structure factor embeddings are a Fourier-space representation of the atom embedding distribution, and truncating to low frequencies effectively coarse-grains the hidden model state while preserving long-range information. The frequency filters $\hat{\Phi}^{\text{lr}}$ are learned, making the entire scheme data-driven rather than tied to a fixed physical functional form.</p>
<p>The method handles both <strong>periodic</strong> systems (where the <a href="https://en.wikipedia.org/wiki/Reciprocal_lattice">reciprocal lattice</a> provides a natural frequency discretization) and <strong>aperiodic</strong> systems (where the Fourier domain is discretized using a cubic voxel grid with SVD-based rotation alignment to preserve rotation invariance). The combined embedding update becomes:</p>
<p>$$
h_i^{(l+1)} = \frac{1}{\sqrt{3}} \left[ h_i^{(l)} + f_{\text{upd}}^{\text{sr}}(M_i^{\text{sr}}) + f_{\text{upd}}^{\text{lr}}(M_i^{\text{lr}}) \right]
$$</p>
<p>The computational complexity is $\mathcal{O}(N_{\text{at}} N_{\text{k}})$, and by fixing the number of frequency vectors $N_{\text{k}}$, linear scaling $\mathcal{O}(N_{\text{at}})$ is achievable.</p>
<h2 id="experiments-across-four-gnn-architectures-and-two-datasets">Experiments Across Four GNN Architectures and Two Datasets</h2>
<p>The authors test Ewald MP as an augmentation on four baseline architectures: <a href="/notes/chemistry/datasets/marcel/">SchNet, PaiNN, DimeNet++, and GemNet-T</a>. Two datasets are used:</p>
<ul>
<li><strong>OC20</strong> (Chanussot et al., 2021): ~265M periodic structures of adsorbate-catalyst systems with DFT-computed energies and forces. The OC20-2M subsplit is used for training.</li>
<li><strong>OE62</strong> (Stuke et al., 2020): ~62,000 large aperiodic organic molecules with DFT-computed energies that include a DFT-D3 dispersion correction for London dispersion interactions.</li>
</ul>
<p>All baselines use a 6 Å distance cutoff and 50 maximum neighbors. The Ewald modification is minimal: the long-range message sum is added as an additional skip connection term in each interaction block. Comparison studies include: (1) increasing the distance cutoff to match the computational cost of Ewald MP, (2) replacing the Ewald block with a SchNet interaction block at increased cutoff, and (3) increasing atom embedding dimensions to match Ewald MP&rsquo;s parameter count.</p>
<h3 id="key-energy-mae-results-on-oe62">Key Energy MAE Results on OE62</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Baseline (meV)</th>
          <th>Ewald MP (meV)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>133.5</td>
          <td>79.2</td>
          <td>40.7%</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>61.4</td>
          <td>57.9</td>
          <td>5.7%</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>51.2</td>
          <td>46.5</td>
          <td>9.2%</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>51.5</td>
          <td>47.4</td>
          <td>8.0%</td>
      </tr>
  </tbody>
</table>
<h3 id="key-energy-mae-results-on-oc20-averaged-across-test-splits">Key Energy MAE Results on OC20 (Averaged Across Test Splits)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Baseline (meV)</th>
          <th>Ewald MP (meV)</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>895</td>
          <td>830</td>
          <td>7.3%</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>448</td>
          <td>393</td>
          <td>12.3%</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>496</td>
          <td>445</td>
          <td>10.4%</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>346</td>
          <td>307</td>
          <td>11.3%</td>
      </tr>
  </tbody>
</table>
<h2 id="robust-long-range-improvements-and-dispersion-recovery">Robust Long-Range Improvements and Dispersion Recovery</h2>
<p>Ewald MP achieves consistent improvements across all models and both datasets, averaging 16.1% on OE62 and 10.3% on OC20. Several findings stand out:</p>
<ol>
<li>
<p><strong>Robustness</strong>: Unlike the increased-cutoff and SchNet-LR alternatives, Ewald MP never produces detrimental effects in any tested configuration. The increased cutoff setting hurts SchNet and PaiNN on OE62, and the SchNet-LR block fails to improve DimeNet++ and GemNet-T.</p>
</li>
<li>
<p><strong>Long-range specificity</strong>: A binning analysis on OE62 groups molecules by the magnitude of their DFT-D3 dispersion correction. Ewald MP shows an outsize improvement for structures with large long-range energy contributions. It recovers or surpasses a &ldquo;cheating&rdquo; baseline that receives the exact DFT-D3 ground truth as an additional input.</p>
</li>
<li>
<p><strong>Efficiency on periodic systems</strong>: Ewald MP achieves similar relative improvements on OC20 at roughly half the relative computational cost compared to OE62, suggesting periodic structures as a particularly attractive application domain.</p>
</li>
<li>
<p><strong>Force predictions</strong>: Improvements in <a href="/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/">force MAEs</a> are consistent but small, which is expected since the frequency truncation removes high-frequency contributions to the potential energy surface.</p>
</li>
<li>
<p><strong>Ablation studies</strong>: Results are robust across different frequency cutoffs, voxel resolutions, and filtering strategies, with the non-radial periodic filtering scheme outperforming radial alternatives on out-of-distribution generalization.</p>
</li>
</ol>
<p>Limitations include the current focus on scalar (invariant) embeddings only (PaiNN&rsquo;s equivariant vector embeddings are not augmented), and the potential for a &ldquo;gap&rdquo; of medium-range interactions when $N_{\text{k}}$ is fixed for linear scaling. The authors suggest adapting more efficient Ewald summation variants (e.g., particle mesh Ewald with $\mathcal{O}(N \log N)$ scaling) as future work.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (periodic)</td>
          <td>OC20-2M</td>
          <td>~2M structures</td>
          <td>Subsplit of OC20; PBC; DFT energies and forces</td>
      </tr>
      <tr>
          <td>Training (aperiodic)</td>
          <td>OE62</td>
          <td>~62,000 molecules</td>
          <td>Large organic molecules; DFT energies with D3 correction</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>OC20-test (4 splits: ID, OOD-ads, OOD-cat, OOD-both)</td>
          <td>Varies</td>
          <td>Evaluated via submission to OC20 evaluation server</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>OE62-val, OE62-test</td>
          <td>~6,000 each</td>
          <td>Direct evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Ewald message passing is integrated as an additional skip connection term in each interaction block</li>
<li>For periodic systems: non-radial filtering with fixed reciprocal lattice positions ($N_x, N_y, N_z$ hyperparameters)</li>
<li>For aperiodic systems: radial Gaussian basis function filtering with frequency cutoff $c_k$ and voxel resolution $\Delta = 0.2$ Å$^{-1}$</li>
<li>SVD-based coordinate alignment for rotation invariance in the aperiodic case</li>
<li>Bottleneck dimension $N_\downarrow = 16$ (GemNet-T) or $N_\downarrow = 8$ (others)</li>
<li>Update function: dense layer + $N_{\text{hidden}}$ residual layers ($N_{\text{hidden}} = 3$, except PaiNN with $N_{\text{hidden}} = 0$)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Embedding Size (OE62)</th>
          <th>Interaction Blocks</th>
          <th>Ewald Params (OE62)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>512</td>
          <td>4</td>
          <td>12.2M total</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>512</td>
          <td>4</td>
          <td>15.7M total</td>
      </tr>
      <tr>
          <td>DimeNet++</td>
          <td>256</td>
          <td>3</td>
          <td>4.8M total</td>
      </tr>
      <tr>
          <td>GemNet-T</td>
          <td>256</td>
          <td>3</td>
          <td>16.1M total</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Primary metric: Energy mean absolute error (EMAE) in meV</li>
<li>Secondary metric: Force MAE in meV/Å (OC20 only)</li>
<li>Loss: Linear combination of energy and force MAEs (Eq. 15) with model-specific force multipliers</li>
<li>Optimizer: Adam with weight decay ($\lambda = 0.01$)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All runtime measurements on NVIDIA A100 GPUs</li>
<li>Runtimes measured after 50 warmup batches, averaged over 500 batches, minimum of 3 repetitions</li>
<li>Code: <a href="https://github.com/arthurkosmala/EwaldMP">EwaldMP</a> (Hippocratic License 3.0)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/arthurkosmala/EwaldMP">EwaldMP</a></td>
          <td>Code</td>
          <td>Hippocratic License 3.0 (new files) / MIT (OC20 base)</td>
          <td>Official implementation built on the Open Catalyst Project codebase</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md">OC20</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>~265M periodic adsorbate-catalyst structures with DFT energies and forces</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.1038/s41597-020-0385-y">OE62</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>~62,000 large organic molecules with DFT energies including D3 correction</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Source code, both datasets, and detailed hyperparameters (including per-model learning rates, batch sizes, and Ewald-specific settings) are all publicly available. Pre-trained model weights are not provided.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kosmala, A., Gasteiger, J., Gao, N., &amp; Günnemann, S. (2023). Ewald-based Long-Range Message Passing for Molecular Graphs. In <em>Proceedings of the 40th International Conference on Machine Learning (ICML 2023)</em>.</p>
<p><strong>Publication</strong>: ICML 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kosmala2023ewald,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Ewald-based Long-Range Message Passing for Molecular Graphs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kosmala, Arthur and Gasteiger, Johannes and Gao, Nicholas and G{\&#34;u}nnemann, Stephan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 40th International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{202}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Block-Recurrent Transformers for Long Sequences</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/block-recurrent-transformers/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/block-recurrent-transformers/</guid><description>Block-Recurrent Transformers combine attention and recurrence for linear-complexity language modeling on long documents like books and code.</description><content:encoded><![CDATA[<h2 id="a-method-for-combining-attention-with-block-level-recurrence">A Method for Combining Attention with Block-Level Recurrence</h2>
<p>This is a <strong>Method</strong> paper that introduces the Block-Recurrent Transformer, a model architecture that integrates recurrence into the transformer framework at the block level. Rather than processing tokens one at a time (as in traditional RNNs) or attending over entire sequences (as in standard transformers), this approach applies a transformer layer recurrently across blocks of tokens. The result is a model with linear complexity in sequence length that maintains the parallelism benefits of transformers during training. A related approach, <a href="/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/">RWKV</a>, later explored similar ideas using linear attention with channel-wise decay.</p>
<h2 id="why-transformers-struggle-with-long-documents">Why Transformers Struggle with Long Documents</h2>
<p>Transformers have largely replaced RNNs for sequence modeling tasks, but their quadratic self-attention cost limits the length of sequences they can process. A transformer with a window size of 512 tokens cannot see information beyond that window, making it blind to long-range dependencies in books, technical papers, or source code repositories.</p>
<p>Prior approaches to this problem fall into several categories: sparse attention patterns (BigBird, Routing Transformers, Reformer), sequence compression (Linformer, Funnel Transformers), and linearized attention approximations. These methods either sacrifice the expressiveness of full softmax attention or introduce implementation complexity.</p>
<p>Traditional RNNs like LSTMs offer linear complexity but suffer from three key limitations: sequential processing prevents parallelism on modern hardware, a single state vector bottlenecks information capacity, and vanishing gradients limit effective memory to a few hundred tokens.</p>
<h2 id="block-level-recurrence-with-lstm-style-gates">Block-Level Recurrence with LSTM-Style Gates</h2>
<p>The core innovation is applying a standard transformer layer in a recurrent fashion along the sequence, operating on blocks of $W$ tokens rather than individual tokens. The recurrent cell maintains $S$ state vectors (typically $S = W = 512$) that are updated at each block boundary.</p>
<h3 id="the-recurrent-cell">The Recurrent Cell</h3>
<p>The cell has two processing directions:</p>
<ul>
<li><strong>Vertical direction</strong>: An ordinary transformer layer with self-attention over input tokens and cross-attention to recurrent states, producing output embeddings.</li>
<li><strong>Horizontal direction</strong>: Self-attention over current state vectors and cross-attention to input tokens, producing updated state vectors. Residual connections are replaced with gates.</li>
</ul>
<p>Self-attention and cross-attention are computed in parallel (not sequentially), with results concatenated and fed into a linear projection. Keys and values are shared between directions, while queries are separate, yielding four query sets: $Q_e^v$, $Q_s^v$ (vertical) and $Q_s^h$, $Q_e^h$ (horizontal).</p>
<h3 id="gating-mechanisms">Gating Mechanisms</h3>
<p>Two gate types are explored. The <strong>fixed gate</strong> uses a learned convex combination:</p>
<p>$$
g = \sigma(b_g)
$$</p>
<p>$$
c_{t+1} = c_t \odot g + z_t \odot (1 - g)
$$</p>
<p>where $g$ is constant after training, implementing an <a href="https://en.wikipedia.org/wiki/Moving_average">exponential moving average</a>.</p>
<p>The <strong>LSTM gate</strong> uses input and forget gates:</p>
<p>$$
i_t = \sigma(W_i h_t + b_i - 1)
$$</p>
<p>$$
f_t = \sigma(W_f h_t + b_f + 1)
$$</p>
<p>$$
c_{t+1} = c_t \odot f_t + z_t \odot i_t
$$</p>
<p>The bias offsets ($-1$ for input, $+1$ for forget) initialize the model to &ldquo;remember&rdquo; by default, which is critical for training stability. Without careful initialization, the model can fall into a local optimum where it ignores the recurrent state entirely. This echoes the <a href="/notes/machine-learning/model-architectures/can-recurrent-neural-networks-warp-time/">gate initialization challenges studied by Tallec and Ollivier</a>, who derived chrono initialization for LSTMs from time-warping invariance.</p>
<h3 id="gate-configurations">Gate Configurations</h3>
<p>Three configurations are tested: <strong>dual</strong> (gates on both attention and MLP outputs), <strong>single</strong> (gate only on MLP output), and <strong>skip</strong> (gate only on attention output, no MLP). The skip configuration removes the large MLP from the recurrent layer entirely.</p>
<h3 id="learned-state-ids">Learned State IDs</h3>
<p>Since the same weights are applied to all state vectors, learned &ldquo;state IDs&rdquo; (analogous to position embeddings) are added so each state vector can issue distinct queries. <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style relative position bias is used for token self-attention, with no position bias for state-token cross-attention.</p>
<h2 id="language-modeling-on-pg19-arxiv-and-github">Language Modeling on PG19, arXiv, and GitHub</h2>
<h3 id="experimental-setup">Experimental Setup</h3>
<p>The base model is a 12-layer transformer with 150M parameters (8 heads of size 128, embedding dimension 1024, MLP hidden size 4096). The recurrent layer is placed at layer 10 with segment length $N = 4096$ and window size $W = 512$. The architecture is evaluated on three long-document datasets:</p>
<ul>
<li><strong>PG19</strong>: Full-length books from <a href="https://en.wikipedia.org/wiki/Project_Gutenberg">Project Gutenberg</a> (pre-1919)</li>
<li><strong>arXiv</strong>: Mathematics papers in LaTeX</li>
<li><strong>GitHub</strong>: Concatenated source code from open-source repositories</li>
</ul>
<p>All models report bits-per-token ($\log_2$ perplexity, lower is better).</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines are compared: Transformer-XL with window sizes of 512, 1024, and 2048, plus 12-layer and 13-layer sliding window models. The 13-layer sliding window (Slide:13L) is the primary comparison, having equivalent computation cost and parameter count to the recurrent models.</p>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Step Time</th>
          <th>PG19 (bytes)</th>
          <th>PG19 (tokens)</th>
          <th>arXiv</th>
          <th>GitHub</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>XL:512</td>
          <td>0.88</td>
          <td>1.01</td>
          <td>3.62</td>
          <td>1.45</td>
          <td>1.21</td>
      </tr>
      <tr>
          <td>XL:2048</td>
          <td>2.11</td>
          <td>0.990</td>
          <td>3.58</td>
          <td>1.31</td>
          <td>1.01</td>
      </tr>
      <tr>
          <td>Slide:13L</td>
          <td>1.00</td>
          <td>0.989</td>
          <td>3.58</td>
          <td>1.42</td>
          <td>1.17</td>
      </tr>
      <tr>
          <td>Rec:fixed:skip</td>
          <td>0.99</td>
          <td>0.952</td>
          <td>3.53</td>
          <td>1.24</td>
          <td>0.976</td>
      </tr>
      <tr>
          <td>Rec:fixed:dual</td>
          <td>1.01</td>
          <td>0.957</td>
          <td>3.52</td>
          <td>1.27</td>
          <td>0.991</td>
      </tr>
      <tr>
          <td>Feedback:fixed:skip</td>
          <td>1.35</td>
          <td>0.935</td>
          <td>3.49</td>
          <td>1.24</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Memorizing Trans. 64k</td>
          <td>1.94</td>
          <td>0.950</td>
          <td>3.53</td>
          <td>1.22</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>The Rec:fixed:skip configuration achieves the best overall results while being slightly faster than the 13-layer baseline. It outperforms XL:2048, which runs over 2x slower. The block feedback variant (allowing all layers to cross-attend to recurrent states) improves perplexity further at ~35-40% higher step time.</p>
<h3 id="scaling-behavior">Scaling Behavior</h3>
<p>Models from 40M to 1.3B parameters show that the benefit of recurrence is <a href="/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/">consistent across scales</a> and increases with model size. At larger sizes, adding recurrence provides a benefit greater than doubling the number of parameters. The 1.3B parameter model achieves 26.50 word-level perplexity on PG19, setting a new state of the art at the time of publication.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>PG19 Perplexity</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Compressive Transformer</td>
          <td>36</td>
          <td>33.6</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Routing Transformer</td>
          <td>22</td>
          <td>33.2</td>
          <td>490M</td>
      </tr>
      <tr>
          <td>Perceiver AR</td>
          <td>60</td>
          <td>28.9</td>
          <td>974.6M</td>
      </tr>
      <tr>
          <td>Block-Recurrent Transformer</td>
          <td>24</td>
          <td>26.50</td>
          <td>1.3B</td>
      </tr>
  </tbody>
</table>
<h3 id="ablations">Ablations</h3>
<ul>
<li><strong>Multiple recurrent layers</strong>: Two adjacent layers (9, 10) provide no benefit. Two separated layers (4, 10) help but no more than adding another non-recurrent layer.</li>
<li><strong>Number of states</strong>: Improvement up to 1024 states, degradation at 2048.</li>
<li><strong>Window size reduction</strong>: Reducing the sliding window hurts Transformer-XL dramatically but has smaller impact on the recurrent model, which compensates via recurrence.</li>
<li><strong>Gate type</strong>: The fixed gate consistently outperforms the LSTM gate despite being theoretically less expressive.</li>
</ul>
<h3 id="qualitative-analysis">Qualitative Analysis</h3>
<p>Comparing per-token predictions against Transformer-XL on PG19 books, the recurrent model&rsquo;s advantage comes overwhelmingly from predicting proper names (17/20 top-improvement tokens). In 19/20 cases, the predicted word was outside the attention window, confirming it was stored in recurrent state. The model can remember book titles and authors across 60,000+ tokens.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>The Block-Recurrent Transformer demonstrates that recurrence at the block level is a cost-effective way to improve language modeling on long sequences. The fixed:skip configuration (the simplest variant) performs best, suggesting the model primarily uses recurrence for long-range name lookup rather than complex reasoning. The fact that removing the MLP from the recurrent layer has minimal impact further supports this interpretation.</p>
<p>Key limitations include: the model was only evaluated on language modeling perplexity (no downstream tasks), the LSTM gate underperforms the simpler fixed gate (suggesting untapped potential for more expressive recurrence), and the authors acknowledge that training the recurrent layer to fully exploit its capacity for knowledge extraction will require further advances.</p>
<p>The authors note that evaluating on downstream tasks requiring long-range context (book summarization, long-document QA, code completion) is an important direction for future work.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>PG19</td>
          <td>~29k books</td>
          <td>Public domain, freely available</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>arXiv</td>
          <td>Mathematics papers</td>
          <td>Obtained via private channels, not redistributable</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>GitHub</td>
          <td>Open-source repos</td>
          <td>Obtained via private channels, not redistributable</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adafactor</li>
<li>Learning rate: 1.0 with inverse square root decay (initial experiments), cosine decay with max 0.01 (scaling experiments)</li>
<li>Warmup: 1000 steps</li>
<li>Dropout: 0.05</li>
<li>Vocabulary: 32k SentencePiece (T5 pretrained for initial, custom for scaling)</li>
<li>Gate initialization: bias of $+1$ for forget gate, $-1$ for input gate to ensure initial &ldquo;remember&rdquo; behavior</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Variant</th>
          <th>Layers</th>
          <th>Parameters</th>
          <th>Recurrent Layers</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Base</td>
          <td>12 (+1 recurrent)</td>
          <td>~151-164M</td>
          <td>Layer 10</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>24 (+2 recurrent)</td>
          <td>650M</td>
          <td>Layers 10, 20</td>
      </tr>
      <tr>
          <td>XL</td>
          <td>24 (+2 recurrent)</td>
          <td>1.3B</td>
          <td>Layers 10, 20</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Model</th>
          <th>PG19 (tokens)</th>
          <th>arXiv</th>
          <th>GitHub</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Bits-per-token</td>
          <td>Rec:fixed:skip</td>
          <td>3.53</td>
          <td>1.24</td>
          <td>0.976</td>
      </tr>
      <tr>
          <td>Word-level PPL</td>
          <td>1.3B model</td>
          <td>26.50</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Error bars on PG19 are between 0.002 and 0.007 (3 runs with different seeds).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: 32 Google V4 TPU replicas</li>
<li>Training time: ~48 hours for 500k steps on PG19</li>
<li>Batch size: 32 (segment length 4096) or 256 (segment length 512), adjusted so each model sees the same tokens per step</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Available</th>
          <th>License</th>
          <th>URL</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Code (Meliad)</td>
          <td>Yes</td>
          <td>Apache 2.0</td>
          <td><a href="https://github.com/google-research/meliad">github.com/google-research/meliad</a></td>
      </tr>
      <tr>
          <td>PG19 Dataset</td>
          <td>Yes</td>
          <td>Public Domain</td>
          <td>Public</td>
      </tr>
      <tr>
          <td>arXiv Dataset</td>
          <td>No</td>
          <td>Not redistributable</td>
          <td>Private</td>
      </tr>
      <tr>
          <td>GitHub Dataset</td>
          <td>No</td>
          <td>Not redistributable</td>
          <td>Private</td>
      </tr>
      <tr>
          <td>Pretrained Models</td>
          <td>No</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Assessment</strong>: Partially Reproducible. Source code is available under Apache 2.0 and the PG19 dataset is public. However, two of three evaluation datasets (arXiv, GitHub) were obtained via private channels and are not redistributable. No pretrained model checkpoints are released.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hutchins, D., Schlag, I., Wu, Y., Dyer, E., &amp; Neyshabur, B. (2022). Block-Recurrent Transformers. <em>Advances in Neural Information Processing Systems 35 (NeurIPS 2022)</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{hutchins2022block,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Block-Recurrent Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hutchins, DeLesley and Schlag, Imanol and Wu, Yuhuai and Dyer, Ethan and Neyshabur, Behnam}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2203.07852}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>NaViT: Native Resolution Vision Transformer</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/navit-native-resolution-vit/</link><pubDate>Mon, 06 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/navit-native-resolution-vit/</guid><description>NaViT uses sequence packing to train Vision Transformers on images at native resolution and aspect ratio, improving efficiency and flexibility.</description><content:encoded><![CDATA[<h2 id="a-method-for-flexible-resolution-vision-transformers">A Method for Flexible-Resolution Vision Transformers</h2>
<p>This is a <strong>Method</strong> paper that introduces NaViT (Native Resolution ViT), a Vision Transformer trained using sequence packing to handle images of arbitrary resolution and aspect ratio. The core idea, called &ldquo;Patch n&rsquo; Pack,&rdquo; borrows example packing from NLP and applies it to vision: patches from multiple images of different sizes are concatenated into a single sequence, enabling native-resolution processing without resizing or padding.</p>
<h2 id="why-fixed-resolution-pipelines-are-suboptimal">Why Fixed-Resolution Pipelines Are Suboptimal</h2>
<p>Standard computer vision pipelines resize all images to a fixed square resolution before processing. This practice originates from convolutional neural network constraints, where fixed spatial dimensions were architecturally required. Even with Vision Transformers, which operate on sequences of patches and could in principle handle variable lengths, the convention of fixed-resolution input persists.</p>
<p>This approach has clear drawbacks. Most images are not square: analysis of ImageNet, LVIS, and WebLI shows that most images deviate more than 20% from a 1:1 aspect ratio. Resizing distorts content and discards information, while padding wastes computation. Prior work like FlexiViT addressed variable patch sizes and Pix2Struct introduced aspect-ratio-preserving patching, but neither fully solved the problem of training efficiently on images at their original resolution.</p>
<h2 id="patch-n-pack-sequence-packing-for-vision">Patch n&rsquo; Pack: Sequence Packing for Vision</h2>
<p>The key insight is that ViT already processes images as sequences of patch tokens, and NLP has long used example packing to handle variable-length sequences efficiently. NaViT applies this directly: patches from multiple images (each at its native resolution and aspect ratio) are packed into a single fixed-length sequence.</p>
<h3 id="architectural-modifications">Architectural Modifications</h3>
<p>Three changes enable Patch n&rsquo; Pack:</p>
<ol>
<li>
<p><strong>Masked self-attention and masked pooling</strong>: Attention masks prevent patches from different images from attending to each other. Masked pooling extracts a single representation per image from the packed sequence.</p>
</li>
<li>
<p><strong>Factorized positional embeddings</strong>: Standard 1D positional embeddings cannot handle arbitrary resolutions. NaViT decomposes position into separate $x$ and $y$ embeddings $\phi_{x}$ and $\phi_{y}$, which are summed together. Two schemes are considered:</p>
<ul>
<li>Absolute embeddings: $\phi(p): [0, \text{maxLen}] \to \mathbb{R}^{D}$, a function of the absolute patch index</li>
<li>Fractional embeddings: $\phi(r): [0, 1] \to \mathbb{R}^{D}$, where $r = p / \text{side-length}$ is the relative position along the image</li>
</ul>
</li>
<li>
<p><strong>Chunked contrastive loss</strong>: For contrastive pretraining, the $\mathcal{O}(n^{2})$ loss computation is handled via chunked computation across device subsets to support the high number of examples per sequence.</p>
</li>
</ol>
<h3 id="training-innovations">Training Innovations</h3>
<p>Packing enables two techniques that were previously impractical:</p>
<ul>
<li>
<p><strong>Continuous token dropping</strong>: Instead of dropping the same proportion of tokens from every image, the drop rate varies per image. Some images keep all tokens while others have aggressive dropping, reducing the train/inference discrepancy. The drop rate can follow a schedule that decreases over training.</p>
</li>
<li>
<p><strong>Resolution sampling</strong>: Each image&rsquo;s resolution is sampled from a distribution (e.g., $R \sim \mathcal{U}(64, R_{\text{max}})$) while preserving aspect ratio. This mixes the throughput benefits of small images with the detail of large ones.</p>
</li>
</ul>
<h3 id="computational-overhead">Computational Overhead</h3>
<p>A natural concern is the $\mathcal{O}(n^{2})$ attention cost for longer packed sequences. In practice, as the transformer hidden dimension scales, attention becomes an increasingly small fraction of total compute (the MLP dominates). Packing overhead is typically less than 2% from padding tokens, using a simple greedy bin-packing algorithm.</p>
<h2 id="pretraining-and-downstream-evaluation">Pretraining and Downstream Evaluation</h2>
<p>NaViT is evaluated in two pretraining setups:</p>
<ul>
<li><strong>Classification pretraining</strong> on JFT-4B with sigmoid cross-entropy loss, evaluated via linear probing (10 examples per class)</li>
<li><strong>Contrastive pretraining</strong> on WebLI using image-text contrastive loss, evaluated on zero-shot ImageNet classification and COCO retrieval</li>
</ul>
<h3 id="training-efficiency">Training Efficiency</h3>
<p>At fixed compute budget, NaViT consistently outperforms ViT across model scales. The top-performing ViT can be matched by NaViT with 4x less compute. The primary driver is throughput: packing with variable resolution and token dropping enables NaViT-L/16 to process approximately 5x more images during training.</p>
<h3 id="variable-resolution-results">Variable Resolution Results</h3>
<p>Models trained with variable resolution ($R \sim \mathcal{U}(64, R_{\text{max}})$) outperform fixed-resolution models even when evaluated at the fixed resolution&rsquo;s own training resolution. Sampling side lengths from a truncated normal biased toward lower values gives the best cost-performance trade-off.</p>
<p>For fine-tuning on ImageNet-1k, a single NaViT fine-tuned with variable resolutions (64 to 512) matches the performance of models fine-tuned at each specific resolution individually.</p>
<h3 id="positional-embedding-comparison">Positional Embedding Comparison</h3>
<p>Factorized embeddings outperform both standard ViT 1D embeddings (with interpolation) and Pix2Struct&rsquo;s learned 2D embeddings. The factorized approach generalizes to resolutions outside the training range, while 2D embeddings fail because they require seeing all $(x, y)$ coordinate pairs during training. Additive combination of $\phi_{x}$ and $\phi_{y}$ works best.</p>
<h3 id="token-dropping-strategies">Token Dropping Strategies</h3>
<p>Variable token dropping with Beta-distributed rates consistently outperforms constant rates. Resolution-dependent dropping (higher rates for higher-resolution images) further improves performance. Scheduling the drop rate to decrease over training provides additional gains.</p>
<h3 id="downstream-tasks">Downstream Tasks</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Setup</th>
          <th>Result</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Semantic segmentation</td>
          <td>ADE20k, L/16, linear decoder</td>
          <td>NaViT at $R_{384}$ beats ViT at $R_{512}$ while being 2x faster</td>
      </tr>
      <tr>
          <td>Object detection</td>
          <td>OWL-ViT-L/14 backbone</td>
          <td>NaViT: 28.3% LVIS AP vs. ViT: 23.3%</td>
      </tr>
      <tr>
          <td>Video classification</td>
          <td>Kinetics-400, tubelet extraction</td>
          <td>NaViT-L matches ViViT-L (80.4%) in ~6x fewer epochs</td>
      </tr>
      <tr>
          <td>Fairness annotation</td>
          <td>FairFace, CelebA linear probes</td>
          <td>Statistically significant accuracy improvements ($p = 3 \times 10^{-4}$)</td>
      </tr>
  </tbody>
</table>
<h3 id="out-of-distribution-robustness">Out-of-Distribution Robustness</h3>
<p>NaViT shows strong gains on ImageNet-A (which contains many extreme aspect ratios) when evaluated without center cropping. Performance on ObjectNet is also competitive. The model maintains stable calibration (ECE between 0.045 and 0.047) across a wide range of token counts per image (128 to 1024).</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>NaViT demonstrates that sequence packing, when applied to Vision Transformers, yields substantial improvements in training efficiency, inference flexibility, and downstream performance. The approach processes images at their native resolution without the information loss from resizing or the waste from padding.</p>
<p>Key takeaways:</p>
<ul>
<li>4x compute reduction to match top ViT performance</li>
<li>A single model works across a continuous range of resolutions at inference time</li>
<li>Variable-resolution training and token dropping provide complementary efficiency gains</li>
<li>Factorized positional embeddings generalize to unseen resolutions</li>
<li>Benefits transfer to detection, segmentation, video, and fairness tasks</li>
</ul>
<p>Limitations: The paper does not release model weights or code. All experiments use Google-internal datasets (JFT-4B, WebLI) and infrastructure (TPUs, JAX/Scenic), making direct reproduction difficult. The attention masking approach for packing assumes that cross-image attention is undesirable, which may not hold for all tasks.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification pretraining</td>
          <td>JFT-4B</td>
          <td>~4B labeled images</td>
          <td>Google-internal, not publicly available</td>
      </tr>
      <tr>
          <td>Contrastive pretraining</td>
          <td>WebLI</td>
          <td>Large-scale web data</td>
          <td>Google-internal, not publicly available</td>
      </tr>
      <tr>
          <td>Classification fine-tuning</td>
          <td>ImageNet-1k</td>
          <td>1.28M images</td>
          <td>Publicly available</td>
      </tr>
      <tr>
          <td>Segmentation</td>
          <td>ADE20k</td>
          <td>20K images</td>
          <td>Publicly available</td>
      </tr>
      <tr>
          <td>Detection</td>
          <td>LVIS</td>
          <td>164K images</td>
          <td>Publicly available</td>
      </tr>
      <tr>
          <td>Video</td>
          <td>Kinetics-400</td>
          <td>~240K videos</td>
          <td>Publicly available (partial)</td>
      </tr>
      <tr>
          <td>Fairness</td>
          <td>FairFace, CelebA</td>
          <td>108K / 200K images</td>
          <td>Publicly available</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Greedy bin-packing for sequence construction (less than 2% padding tokens)</li>
<li>Resolution sampling: side length from truncated normal $\mathcal{N}_{t}(-0.5, 1)$ mapped to $[64, R_{\text{max}}]$</li>
<li>Token dropping: Beta-distributed per-image rates, optionally resolution-dependent</li>
<li>Factorized positional embeddings with additive combination</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>NaViT variants: B/16, L/16, L/14</li>
<li>Based on vanilla ViT with query-key normalization, no biases, attention pooling</li>
<li>Implemented in JAX/FLAX within the Scenic framework</li>
<li>No public model checkpoints available</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>NaViT</th>
          <th>ViT Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>JFT linear probe (L/16)</td>
          <td>Matches top ViT</td>
          <td>4x more compute</td>
          <td>Compute-matched comparison</td>
      </tr>
      <tr>
          <td>ImageNet zero-shot (L/14)</td>
          <td>72.9%</td>
          <td>68.3%</td>
          <td>Contrastive pretraining</td>
      </tr>
      <tr>
          <td>LVIS AP (L/14)</td>
          <td>28.3%</td>
          <td>23.3%</td>
          <td>OWL-ViT detection</td>
      </tr>
      <tr>
          <td>LVIS AP rare (L/14)</td>
          <td>24.3%</td>
          <td>17.2%</td>
          <td>OWL-ViT detection</td>
      </tr>
      <tr>
          <td>ADE20k mIoU (L/16, 384)</td>
          <td>Beats ViT@512</td>
          <td>At 2x cost</td>
          <td>Segmenter linear decoder</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training on Cloud TPUs (specific configuration not detailed)</li>
<li>Inference latency measured on Cloud TPUv3</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Dehghani, M., Mustafa, B., Djolonga, J., Heek, J., Minderer, M., Caron, M., Steiner, A., Puigcerver, J., Geirhos, R., Alabdulmohsin, I., Oliver, A., Padlewski, P., Gritsenko, A., Lučić, M., &amp; Houlsby, N. (2023). Patch n&rsquo; Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution. <em>Advances in Neural Information Processing Systems 36 (NeurIPS 2023)</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{dehghani2023patch,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Patch n&#39; Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Dehghani, Mostafa and Mustafa, Basil and Djolonga, Josip and Heek, Jonathan and Minderer, Matthias and Caron, Mathilde and Steiner, Andreas and Puigcerver, Joan and Geirhos, Robert and Alabdulmohsin, Ibrahim and Oliver, Avital and Padlewski, Piotr and Gritsenko, Alexey and Lučić, Mario and Houlsby, Neil}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2307.06304}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.CV}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Materials Representations for ML Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/materials-representations-ml-review/</link><pubDate>Mon, 06 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/materials-representations-ml-review/</guid><description>Review of representation strategies for encoding solid-state materials as ML inputs, covering structural descriptors, crystal graphs, and generative models.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-material-representations">A Systematization of Material Representations</h2>
<p>This paper is a <strong>Systematization</strong> that organizes and categorizes the strategies researchers use to convert solid-state materials into numerical representations suitable for machine learning models. Rather than proposing a new method, the review provides a structured taxonomy of existing approaches, connecting each to the practical constraints of data availability, computational cost, and prediction targets. It covers structural descriptors, graph-based learned representations, compositional features, transfer learning, and generative models for inverse design.</p>
<h2 id="why-material-representations-matter">Why Material Representations Matter</h2>
<p>Machine learning has enabled rapid property prediction for materials, but every ML pipeline depends on how the material is encoded as a numerical input. The authors identify three guiding principles for effective representations:</p>
<ol>
<li><strong>Similarity preservation</strong>: Similar materials should have similar representations, and dissimilar materials should diverge in representation space.</li>
<li><strong>Domain coverage</strong>: The representation should be constructable for every material in the target domain.</li>
<li><strong>Cost efficiency</strong>: Computing the representation should be cheaper than computing the target property directly (e.g., via <a href="https://en.wikipedia.org/wiki/Density_functional_theory">DFT</a>).</li>
</ol>
<p>In practice, materials scientists face several barriers. Atomistic structures span diverse space groups, supercell sizes, and disorder parameters. Real material performance depends on defects, microstructure, and interfaces. Structural information often requires expensive experimental or computational effort to obtain. Datasets in materials science tend to be small, sparse, and biased toward well-studied systems.</p>
<h2 id="structural-descriptors-local-global-and-topological">Structural Descriptors: Local, Global, and Topological</h2>
<p>The review covers three families of hand-crafted structural descriptors that encode atomic positions and types.</p>
<h3 id="local-descriptors">Local Descriptors</h3>
<p>Local descriptors characterize the environment around each atom. Atom-centered symmetry functions (ACSF), introduced by Behler and Parrinello, define radial and angular functions:</p>
<p>$$
G_{i}^{1} = \sum_{j \neq i}^{\text{neighbors}} e^{-\eta(R_{ij} - R_{s})^{2}} f_{c}(R_{ij})
$$</p>
<p>$$
G_{i}^{2} = 2^{1-\zeta} \sum_{j,k \neq i}^{\text{neighbors}} (1 + \lambda \cos \theta_{ijk})^{\zeta} e^{-\eta(R_{ij}^{2} + R_{ik}^{2} + R_{jk}^{2})} f_{c}(R_{ij}) f_{c}(R_{ik}) f_{c}(R_{jk})
$$</p>
<p>The Smooth Overlap of Atomic Positions (SOAP), proposed by Bartók et al., defines atomic neighborhood density as a sum of Gaussians and computes a rotationally invariant kernel through expansion in radial functions and <a href="https://en.wikipedia.org/wiki/Spherical_harmonics">spherical harmonics</a>:</p>
<p>$$
\rho_{i}(\mathbf{r}) = \sum_{j} \exp\left(-\frac{|\mathbf{r} - \mathbf{r}_{ij}|^{2}}{2\sigma^{2}}\right) = \sum_{nlm} c_{nlm} g_{n}(\mathbf{r}) Y_{lm}(\hat{\mathbf{r}})
$$</p>
<p>The power spectrum $\mathbf{p}(\mathbf{r}) \equiv \sum_{m} c_{nlm}(c_{n&rsquo;lm})^{*}$ serves as a vector descriptor of the local environment. SOAP has seen wide adoption both as a similarity metric and as input to ML models.</p>
<p><a href="https://en.wikipedia.org/wiki/Voronoi_diagram">Voronoi tessellation</a> provides another local approach, segmenting space into cells and extracting features like effective coordination numbers, cell volumes, and neighbor properties.</p>
<h3 id="global-descriptors">Global Descriptors</h3>
<p>Global descriptors encode the full structure. The Coulomb matrix models electrostatic interactions between atoms:</p>
<p>$$
M_{i,j} = \begin{cases} Z_{i}^{2.4} &amp; \text{for } i = j \\ \frac{Z_{i}Z_{j}}{|r_{i} - r_{j}|} &amp; \text{for } i \neq j \end{cases}
$$</p>
<p>Other global methods include partial radial distribution functions (PRDF), the many-body tensor representation (MBTR), and cluster expansions. The Atomic Cluster Expansion (ACE) framework generalizes cluster expansions to continuous environments and has become a foundation for modern deep learning potentials.</p>
<h3 id="topological-descriptors">Topological Descriptors</h3>
<p><a href="https://en.wikipedia.org/wiki/Persistent_homology">Persistent homology</a> from topological data analysis (TDA) identifies geometric features at multiple length scales. Topological descriptors capture pore geometries in porous materials and have outperformed traditional structural descriptors for predicting CO$_{2}$ adsorption in metal-organic frameworks and methane storage in <a href="https://en.wikipedia.org/wiki/Zeolite">zeolites</a>. A caveat is the $O(N^{3})$ worst-case computational cost per filtration.</p>
<h2 id="crystal-graph-neural-networks">Crystal Graph Neural Networks</h2>
<p>Graph neural networks bypass manual feature engineering by learning representations directly from structural data. Materials are converted to graphs $G(V, E)$ where nodes represent atoms and edges connect neighbors within a cutoff radius, with periodic boundary conditions.</p>
<p>Key architectures discussed include:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Key Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN</td>
          <td>Crystal graph convolutions for broad property prediction</td>
      </tr>
      <tr>
          <td>MEGNet</td>
          <td>Materials graph networks with global state attributes</td>
      </tr>
      <tr>
          <td>ALIGNN</td>
          <td>Line graph neural networks incorporating three-body angular features</td>
      </tr>
      <tr>
          <td>Equivariant GNNs</td>
          <td>E(3)-equivariant message passing for tensorial properties</td>
      </tr>
  </tbody>
</table>
<p>The review identifies several limitations. Graph convolutions based on local neighborhoods can fail to capture long-range interactions or periodicity-dependent properties (e.g., lattice parameters, phonon spectra). Strategies to address this include concatenation with hand-tuned descriptors, plane-wave periodic basis modulation, and reciprocal-space features.</p>
<p>A major practical restriction is the requirement for relaxed atomic positions. Graphs built from unrelaxed crystal prototypes lose information about geometric distortions, degrading accuracy. Approaches to mitigate this include data augmentation with perturbed structures, Bayesian optimization of prototypes, and surrogate force-field relaxation.</p>
<p>Equivariant models that introduce higher-order tensors to node and edge features, constrained to transform correctly under E(3) operations, achieve state-of-the-art accuracy and can match structural descriptor performance even in low-data (~100 datapoints) regimes.</p>
<h2 id="compositional-descriptors-without-structure">Compositional Descriptors Without Structure</h2>
<p>When crystal structures are unavailable, representations can be built purely from stoichiometry and tabulated atomic properties (radii, electronegativity, valence electrons). Despite their simplicity, these methods have distinct advantages: zero computational overhead, accessibility to non-experts, and robustness for high-throughput screening.</p>
<p>Key methods include:</p>
<ul>
<li><strong>MagPie</strong>: 145 input features derived from elemental properties</li>
<li><strong>SISSO</strong>: Compressive sensing over algebraic combinations of atomic properties, capable of discovering interpretable descriptors (e.g., a new tolerance factor $\tau$ for perovskite stability)</li>
<li><strong>ElemNet</strong>: Deep neural network using only fractional stoichiometry as input, outperforming MagPie with &gt;3,000 training points</li>
<li><strong>ROOST</strong>: Fully-connected compositional graph with attention-based message passing, achieving strong performance with only hundreds of examples</li>
<li><strong>CrabNet</strong>: Self-attention on element embeddings with fractional encoding, handling dopant-level concentrations via log-scale inputs</li>
</ul>
<p>Compositional models cannot distinguish polymorphs and generally underperform structural approaches. They are most valuable when atomistic resolution is unavailable.</p>
<h2 id="defects-surfaces-and-grain-boundaries">Defects, Surfaces, and Grain Boundaries</h2>
<p>The review extends beyond idealized unit cells to practical materials challenges:</p>
<p><strong>Point defects</strong>: Representations of the pristine bulk can predict vacancy formation energies through linear relationships with band structure descriptors. Frey et al. proposed using relative differences between defect and parent structure properties, requiring no DFT on the defect itself.</p>
<p><strong>Surfaces and catalysis</strong>: Binding energy prediction for catalysis requires representations beyond the bulk unit cell. The d-band center for metals and oxygen 2p-band center for metal oxides serve as simple electronic descriptors, following the <a href="https://en.wikipedia.org/wiki/Sabatier_principle">Sabatier principle</a> that optimal catalytic activity requires intermediate binding strength. Graph neural networks trained on the Open Catalyst 2020 dataset (&gt;1 million DFT energies) have enabled broader screening, though errors remain high for certain adsorbates and non-metallic surfaces.</p>
<p><strong>Grain boundaries</strong>: SOAP descriptors computed for atoms near grain boundaries and clustered into local environment classes can predict grain boundary energy, mobility, and shear coupling. This approach provides interpretable structure-property relationships.</p>
<h2 id="transfer-learning-across-representations">Transfer Learning Across Representations</h2>
<p>When target datasets are small, transfer learning leverages representations learned from large, related datasets. The standard procedure involves: (1) pretraining on a large dataset (e.g., all Materials Project formation energies), (2) freezing parameters up to a chosen depth, and (3) either fine-tuning remaining layers or extracting features for a separate model.</p>
<p>Key findings from the review:</p>
<ul>
<li>Transfer learning is most effective when the source dataset is orders of magnitude larger than the target</li>
<li>Physically related tasks transfer better (e.g., Open Catalyst absorption energies transfer well to new adsorbates, less so to unrelated small molecules)</li>
<li>Earlier neural network layers learn more general representations and transfer better across properties</li>
<li>Multi-depth feature extraction, combining activations from multiple layers, can improve transfer</li>
<li>Predictions from surrogate models can serve as additional descriptors, expanding screening domains by orders of magnitude</li>
</ul>
<h2 id="generative-models-for-crystal-inverse-design">Generative Models for Crystal Inverse Design</h2>
<p>Generative models for solid-state materials face challenges beyond molecular generation: more diverse atomic species, the need to specify both positions and lattice parameters, non-unique definitions (rotations, translations, supercell scaling), and large unit cells (&gt;100 atoms for zeolites and MOFs).</p>
<p>The review traces the progression of approaches:</p>
<ol>
<li><strong>Voxel representations</strong>: Discretize unit cells into volume elements. Early work (iMatGen, Court et al.) demonstrated feasibility but was restricted to specific chemistries or cubic systems.</li>
<li><strong>Continuous coordinate models</strong>: Point cloud and invertible representations allowed broader chemical spaces but lacked symmetry invariances.</li>
<li><strong>Symmetry-aware models</strong>: Crystal Diffusion <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">VAE</a> (CDVAE) uses periodic graphs and SE(3)-equivariant message passing for translationally and rotationally invariant generation, establishing benchmark tasks for the field.</li>
<li><strong>Constrained models for porous materials</strong>: Approaches like SmVAE represent MOFs through their topological building blocks (RFcodes), ensuring all generated structures are physically valid.</li>
</ol>
<h2 id="open-problems-and-future-directions">Open Problems and Future Directions</h2>
<p>The review highlights four high-impact open questions:</p>
<ol>
<li><strong>Local vs. global descriptor trade-offs</strong>: Local descriptors (SOAP) excel for short-range interactions but struggle with long-range physics. Global descriptors model periodicity but lack generality across space groups. Combining local and long-range features could provide more universal models.</li>
<li><strong>Prediction from unrelaxed prototypes</strong>: ML force fields can relax structures at a fraction of DFT cost, potentially expanding screening domains. Key questions remain about required training data scale and generalizability.</li>
<li><strong>Applicability of compositional descriptors</strong>: The performance gap between compositional and structural models may be property-dependent, being smaller for properties like band gap that depend on global features rather than local site energies.</li>
<li><strong>Extensions of generative models</strong>: Diffusion-based architectures have improved on voxel approaches for small unit cells, but extending to microstructure, dimensionality, and surface generation remains open.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This paper is a review and does not present new experimental results or release any novel code, data, or models. The paper is open-access (hybrid OA at Annual Reviews) and the arXiv preprint is freely available. The following artifacts table covers key publicly available resources discussed in the review.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://arxiv.org/abs/2301.08813">arXiv preprint (2301.08813)</a></td>
          <td>Other</td>
          <td>arXiv (open access)</td>
          <td>Free preprint version</td>
      </tr>
      <tr>
          <td><a href="https://materialsproject.org">Materials Project</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>DFT energies, band gaps, structures for &gt;100,000 compounds</td>
      </tr>
      <tr>
          <td><a href="https://oqmd.org">OQMD</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Open Quantum Materials Database, &gt;600,000 DFT entries</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Open-Catalyst-Project/ocp">Open Catalyst 2020 (OC20)</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>&gt;1,000,000 DFT surface adsorption energies</td>
      </tr>
      <tr>
          <td><a href="https://aflowlib.org">AFLOW</a></td>
          <td>Dataset</td>
          <td>Public</td>
          <td>High-throughput ab initio library, &gt;3,000,000 entries</td>
      </tr>
      <tr>
          <td><a href="https://github.com/hackingmaterials/matminer">Matminer</a></td>
          <td>Code</td>
          <td>BSD</td>
          <td>Open-source toolkit for materials data mining and featurization</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The review covers: ACSF, SOAP, Voronoi tessellation, Coulomb matrices, PRDF, MBTR, cluster expansions, ACE, persistent homology, CGCNN, MEGNet, ALIGNN, E(3)-equivariant GNNs, MagPie, SISSO, ElemNet, ROOST, CrabNet, VAE, GAN, and diffusion-based crystal generators.</p>
<h3 id="hardware">Hardware</h3>
<p>No new experiments are conducted. Hardware requirements vary by the referenced methods (DFT calculations require HPC; GNN training typically requires 1-8 GPUs).</p>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<p><strong>Partially Reproducible</strong>: The review paper itself is open-access. All major datasets discussed (Materials Project, OQMD, OC20, AFLOW) are publicly available under permissive licenses. Most referenced model implementations (CGCNN, MEGNet, ALIGNN, ROOST, CDVAE) have open-source code. No novel artifacts are released by the authors.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Damewood, J., Karaguesian, J., Lunger, J. R., Tan, A. R., Xie, M., Peng, J., &amp; Gómez-Bombarelli, R. (2023). Representations of Materials for Machine Learning. <em>Annual Review of Materials Research</em>, 53. <a href="https://doi.org/10.1146/annurev-matsci-080921-085947">https://doi.org/10.1146/annurev-matsci-080921-085947</a></p>
<p><strong>Publication</strong>: Annual Review of Materials Research, 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{damewood2023representations,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Representations of Materials for Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Damewood, James and Karaguesian, Jessica and Lunger, Jaclyn R. and Tan, Aik Rui and Xie, Mingrou and Peng, Jiayu and G{\&#39;o}mez-Bombarelli, Rafael}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Annual Review of Materials Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{53}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1146/annurev-matsci-080921-085947}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MarkushGrapher-2: End-to-End Markush Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/markushgrapher-2-multimodal-recognition/</link><pubDate>Mon, 06 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/markushgrapher-2-multimodal-recognition/</guid><description>MarkushGrapher-2 fuses vision, text, and layout encoders with a dedicated OCR module for end-to-end Markush structure recognition from patent images.</description><content:encoded><![CDATA[<h2 id="a-multimodal-method-for-markush-structure-recognition">A Multimodal Method for Markush Structure Recognition</h2>
<p>This is a <strong>Method</strong> paper that introduces MarkushGrapher-2, a universal encoder-decoder model for recognizing both standard molecular structures and multimodal Markush structures from chemical images. The primary contribution is a dual-encoder architecture that fuses a pretrained OCSR (Optical Chemical Structure Recognition) vision encoder with a Vision-Text-Layout (VTL) encoder, connected through a dedicated ChemicalOCR module for end-to-end processing. The paper also introduces two new resources: a large-scale training dataset (USPTO-MOL-M) of real-world Markush structures extracted from USPTO patent MOL files, and IP5-M, a manually annotated benchmark of 1,000 Markush structures from five major patent offices.</p>
<h2 id="why-markush-structure-recognition-remains-challenging">Why Markush Structure Recognition Remains Challenging</h2>
<p><a href="https://en.wikipedia.org/wiki/Markush_structure">Markush structures</a> are compact representations used in patent documents to describe families of related molecules. They combine a visual backbone (atoms, bonds, variable regions) with textual definitions of substituents that can replace those variable regions. This multimodal nature makes them harder to parse than standard molecular diagrams.</p>
<p>Three factors limit automatic Markush recognition. First, visual styles vary across patent offices and publication years. Second, textual definitions lack standardization and often contain conditional or recursive descriptions. Third, real-world training data with comprehensive annotations is scarce. As a result, Markush structures are currently indexed only in two proprietary, manually curated databases: MARPAT and DWPIM.</p>
<p>Prior work, including the original <a href="/notes/chemistry/optical-structure-recognition/markush/markushgrapher/">MarkushGrapher</a>, required pre-annotated OCR outputs at inference time, limiting practical deployment. General-purpose models like GPT-5 and DeepSeek-OCR produce mostly chemically invalid outputs on Markush images, suggesting these lie outside their training distribution.</p>
<h2 id="dual-encoder-architecture-with-dedicated-chemicalocr">Dual-Encoder Architecture with Dedicated ChemicalOCR</h2>
<p>MarkushGrapher-2 uses two complementary encoding pipelines:</p>
<ol>
<li>
<p><strong>Vision encoder pipeline</strong>: The input image passes through a Swin-B Vision Transformer (taken from <a href="/notes/chemistry/optical-structure-recognition/image-to-graph/molscribe/">MolScribe</a>) pretrained for OCSR. This encoder extracts visual features representing molecular structures and remains frozen during training.</p>
</li>
<li>
<p><strong>Vision-Text-Layout (VTL) pipeline</strong>: The same image goes through ChemicalOCR, a compact 256M-parameter vision-language model fine-tuned from SmolDocling for OCR on chemical images. ChemicalOCR extracts character-level text and bounding boxes. These, combined with image patches, feed into a T5-base VTL encoder following the UDOP fusion paradigm, where visual and textual tokens are spatially aligned by bounding box overlap.</p>
</li>
</ol>
<p>The VTL encoder output is concatenated with projected embeddings from the vision encoder. This joint representation feeds a text decoder that auto-regressively generates a CXSMILES (ChemAxon Extended <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) string describing the backbone structure and a substituent table listing variable group definitions.</p>
<h3 id="two-stage-training-strategy">Two-Stage Training Strategy</h3>
<p>Training proceeds in two phases:</p>
<ul>
<li>
<p><strong>Phase 1 (Adaptation)</strong>: The vision encoder is frozen. The MLP projector and text decoder train on 243K real-world image-SMILES pairs from MolScribe&rsquo;s USPTO dataset (3 epochs). This aligns the decoder to the pretrained OCSR feature space.</p>
</li>
<li>
<p><strong>Phase 2 (Fusion)</strong>: The vision encoder, projector, and ChemicalOCR are all frozen. The VTL encoder and text decoder train on a mix of 235K synthetic and 145K real-world Markush samples (2 epochs). The VTL encoder learns the features needed for CXSMILES and substituent table prediction without disrupting the established OCSR representations.</p>
</li>
</ul>
<p>The total model has 831M parameters, of which 744M are trainable.</p>
<h2 id="datasets-and-evaluation-benchmarks">Datasets and Evaluation Benchmarks</h2>
<h3 id="training-data">Training Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OCR pretraining</td>
          <td>Synthetic chemical structures</td>
          <td>235K</td>
          <td><a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> SMILES augmented to CXSMILES, rendered with annotations</td>
      </tr>
      <tr>
          <td>OCR fine-tuning</td>
          <td>Manual OCR annotations</td>
          <td>7K</td>
          <td>IP5 patent document crops</td>
      </tr>
      <tr>
          <td>Phase 1 (OCSR)</td>
          <td>MolScribe USPTO</td>
          <td>243K</td>
          <td>Real image-SMILES pairs</td>
      </tr>
      <tr>
          <td>Phase 2 (MMSR)</td>
          <td>Synthetic CXSMILES</td>
          <td>235K</td>
          <td>Same as OCR pretraining set</td>
      </tr>
      <tr>
          <td>Phase 2 (MMSR)</td>
          <td>MolParser dataset</td>
          <td>91K</td>
          <td>Real-world Markush, converted to CXSMILES</td>
      </tr>
      <tr>
          <td>Phase 2 (MMSR)</td>
          <td>USPTO-MOL-M</td>
          <td>54K</td>
          <td>Real-world, auto-extracted from USPTO MOL files (2010-2025)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation-benchmarks">Evaluation Benchmarks</h3>
<p><strong>Markush benchmarks</strong>: M2S (103 samples), USPTO-M (74), WildMol-M (10K, semi-manual), and the new IP5-M (1,000 manually annotated from USPTO, JPO, KIPO, CNIPA, and EPO patents, 1980-2025).</p>
<p><strong>OCSR benchmarks</strong>: USPTO (5,719), JPO (450), UOB (5,740), WildMol (10K).</p>
<p>The primary metric is <strong>CXSMILES Accuracy (A)</strong>: a prediction is correct when (1) the predicted SMILES matches the ground truth by <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChIKey</a> equivalence, and (2) all Markush features (variable groups, positional and frequency variation indicators) are correctly represented. Stereochemistry is ignored during evaluation.</p>
<h3 id="results-markush-structure-recognition">Results: Markush Structure Recognition</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>M2S</th>
          <th>USPTO-M</th>
          <th>WildMol-M</th>
          <th>IP5-M</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolParser-Base</td>
          <td>39</td>
          <td>30</td>
          <td>38.1</td>
          <td>47.7</td>
      </tr>
      <tr>
          <td>MolScribe</td>
          <td>21</td>
          <td>7</td>
          <td>28.1</td>
          <td>22.3</td>
      </tr>
      <tr>
          <td>GPT-5</td>
          <td>3</td>
          <td>0</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DeepSeek-OCR</td>
          <td>0</td>
          <td>0</td>
          <td>1.9</td>
          <td>0.0</td>
      </tr>
      <tr>
          <td>MarkushGrapher-1</td>
          <td>38</td>
          <td>10</td>
          <td>32</td>
          <td>-</td>
      </tr>
      <tr>
          <td><strong>MarkushGrapher-2</strong></td>
          <td><strong>56</strong></td>
          <td><strong>13</strong></td>
          <td><strong>55</strong></td>
          <td><strong>48.0</strong></td>
      </tr>
  </tbody>
</table>
<p>On M2S, MarkushGrapher-2 achieves 56% CXSMILES accuracy vs. 38% for MarkushGrapher-1, a relative improvement of 47%. On WildMol-M (the largest benchmark at 10K samples), MarkushGrapher-2 reaches 55% vs. 38.1% for MolParser-Base and 32% for MarkushGrapher-1. GPT-5 and DeepSeek-OCR generate mostly chemically invalid outputs on Markush images: only 30% and 15% of their predictions are valid CXSMILES on M2S, respectively.</p>
<h3 id="results-standard-molecular-structure-recognition">Results: Standard Molecular Structure Recognition</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>WildMol</th>
          <th>JPO</th>
          <th>UOB</th>
          <th>USPTO</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolParser-Base</td>
          <td>76.9</td>
          <td>78.9</td>
          <td>91.8</td>
          <td>93.0</td>
      </tr>
      <tr>
          <td>MolScribe</td>
          <td>66.4</td>
          <td>76.2</td>
          <td>87.4</td>
          <td>93.1</td>
      </tr>
      <tr>
          <td>DECIMER 2.7</td>
          <td>56.0</td>
          <td>64.0</td>
          <td>88.3</td>
          <td>59.9</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-graph/molgrapher/">MolGrapher</a></td>
          <td>45.5</td>
          <td>67.5</td>
          <td>94.9</td>
          <td>91.5</td>
      </tr>
      <tr>
          <td>DeepSeek-OCR</td>
          <td>25.8</td>
          <td>31.6</td>
          <td>78.7</td>
          <td>36.9</td>
      </tr>
      <tr>
          <td><strong>MarkushGrapher-2</strong></td>
          <td>68.4</td>
          <td>71.0</td>
          <td><strong>96.6</strong></td>
          <td>89.8</td>
      </tr>
  </tbody>
</table>
<p>MarkushGrapher-2 achieves the highest score on UOB (96.6%) and remains competitive on other OCSR benchmarks, despite being primarily optimized for Markush recognition.</p>
<h3 id="chemicalocr-vs-general-ocr">ChemicalOCR vs. General OCR</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>M2S F1</th>
          <th>USPTO-M F1</th>
          <th>IP5-M F1</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PaddleOCR v5</td>
          <td>7.7</td>
          <td>1.2</td>
          <td>1.9</td>
      </tr>
      <tr>
          <td>EasyOCR</td>
          <td>10.2</td>
          <td>18.0</td>
          <td>18.4</td>
      </tr>
      <tr>
          <td><strong>ChemicalOCR</strong></td>
          <td><strong>87.2</strong></td>
          <td><strong>93.0</strong></td>
          <td><strong>86.5</strong></td>
      </tr>
  </tbody>
</table>
<p>General-purpose OCR tools fail on chemical images because they misinterpret bonds as characters and cannot parse chemical abbreviations. ChemicalOCR outperforms both by a large margin.</p>
<h2 id="ablation-results-and-key-findings">Ablation Results and Key Findings</h2>
<p><strong>OCR input is critical for Markush features.</strong> Without OCR, CXSMILES accuracy drops from 56% to 4% on M2S, and from 53.7% to 15.4% on IP5-M. The backbone structure accuracy ($A_{\text{InChIKey}}$) also drops substantially (from 80% to 39% on M2S), though the vision encoder alone can still recover some structural information. This confirms that textual cues (brackets, indices, variable definitions) are essential for Markush feature prediction.</p>
<p><strong>Two-phase training improves both tasks.</strong> Compared to single-phase (fusion only) training, the two-phase strategy improves CXSMILES accuracy from 44% to 50% on M2S and from 53.0% to 61.5% on JPO after the same number of epochs. Adapting the decoder to OCSR features before introducing the VTL encoder prevents the fusion process from degrading learned visual representations.</p>
<p><strong>Frequency variation indicators remain the hardest feature.</strong> On IP5-M, the per-feature breakdown shows 73.3% accuracy for backbone InChI, 74.8% for variable groups, 78.8% for positional variation, but only 30.7% for frequency variation (Sg groups). These repeating structural units are particularly challenging to represent and predict.</p>
<p><strong>Limitations</strong>: The model relies on accurate OCR as a prerequisite. Performance on USPTO-M (13% CXSMILES accuracy) lags behind other benchmarks, likely due to the older patent styles in that dataset. The paper does not report inference latency.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OCR pretraining</td>
          <td>Synthetic chemical images</td>
          <td>235K</td>
          <td>Generated from PubChem SMILES, augmented to CXSMILES</td>
      </tr>
      <tr>
          <td>OCR fine-tuning</td>
          <td>IP5 patent crops</td>
          <td>7K</td>
          <td>Manually annotated</td>
      </tr>
      <tr>
          <td>Phase 1 training</td>
          <td>MolScribe USPTO</td>
          <td>243K</td>
          <td>Public, real image-SMILES pairs</td>
      </tr>
      <tr>
          <td>Phase 2 training</td>
          <td>Synthetic + MolParser + USPTO-MOL-M</td>
          <td>380K</td>
          <td>Mix of synthetic (235K), MolParser (91K), USPTO-MOL-M (54K)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>M2S, USPTO-M, WildMol-M, IP5-M</td>
          <td>103 to 10K</td>
          <td>Markush benchmarks</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>WildMol, JPO, UOB, USPTO</td>
          <td>450 to 10K</td>
          <td>OCSR benchmarks</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
          <th>Status</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Vision encoder</td>
          <td>Swin-B ViT (from MolScribe)</td>
          <td>~87M</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td>VTL encoder + decoder</td>
          <td>T5-base</td>
          <td>~744M trainable</td>
          <td>Trained</td>
      </tr>
      <tr>
          <td>ChemicalOCR</td>
          <td>SmolDocling-based VLM</td>
          <td>256M</td>
          <td>Fine-tuned, frozen in Phase 2</td>
      </tr>
      <tr>
          <td>MLP projector</td>
          <td>Linear projection</td>
          <td>-</td>
          <td>Trained in Phase 1, frozen in Phase 2</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td></td>
          <td><strong>831M</strong></td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Definition</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CXSMILES Accuracy (A)</td>
          <td>Percentage of samples where InChIKey matches AND all Markush features correct</td>
      </tr>
      <tr>
          <td>$A_{\text{InChIKey}}$</td>
          <td>Backbone structure accuracy only (ignoring Markush features)</td>
      </tr>
      <tr>
          <td>Table Accuracy</td>
          <td>Percentage of correctly predicted substituent tables</td>
      </tr>
      <tr>
          <td>Markush Accuracy</td>
          <td>Joint CXSMILES + Table accuracy</td>
      </tr>
      <tr>
          <td>OCR F1</td>
          <td>Bounding-box-level precision/recall at IoU &gt; 0.5</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: NVIDIA A100 GPU</li>
<li>Phase 1: 3 epochs, Adam optimizer, lr 5e-4, 1000 warmup steps, batch size 10, weight decay 1e-3</li>
<li>Phase 2: 2 epochs, batch size 8</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DS4SD/MarkushGrapher">MarkushGrapher GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation of MarkushGrapher-2 with models and datasets</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Highly Reproducible. Code, models, and datasets are all publicly released under an MIT license with documented training hyperparameters and a single A100 GPU requirement.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Strohmeyer, T., Morin, L., Meijer, G. I., Weber, V., Nassar, A., &amp; Staar, P. (2026). MarkushGrapher-2: End-to-end Multimodal Recognition of Chemical Structures. In <em>Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)</em>.</p>
<p><strong>Publication</strong>: CVPR 2026</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/DS4SD/MarkushGrapher">GitHub Repository (MIT License)</a></li>
<li><a href="https://arxiv.org/abs/2603.28550">arXiv Preprint</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{strohmeyer2026markushgrapher,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MarkushGrapher-2: End-to-end Multimodal Recognition of Chemical Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Strohmeyer, Tim and Morin, Lucas and Meijer, Gerhard Ingmar and Weber, Val\&#39;{e}ry and Nassar, Ahmed and Staar, Peter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2026}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2603.28550}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.CV}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformers and LLMs for Chemistry Drug Discovery</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/transformers-llms-chemistry-drug-discovery/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/transformers-llms-chemistry-drug-discovery/</guid><description>Bran and Schwaller review transformer architectures for chemistry, from task-specific SMILES models to multimodal LLMs and chemistry agents.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformers-in-chemistry">A Systematization of Transformers in Chemistry</h2>
<p>This book chapter by Bran and Schwaller is a <strong>Systematization</strong> paper that organizes the growing body of work applying transformer architectures to chemistry and drug discovery. Rather than proposing a new method, the authors trace a three-stage evolution: (1) task-specific single-modality models operating on SMILES and reaction strings, (2) multimodal models bridging molecular representations with spectra, synthesis actions, and natural language, and (3) large language models and LLM-powered agents capable of general chemical reasoning.</p>
<h2 id="why-transformers-for-chemistry">Why Transformers for Chemistry?</h2>
<p>The authors motivate the review by drawing analogies between natural language and chemical language. Just as text can be decomposed into subwords and tokens, molecules can be linearized into <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> or <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> strings, and chemical reactions can be encoded as reaction SMILES. This structural parallel enabled direct transfer of transformer architectures, originally designed for machine translation, to chemical prediction tasks.</p>
<p>Several factors accelerated this adoption:</p>
<ul>
<li>The publication of open chemical databases and benchmarks (e.g., <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>, Open Reaction Database, Therapeutics Data Commons)</li>
<li>Improvements in compute infrastructure and training algorithms</li>
<li>The success of attention mechanisms at capturing context-dependent relationships, which proved effective for learning chemical grammar and atom-level correspondences</li>
</ul>
<p>The review positions the transformer revolution in chemistry as a natural extension of NLP advances, noting that the gap between chemical and natural language is progressively closing.</p>
<h2 id="molecular-representations-as-language">Molecular Representations as Language</h2>
<p>A key section of the review covers text-based molecular representations that make transformer applications possible:</p>
<ul>
<li><strong>SMILES</strong> (Simplified Molecular Input Line Entry System): The dominant linearization scheme since the 1980s, encoding molecular graphs as character sequences with special symbols for bonds, branches, and rings.</li>
<li><strong>SELFIES</strong> (Self-Referencing Embedded Strings): A newer representation that guarantees every string maps to a valid molecule, addressing the robustness issues of SMILES in generative settings.</li>
<li><strong>Reaction SMILES</strong>: Extends molecular representations to encode full chemical reactions in the format &ldquo;A.B &gt; catalyst.reagent &gt; C.D&rdquo;, enabling reaction prediction as a sequence-to-sequence task.</li>
</ul>
<p>The authors note that while IUPAC names, InChI, and <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> exist as alternatives, SMILES and SELFIES dominate practical applications.</p>
<h2 id="stage-1-task-specific-transformer-models">Stage 1: Task-Specific Transformer Models</h2>
<p>The first stage of transformer adoption focused on clearly defined chemical tasks, with models trained on a single data modality (molecular strings).</p>
<h3 id="chemical-translation-tasks">Chemical Translation Tasks</h3>
<p>The encoder-decoder architecture was directly applied to tasks framed as translation:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a></strong> (Schwaller et al.): Treated reaction prediction as translation from reactant SMILES to product SMILES, becoming a leading method for forward synthesis prediction.</li>
<li><strong>Retrosynthetic planning</strong>: The reverse task, predicting reactants from products, with iterative application to construct full retrosynthetic trees mapping to commercially available building blocks.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a></strong> (Irwin et al.): A pre-trained model across multiple chemical tasks, offering transferability to new applications with improved performance.</li>
<li><strong>Graph-to-sequence models</strong> (Tu and Coley): Used a custom graph encoder with a transformer decoder, achieving improvements through permutation-invariant molecular graph encoding.</li>
</ul>
<h3 id="representation-learning-and-feature-extraction">Representation Learning and Feature Extraction</h3>
<p>Encoder-only transformers proved valuable for generating molecular and reaction embeddings:</p>
<ul>
<li><strong>Reaction representations</strong> (Wang et al., SMILES-BERT): Trained models to generate reaction vectors that outperformed hand-engineered features on downstream regression tasks.</li>
<li><strong>Reaction classification</strong> (Schwaller et al.): Replaced the decoder with a classification layer to map chemical reactions by class, revealing clustering patterns by reaction type, data source, and molecular properties.</li>
<li><strong>Yield prediction</strong>: Regression heads attached to encoders achieved strong results on high-throughput experimentation datasets.</li>
<li><strong>Protein language models</strong> (Rives et al., ESM): Trained on 250 million protein sequences using unsupervised learning, achieving strong performance on protein property prediction and structure forecasting.</li>
<li><strong>RXNMapper</strong> (Schwaller et al.): A notable application where attention weight analysis revealed that transformers internally learn atom-to-atom mappings in chemical reactions, leading to an open-source atom mapping algorithm that outperformed existing approaches.</li>
</ul>
<h2 id="stage-2-multimodal-chemical-models">Stage 2: Multimodal Chemical Models</h2>
<p>The second stage extended transformers beyond molecular strings to incorporate additional data types:</p>
<ul>
<li><strong>Molecular captioning</strong>: Describing molecules in natural language, covering scaffolds, sources, drug interactions, and other features (Edwards et al.).</li>
<li><strong>Bidirectional molecule-text conversion</strong>: Models capable of generating molecules from text queries and performing molecule-to-molecule tasks (Christofidellis et al.).</li>
<li><strong>Experimental procedure prediction</strong>: Generating actionable synthesis steps from reaction SMILES (Vaucher et al.), bridging the gap between retrosynthetic planning and laboratory execution.</li>
<li><strong>Structural elucidation from IR spectra</strong>: Encoding IR spectra as text sequences alongside chemical formulas, then predicting SMILES from these inputs (Alberts et al.), achieving 45% accuracy in structure prediction and surpassing prior approaches for functional group identification.</li>
</ul>
<h2 id="stage-3-large-language-models-and-chemistry-agents">Stage 3: Large Language Models and Chemistry Agents</h2>
<p>The most recent stage builds on foundation models pre-trained on vast text corpora, adapted for chemistry through fine-tuning and in-context learning.</p>
<h3 id="scaling-laws-and-emergent-capabilities">Scaling Laws and Emergent Capabilities</h3>
<p>The authors discuss how model scaling leads to emergent capabilities relevant to chemistry:</p>
<ul>
<li>Below certain compute thresholds, model performance on chemistry tasks appears random.</li>
<li>Above critical sizes, sudden improvements emerge, along with capabilities like chain-of-thought (CoT) reasoning and instruction following.</li>
<li>These emergent abilities enable chemistry tasks that require multi-step reasoning without explicit training on chemical data.</li>
</ul>
<h3 id="llms-as-chemistry-tools">LLMs as Chemistry Tools</h3>
<p>Key applications of LLMs in chemistry include:</p>
<ul>
<li><strong><a href="/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/">Fine-tuning for low-data chemistry</a></strong> (Jablonka et al.): GPT-3 fine-tuned on limited chemistry datasets performed comparably to, and sometimes exceeded, specialized models with engineered features for tasks like predicting transition wavelengths and phase classification.</li>
<li><strong>In-context learning</strong>: Providing LLMs with a few examples enables prediction on chemistry tasks without any parameter updates, particularly valuable when data is scarce.</li>
<li><strong>Bayesian optimization with LLMs</strong> (Ramos et al.): Using GPT models for uncertainty-calibrated regression, enabling catalyst and molecular optimization directly from synthesis procedures without feature engineering.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/">3D structure generation</a></strong> (Flam-Shepherd and Aspuru-Guzik): Using language models to generate molecular structures with three-dimensional atomic positions in XYZ, CIF, and PDB formats, matching graph-based algorithms while overcoming representation limitations.</li>
</ul>
<h3 id="llm-powered-chemistry-agents">LLM-Powered Chemistry Agents</h3>
<p>The review highlights the agent paradigm as the most impactful recent development:</p>
<ul>
<li><strong>14 LLM use-cases</strong> (Jablonka et al.): A large-scale collaborative effort demonstrating applications from computational tool wrappers to reaction optimization assistants and scientific question answering.</li>
<li><strong><a href="/notes/chemistry/llm-applications/chemcrow-augmenting-llms-chemistry-tools/">ChemCrow</a></strong> (Bran, Cox et al.): An LLM-powered agent equipped with curated computational chemistry tools, capable of planning and executing tasks across drug design, materials design, and synthesis. ChemCrow demonstrated that tool integration overcomes LLM hallucination issues by grounding responses in reliable data sources.</li>
<li><strong>Autonomous scientific research</strong> (Boiko et al.): Systems with focus on cloud laboratory operability.</li>
</ul>
<p>The agent paradigm offers tool composability through natural language interfaces, allowing users to chain multiple computational tools into custom pipelines.</p>
<h2 id="outlook-and-limitations">Outlook and Limitations</h2>
<p>The authors identify several themes for the future:</p>
<ul>
<li>The three stages represent increasing generality, from task-specific single-modality models to open-ended agents.</li>
<li>Natural language interfaces are progressively closing the gap between chemical and human language.</li>
<li>Tool integration through agents provides grounding that mitigates hallucination, a known limitation of direct LLM application to chemistry.</li>
<li>The review acknowledges that LLMs have a &ldquo;high propensity to generate false and inaccurate content&rdquo; on chemical tasks, making tool-augmented approaches preferable to direct application.</li>
</ul>
<p>The chapter does not provide quantitative benchmarks or systematic comparisons across the methods discussed, as its goal is to organize the landscape rather than evaluate individual methods.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This is a review/survey chapter and does not introduce new models, datasets, or experiments. The reproducibility assessment applies to the referenced works rather than the review itself.</p>
<h3 id="key-referenced-resources">Key Referenced Resources</h3>
<p>Several open-source tools and datasets discussed in the review are publicly available:</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/rxn4chemistry/rxnmapper">RXNMapper</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Attention-based atom mapping</td>
      </tr>
      <tr>
          <td><a href="https://github.com/ur-whitelab/chemcrow-public">ChemCrow</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>LLM-powered chemistry agent</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Dataset</td>
          <td>Various</td>
          <td>Molecular ML benchmarks</td>
      </tr>
      <tr>
          <td><a href="https://open-reaction-database.org/">Open Reaction Database</a></td>
          <td>Dataset</td>
          <td>CC-BY-SA-4.0</td>
          <td>Curated reaction data</td>
      </tr>
      <tr>
          <td><a href="https://tdcommons.ai/">Therapeutics Data Commons</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Drug discovery ML datasets</td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-classification">Reproducibility Classification</h3>
<p><strong>Not applicable</strong> (review paper). Individual referenced works range from Highly Reproducible (open-source models like RXNMapper, ChemCrow) to Partially Reproducible (some models without released code) to Closed (proprietary LLMs like GPT-3/GPT-4 used in fine-tuning studies).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bran, A. M., &amp; Schwaller, P. (2024). Transformers and Large Language Models for Chemistry and Drug Discovery. In <em>Drug Development Supported by Informatics</em> (pp. 143-163). Springer Nature Singapore. <a href="https://doi.org/10.1007/978-981-97-4828-0_8">https://doi.org/10.1007/978-981-97-4828-0_8</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@incollection</span>{bran2024transformers,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformers and Large Language Models for Chemistry and Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bran, Andres M. and Schwaller, Philippe}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Drug Development Supported by Informatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{143--163}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer Nature Singapore}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1007/978-981-97-4828-0_8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>REINVENT: Reinforcement Learning for Mol. Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/</guid><description>REINVENT uses augmented episodic likelihood to fine-tune a SMILES-based RNN via reinforcement learning for goal-directed molecular generation.</description><content:encoded><![CDATA[<h2 id="augmented-episodic-likelihood-for-goal-directed-generation">Augmented Episodic Likelihood for Goal-Directed Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces REINVENT, a policy-based reinforcement learning framework for molecular de novo design. The primary contribution is a novel cost function, the <a href="/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/">augmented episodic likelihood</a>, that fine-tunes a <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>-based recurrent neural network (RNN) pre-trained on ChEMBL toward generating molecules satisfying user-defined property objectives. The method anchors the agent to the prior distribution of valid drug-like molecules, addressing failure modes of standard REINFORCE algorithms (reward exploitation and <a href="/notes/chemistry/molecular-design/generation/evaluation/failure-modes-molecule-generation/">mode collapse</a> to trivially simple structures).</p>
<h2 id="de-novo-design-needs-flexible-data-driven-approaches">De Novo Design Needs Flexible, Data-Driven Approaches</h2>
<p>Traditional de novo design methods fall into three categories, each with limitations:</p>
<ol>
<li><strong>Structure-based approaches</strong> grow ligands to fit binding pockets but often produce molecules with poor DMPK profiles and synthetic intractability.</li>
<li><strong>Ligand-based virtual library</strong> approaches generate large libraries and score them, but are constrained by pre-defined reaction rules or transformation rules that limit chemical diversity.</li>
<li><strong><a href="/notes/chemistry/molecular-design/property-prediction/">Inverse QSAR</a></strong> methods attempt to map favorable activity regions back to molecular structures, but require descriptors suitable for both forward prediction and inverse mapping.</li>
</ol>
<p>RNN-based generative models trained on SMILES offer a data-driven alternative that can learn the underlying distribution of drug-like chemical space without rigid rules. Segler et al. (2017) showed that fine-tuning a pre-trained RNN on focused actives yields high fractions of predicted actives. However, this maximum likelihood fine-tuning cannot use negative or continuous scores and risks catastrophic forgetting.</p>
<p>Prior RL approaches had significant issues. Jaques et al. (2016) used Deep Q-learning with prior likelihood regularization for sequence generation, but reported dependence on hand-written rules to penalize undesirable sequences and still observed reward exploitation producing unrealistically simple molecules. Standard REINFORCE algorithms tend to converge on trivial solutions (e.g., generating only &ldquo;C&rdquo; to satisfy a scoring function).</p>
<h2 id="the-augmented-episodic-likelihood-framework">The Augmented Episodic Likelihood Framework</h2>
<p>The core innovation is a formulation where the agent learns a policy that minimizes the squared difference between its own log-likelihood and an augmented target likelihood.</p>
<p>The RNN is first pre-trained on 1.5 million canonical SMILES from ChEMBL via maximum likelihood estimation:</p>
<p>$$
J(\Theta) = -\sum_{t=1}^{T} \log P(x^{t} \mid x^{t-1}, \dots, x^{1})
$$</p>
<p>The pre-trained model (the Prior) is then used as the starting point for the Agent. For a generated SMILES sequence $A = a_1, a_2, \dots, a_T$, the model likelihood is $P(A) = \prod_{t=1}^{T} \pi(a_t \mid s_t)$, and a scoring function $S(A) \in [-1, 1]$ rates desirability.</p>
<p>The augmented likelihood combines prior likelihood with the score:</p>
<p>$$
\log P(A)_{\mathbb{U}} = \log P(A)_{Prior} + \sigma S(A)
$$</p>
<p>where $\sigma$ is a scalar coefficient controlling the trade-off between prior fidelity and score optimization.</p>
<p>The return is defined as the negative squared difference between the augmented likelihood and the agent&rsquo;s likelihood:</p>
<p>$$
G(A) = -\left[\log P(A)_{\mathbb{U}} - \log P(A)_{\mathbb{A}}\right]^{2}
$$</p>
<p>The agent minimizes $J(\Theta) = -G$, effectively learning a policy whose sequence likelihoods match the prior modulated by the scoring function. The authors show in supplementary material that this is equivalent to a REINFORCE algorithm with a specific final-step reward formulation.</p>
<p>This design has three key advantages over standard REINFORCE:</p>
<ul>
<li>The target policy is explicitly stochastic, preserving diversity in generated molecules</li>
<li>The prior anchoring prevents catastrophic forgetting of SMILES syntax and chemical space coverage</li>
<li>No hand-written rules are needed to penalize degenerate solutions</li>
</ul>
<p>The Agent is trained on-policy with batches of 128 generated sequences, using SGD with learning rate 0.0005 and gradient clipping to $[-3, 3]$.</p>
<h2 id="three-experiments-sulphur-avoidance-celecoxib-analogues-and-drd2-activity">Three Experiments: Sulphur Avoidance, Celecoxib Analogues, and DRD2 Activity</h2>
<h3 id="prior-network-architecture">Prior Network Architecture</h3>
<p>The Prior is a 3-layer RNN with 1024 Gated Recurrent Units per layer, trained on RDKit canonical SMILES from ChEMBL (molecules with 10-50 heavy atoms, elements from ${H, B, C, N, O, F, Si, P, S, Cl, Br, I}$). Training used Adam ($\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$) for 50,000 steps with batch size 128 and learning rate decay of 0.02 every 100 steps. The Prior generates 94% valid SMILES, of which 90% are novel.</p>
<h3 id="experiment-1-learning-to-avoid-sulphur">Experiment 1: Learning to Avoid Sulphur</h3>
<p>A proof-of-principle task where the scoring function assigns $S(A) = 1$ for valid sulphur-free molecules, $S(A) = 0$ for invalid SMILES, and $S(A) = -1$ for sulphur-containing molecules.</p>
<p>The Agent method was compared against three alternatives:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Fraction Valid</th>
          <th>Fraction No S</th>
          <th>Avg MW</th>
          <th>Avg cLogP</th>
          <th>Avg RotBonds</th>
          <th>Avg AromRings</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior</td>
          <td>0.94</td>
          <td>0.66</td>
          <td>371</td>
          <td>3.36</td>
          <td>5.39</td>
          <td>2.26</td>
      </tr>
      <tr>
          <td>Agent</td>
          <td>0.95</td>
          <td>0.98</td>
          <td>367</td>
          <td>3.37</td>
          <td>5.41</td>
          <td>2.26</td>
      </tr>
      <tr>
          <td>Action basis</td>
          <td>0.95</td>
          <td>0.92</td>
          <td>372</td>
          <td>3.39</td>
          <td>6.08</td>
          <td>2.09</td>
      </tr>
      <tr>
          <td>REINFORCE</td>
          <td>0.98</td>
          <td>0.98</td>
          <td>585</td>
          <td>11.3</td>
          <td>30.0</td>
          <td>0.57</td>
      </tr>
      <tr>
          <td>REINFORCE + Prior</td>
          <td>0.98</td>
          <td>0.92</td>
          <td>232</td>
          <td>3.05</td>
          <td>2.8</td>
          <td>2.11</td>
      </tr>
  </tbody>
</table>
<p>Standard REINFORCE exploited the reward by generating sequences of predominantly &ldquo;C&rdquo; (average MW 585, cLogP 11.3). REINFORCE + Prior avoided this but collapsed to small, simplistic structures (MW 232). The Agent achieved 98% sulphur-free structures while maintaining molecular properties nearly identical to the Prior, demonstrating that augmented episodic likelihood preserves the prior distribution.</p>
<h3 id="experiment-2-similarity-guided-generation-celecoxib-analogues">Experiment 2: Similarity-Guided Generation (Celecoxib Analogues)</h3>
<p>The scoring function uses <a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard similarity</a> on FCFP4 fingerprints:</p>
<p>$$
S(A) = -1 + 2 \times \frac{\min{J_{i,j}, k}}{k}
$$</p>
<p>where $k$ caps the rewarded similarity. With $k = 1$ and $\sigma = 15$, the Agent recovers <a href="https://en.wikipedia.org/wiki/Celecoxib">Celecoxib</a> itself within 200 training steps. Even when all structures with $J &gt; 0.5$ to Celecoxib (1,804 molecules) were removed from the Prior training set, the Agent still found Celecoxib after 400 steps, despite a 700-fold reduction in prior likelihood ($\log_e P$ from $-12.7$ to $-19.2$).</p>
<p>With moderate similarity targets ($k = 0.7$, $\sigma = 12$), the Agent generates diverse analogues including scaffold hops where functional groups are rearranged.</p>
<h3 id="experiment-3-target-activity-drd2">Experiment 3: Target Activity (DRD2)</h3>
<p>The most drug-discovery-relevant task: generating molecules predicted active against the <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">dopamine receptor type 2 (DRD2)</a>. An SVM classifier (Gaussian kernel, $C = 2^7$, $\gamma = 2^{-6}$) was trained on bioactivity data from ExCAPE-DB (7,218 actives with pIC50 &gt; 5, 100,000 sampled inactives). The actives were split by Butina clustering (ECFP6, cutoff 0.4) to decrease nearest-neighbor similarity between train and test sets.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Prior</th>
          <th>Agent</th>
          <th>Prior (reduced)</th>
          <th>Agent (reduced)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fraction valid SMILES</td>
          <td>0.94</td>
          <td>0.99</td>
          <td>0.94</td>
          <td>0.99</td>
      </tr>
      <tr>
          <td>Fraction predicted actives</td>
          <td>0.03</td>
          <td>0.97</td>
          <td>0.02</td>
          <td>0.96</td>
      </tr>
      <tr>
          <td>Fraction similar to train active</td>
          <td>0.02</td>
          <td>0.79</td>
          <td>0.02</td>
          <td>0.75</td>
      </tr>
      <tr>
          <td>Fraction similar to test active</td>
          <td>0.01</td>
          <td>0.46</td>
          <td>0.01</td>
          <td>0.38</td>
      </tr>
      <tr>
          <td>Test actives recovered (x10^-3)</td>
          <td>13.5</td>
          <td>126</td>
          <td>2.85</td>
          <td>72.6</td>
      </tr>
  </tbody>
</table>
<p>The Agent increased the fraction of predicted actives from 2-3% (Prior) to 96-97%, representing a 250-fold enrichment in the probability of generating a test set active. The Agent based on the reduced Prior (DRD2 actives removed from ChEMBL) still recovered 7% of test actives, meaning it generated experimentally confirmed actives that appeared in neither the generative model nor the activity prediction model training data.</p>
<h2 id="anchored-policy-learning-prevents-reward-exploitation">Anchored Policy Learning Prevents Reward Exploitation</h2>
<p>The key finding is that augmented episodic likelihood successfully balances score optimization with prior distribution preservation. The Agent achieves task objectives (sulphur avoidance, similarity targets, activity prediction) while maintaining the molecular property distributions learned from ChEMBL. This is a significant improvement over standard REINFORCE, which either exploits rewards trivially or collapses to simple structures.</p>
<p>Analysis of the conditional probability distributions between the Prior and Agent (for DRD2 active generation) shows that the policy changes are not drastic: most trends learned by the Prior carry over, with targeted modifications at specific steps that substantially alter sequence likelihoods and generated structure types.</p>
<p>Limitations acknowledged by the authors:</p>
<ul>
<li>All experiments use single-parameter scoring functions; multi-parametric optimization (activity + DMPK + synthetic accessibility) is left for future work</li>
<li>The quality of generated structures depends heavily on the Prior&rsquo;s coverage of chemical space</li>
<li>The activity model (SVM) has limited domain of applicability, and structures outside this domain may be falsely scored</li>
<li>No exhaustive study of how Prior training set size, model size, and regularization affect generation quality</li>
</ul>
<p>Future directions include multi-parametric scoring functions, exploration of token embeddings, and adversarial training where the scoring function is replaced by a discriminator network (GAN-style training).</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior training</td>
          <td>ChEMBL</td>
          <td>1.5M structures</td>
          <td>10-50 heavy atoms, filtered elements</td>
      </tr>
      <tr>
          <td>DRD2 activity model</td>
          <td>ExCAPE-DB</td>
          <td>7,218 actives + 100K inactives</td>
          <td>Butina clustering split (ECFP6, cutoff 0.4)</td>
      </tr>
      <tr>
          <td>Similarity target</td>
          <td>Celecoxib</td>
          <td>1 query structure</td>
          <td>FCFP4 fingerprints for Jaccard similarity</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Prior</strong>: 3-layer GRU RNN (1024 units/layer), Adam optimizer, 50K steps, batch size 128, LR 0.001 with 0.02 decay/100 steps</li>
<li><strong>Agent</strong>: Same architecture, SGD with LR 0.0005, gradient clipping [-3, 3], on-policy batches of 128</li>
<li><strong>DRD2 model</strong>: SVM with Gaussian kernel ($C = 2^7$, $\gamma = 2^{-6}$), grid search on validation set</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MarcusOlivecrona/REINVENT">REINVENT</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Original implementation in TensorFlow/Python 2.7</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.572576">Archived version</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Zenodo archive (DOI: 10.5281/zenodo.572576)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>SMILES validity rate (RDKit parsing)</li>
<li>Fraction of structures satisfying scoring function</li>
<li>Molecular property distributions (MW, cLogP, rotatable bonds, aromatic rings)</li>
<li>Jaccard similarity on ECFP6/FCFP4 fingerprints</li>
<li>Recovery rate of known actives from test set</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. The implementation uses TensorFlow 1.0.1 with Python 2.7, RDKit, and Scikit-learn.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Olivecrona, M., Blaschke, T., Engkvist, O., &amp; Chen, H. (2017). Molecular de-novo design through deep reinforcement learning. <em>Journal of Cheminformatics</em>, 9(1), 48.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{olivecrona2017molecular,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecular de-novo design through deep reinforcement learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Olivecrona, Marcus and Blaschke, Thomas and Engkvist, Ola and Chen, Hongming}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{48}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-017-0235-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ReactionT5: Pre-trained T5 for Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/reactiont5-pretrained-limited-reaction-data/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/reactiont5-pretrained-limited-reaction-data/</guid><description>ReactionT5 uses two-stage pretraining on ZINC and the Open Reaction Database to enable competitive reaction and yield prediction with minimal fine-tuning data.</description><content:encoded><![CDATA[<h2 id="a-two-stage-pre-trained-transformer-for-chemical-reactions">A Two-Stage Pre-trained Transformer for Chemical Reactions</h2>
<p>ReactionT5 is a <strong>Method</strong> paper that proposes a T5-based pre-trained model for chemical reaction tasks, specifically product prediction and yield prediction. The primary contribution is a two-stage pretraining pipeline: first on a compound library (ZINC, 23M molecules) to learn molecular representations, then on a large-scale reaction database (the Open Reaction Database, 1.5M reactions) to learn reaction-level patterns. The key result is that this pre-trained model can be fine-tuned with very limited target-domain data (as few as 30 reactions) and still achieve competitive performance against models trained on full datasets.</p>
<h2 id="bridging-the-gap-between-single-molecule-and-multi-molecule-pretraining">Bridging the Gap Between Single-Molecule and Multi-Molecule Pretraining</h2>
<p>While transformer-based models pre-trained on compound libraries (e.g., <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, MolGPT) have seen substantial development, most focus on single-molecule inputs and outputs. Pretraining for multi-molecule contexts, such as chemical reactions involving reactants, reagents, catalysts, and products, remains underexplored. T5Chem supports multi-task reaction prediction but focuses on building a single multi-task model rather than investigating the effectiveness of pre-trained models for fine-tuning on limited in-house data.</p>
<p>The authors identify two key gaps:</p>
<ol>
<li>Most pre-trained chemical models do not account for reaction-level interactions between multiple molecules.</li>
<li>In practical settings, target-domain reaction data is often scarce, making transfer learning from large public datasets essential.</li>
</ol>
<h2 id="two-stage-pretraining-with-compound-restoration">Two-Stage Pretraining with Compound Restoration</h2>
<p>The core innovation is a two-stage pretraining procedure built on the <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5 (text-to-text transfer transformer)</a> architecture:</p>
<p><strong>Stage 1: Compound Pretraining (CompoundT5)</strong>. An initialized T5 model is trained on 23M <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> from the ZINC database using span-masked language modeling. The model learns to predict masked subsequences of SMILES tokens. A SentencePiece unigram tokenizer is trained on this compound library, allowing more compact representations than character-level or atom-level tokenizers. After this stage, new tokens are added to the tokenizer to cover metal atoms and other characters present in the reaction database but absent from ZINC.</p>
<p><strong>Stage 2: Reaction Pretraining (ReactionT5)</strong>. CompoundT5 is further pretrained on 1.5M reactions from the Open Reaction Database (ORD) on both product prediction and yield prediction tasks. Reactions are formulated as text-to-text tasks using special tokens:</p>
<ul>
<li><code>REACTANT:</code>, <code>REAGENT:</code>, and <code>PRODUCT:</code> tokens delimit the role of each molecule in the reaction string.</li>
<li>For product prediction, the model takes reactants and reagents as input and generates product SMILES.</li>
<li>For yield prediction, the model takes the full reaction (including products) and outputs a numerical yield value.</li>
</ul>
<p><strong>Compound Restoration</strong>. A notable methodological detail is the handling of uncategorized compounds in the ORD. About 31.8% of ORD reactions contain compounds with unknown roles. Simply discarding these reactions introduces severe product bias (only 447 unique products remain vs. 439,898 with uncategorized data included). The authors develop RestorationT5, a binary classifier built from CompoundT5, that assigns uncategorized compounds to either reactant or reagent roles. This classifier uses a sigmoid output layer and achieves an F1 score of 0.1564 at a threshold of 0.97, outperforming a random forest baseline (F1 = 0.1136). The restored dataset (&ldquo;ORD(restored)&rdquo;) is then used for reaction pretraining.</p>
<p>For yield prediction, the loss function is mean squared error:</p>
<p>$$L = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2$$</p>
<p>where $y_i$ is the true yield (normalized to [0, 1]) and $\hat{y}_i$ is the predicted yield.</p>
<h2 id="experimental-setup-product-and-yield-prediction-benchmarks">Experimental Setup: Product and Yield Prediction Benchmarks</h2>
<h3 id="product-prediction">Product Prediction</h3>
<p>The USPTO dataset (479K reactions) is used for evaluation, with standard train/val/test splits (409K/30K/40K). Reactions overlapping with the ORD (18%) are removed during evaluation. Beam search with beam size 10 is used for decoding, and minimum/maximum output length constraints are set based on the training data distribution. Top-k accuracy (k = 1, 2, 3, 5) and invalidity rate are reported.</p>
<p>Baselines include Seq-to-seq, WLDN (graph neural network), <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a>, and T5Chem.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Train</th>
          <th>Top-1</th>
          <th>Top-2</th>
          <th>Top-3</th>
          <th>Top-5</th>
          <th>Invalidity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Seq-to-seq</td>
          <td>USPTO</td>
          <td>80.3</td>
          <td>84.7</td>
          <td>86.2</td>
          <td>87.5</td>
          <td>-</td>
      </tr>
      <tr>
          <td>WLDN</td>
          <td>USPTO</td>
          <td>85.6</td>
          <td>90.5</td>
          <td>92.8</td>
          <td>93.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Molecular Transformer</td>
          <td>USPTO</td>
          <td>88.8</td>
          <td>92.6</td>
          <td>-</td>
          <td>94.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>T5Chem</td>
          <td>USPTO</td>
          <td>90.4</td>
          <td>94.2</td>
          <td>-</td>
          <td>96.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>CompoundT5</td>
          <td>USPTO</td>
          <td>88.0</td>
          <td>92.4</td>
          <td>93.9</td>
          <td>95.0</td>
          <td>7.5</td>
      </tr>
      <tr>
          <td>ReactionT5 (restored ORD)</td>
          <td>USPTO200</td>
          <td>85.5</td>
          <td>91.7</td>
          <td>93.5</td>
          <td>94.9</td>
          <td>12.0</td>
      </tr>
  </tbody>
</table>
<p>A critical finding: ReactionT5 pre-trained on ORD achieves 0% accuracy on USPTO without fine-tuning due to domain mismatch (ORD includes byproducts; USPTO lists only the main product). Fine-tuning on just 200 USPTO reactions with the restored ORD model produces competitive results.</p>
<p>The few-shot fine-tuning analysis shows rapid performance scaling:</p>
<table>
  <thead>
      <tr>
          <th>Samples</th>
          <th>Top-1</th>
          <th>Top-2</th>
          <th>Top-3</th>
          <th>Top-5</th>
          <th>Invalidity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>10</td>
          <td>9.0</td>
          <td>12.5</td>
          <td>15.3</td>
          <td>19.1</td>
          <td>12.4</td>
      </tr>
      <tr>
          <td>30</td>
          <td>80.5</td>
          <td>87.3</td>
          <td>89.8</td>
          <td>92.0</td>
          <td>17.2</td>
      </tr>
      <tr>
          <td>50</td>
          <td>83.7</td>
          <td>89.9</td>
          <td>92.2</td>
          <td>94.0</td>
          <td>14.8</td>
      </tr>
      <tr>
          <td>100</td>
          <td>85.1</td>
          <td>91.0</td>
          <td>92.8</td>
          <td>94.4</td>
          <td>14.0</td>
      </tr>
      <tr>
          <td>200</td>
          <td>85.5</td>
          <td>91.7</td>
          <td>93.5</td>
          <td>94.9</td>
          <td>12.0</td>
      </tr>
  </tbody>
</table>
<h3 id="yield-prediction">Yield Prediction</h3>
<p>The <a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> C-N cross-coupling dataset (3,955 reactions) is used with random 7:3 splits (repeated 10 times) plus four out-of-sample test sets (Tests 1-4) designed so that similar reactions do not appear in both train and test.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Random 7:3</th>
          <th>Test 1</th>
          <th>Test 2</th>
          <th>Test 3</th>
          <th>Test 4</th>
          <th>Avg. Tests 1-4</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DFT</td>
          <td>0.92</td>
          <td>0.80</td>
          <td>0.77</td>
          <td>0.64</td>
          <td>0.54</td>
          <td>0.69</td>
      </tr>
      <tr>
          <td>MFF</td>
          <td>0.927</td>
          <td>0.851</td>
          <td>0.713</td>
          <td>0.635</td>
          <td>0.184</td>
          <td>0.596</td>
      </tr>
      <tr>
          <td>Yield-BERT</td>
          <td>0.951</td>
          <td>0.838</td>
          <td>0.836</td>
          <td>0.738</td>
          <td>0.538</td>
          <td>0.738</td>
      </tr>
      <tr>
          <td>T5Chem</td>
          <td>0.970</td>
          <td>0.811</td>
          <td>0.907</td>
          <td>0.789</td>
          <td>0.627</td>
          <td>0.785</td>
      </tr>
      <tr>
          <td>CompoundT5</td>
          <td>0.971</td>
          <td>0.855</td>
          <td>0.852</td>
          <td>0.712</td>
          <td>0.547</td>
          <td>0.741</td>
      </tr>
      <tr>
          <td>ReactionT5</td>
          <td>0.966</td>
          <td>0.914</td>
          <td>0.940</td>
          <td>0.819</td>
          <td>0.896</td>
          <td>0.892</td>
      </tr>
      <tr>
          <td>ReactionT5 (zero-shot)</td>
          <td>0.904</td>
          <td>0.919</td>
          <td>0.927</td>
          <td>0.847</td>
          <td>0.909</td>
          <td>0.900</td>
      </tr>
  </tbody>
</table>
<p>ReactionT5 achieves the highest average $R^2$ across Tests 1-4 (0.892), with the zero-shot variant performing even better (0.900). The improvement is most dramatic on Test 4, the hardest split, where ReactionT5 achieves $R^2 = 0.896$ versus T5Chem&rsquo;s 0.627 and Yield-BERT&rsquo;s 0.538.</p>
<p>In a low-data regime (30% train / 70% test), ReactionT5 ($R^2 = 0.927$) substantially outperforms a random forest baseline ($R^2 = 0.853$), and even zero-shot ReactionT5 ($R^2 = 0.898$) exceeds the random forest.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Two-stage pretraining is effective</strong>: Compound pretraining followed by reaction pretraining produces models with strong generalization, particularly on out-of-distribution test sets.</li>
<li><strong>Few-shot transfer works</strong>: With as few as 30 fine-tuning reactions, ReactionT5 achieves over 80% Top-1 accuracy on product prediction, competitive with models trained on the full USPTO dataset.</li>
<li><strong>Compound restoration matters</strong>: Restoring uncategorized compounds in the ORD is essential for product prediction. Without restoration, fine-tuning on 200 USPTO reactions yields 0% accuracy; with restoration, the same fine-tuning yields 85.5% Top-1.</li>
<li><strong>Zero-shot yield prediction is surprisingly effective</strong>: ReactionT5 achieves $R^2 = 0.900$ on the out-of-sample yield tests without any task-specific fine-tuning, outperforming all fine-tuned baselines.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Product prediction shows a high invalidity rate (12.0% for the best ReactionT5 variant) compared to CompoundT5 (7.5%), suggesting the reaction pretraining may introduce some noise.</li>
<li>The 0% accuracy without fine-tuning on product prediction reveals a significant domain gap between ORD and USPTO annotation conventions (byproducts vs. main products).</li>
<li>The RestorationT5 classifier has low precision (0.0878) despite high recall (0.7212), meaning many compounds are incorrectly assigned roles. The paper does not investigate how this impacts downstream performance.</li>
<li>The paper does not report training times, computational costs, or model sizes, making resource requirements unclear.</li>
<li>Only two downstream tasks (product prediction on USPTO, yield prediction on Buchwald-Hartwig) are evaluated.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Compound pretraining</td>
          <td>ZINC</td>
          <td>22,992,522 compounds</td>
          <td>SMILES canonicalized with <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a></td>
      </tr>
      <tr>
          <td>Reaction pretraining</td>
          <td>ORD (restored)</td>
          <td>1,505,916 reactions</td>
          <td>Atom mapping removed, compounds canonicalized</td>
      </tr>
      <tr>
          <td>Product prediction eval</td>
          <td>USPTO</td>
          <td>479,035 reactions</td>
          <td>409K/30K/40K train/val/test split</td>
      </tr>
      <tr>
          <td>Yield prediction eval</td>
          <td>Buchwald-Hartwig C-N</td>
          <td>3,955 reactions</td>
          <td>Random 7:3 split (10 repeats) + 4 OOS tests</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Base architecture: T5 (text-to-text transfer transformer)</li>
<li>Tokenizer: SentencePiece unigram, trained on ZINC, extended with special reaction tokens</li>
<li>Compound pretraining: Span-masked language modeling (15% masking rate, average span length 3)</li>
<li>Beam search: size 10 for product prediction</li>
<li>Output length constraints: min/max from training data distribution</li>
<li>Yield normalization: clipped to [0, 100], then scaled to [0, 1]</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>CompoundT5: T5 pretrained on ZINC</li>
<li>RestorationT5: CompoundT5 fine-tuned for binary classification (reactant vs. reagent)</li>
<li>ReactionT5: CompoundT5 pretrained on ORD for product and yield prediction</li>
<li>Pre-trained weights available on Hugging Face</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Best Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Top-1 accuracy</td>
          <td>Product prediction</td>
          <td>85.5%</td>
          <td>ReactionT5 with 200 fine-tuning reactions</td>
      </tr>
      <tr>
          <td>Top-5 accuracy</td>
          <td>Product prediction</td>
          <td>94.9%</td>
          <td>ReactionT5 with 200 fine-tuning reactions</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Yield prediction (random)</td>
          <td>0.966</td>
          <td>ReactionT5 fine-tuned</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Yield prediction (OOS avg.)</td>
          <td>0.900</td>
          <td>ReactionT5 zero-shot</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Training times and GPU requirements are not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/sagawatatsuya/ReactionT5v2">ReactionT5v2 (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/sagawa">ReactionT5 models (Hugging Face)</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Pre-trained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sagawa, T. &amp; Kojima, R. (2023). ReactionT5: a large-scale pre-trained model towards application of limited reaction data. <em>arXiv preprint arXiv:2311.06708</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{sagawa2023reactiont5,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ReactionT5: a large-scale pre-trained model towards application of limited reaction data}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Sagawa, Tatsuya and Kojima, Ryosuke}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2311.06708}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.2311.06708}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PharMolixFM: Multi-Modal All-Atom Molecular Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/pharmolixfm-all-atom-foundation-models/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/pharmolixfm-all-atom-foundation-models/</guid><description>PharMolixFM unifies diffusion, flow matching, and Bayesian flow networks for all-atom molecular modeling and generation with task-specific denoising priors.</description><content:encoded><![CDATA[<h2 id="a-unified-framework-for-all-atom-molecular-foundation-models">A Unified Framework for All-Atom Molecular Foundation Models</h2>
<p>PharMolixFM is a <strong>Method</strong> paper that introduces a unified framework for constructing all-atom foundation models for molecular modeling and generation. The primary contribution is the systematic implementation of three multi-modal generative model variants (diffusion, flow matching, and Bayesian flow networks) within a single architecture, along with a task-unifying denoising formulation that enables training on multiple structural biology tasks simultaneously. The framework achieves competitive performance on protein-small-molecule docking and structure-based drug design while providing the first empirical analysis of inference scaling laws for molecular generative models.</p>
<h2 id="challenges-in-multi-modal-atomic-modeling">Challenges in Multi-Modal Atomic Modeling</h2>
<p>Existing all-atom foundation models such as AlphaFold3, RoseTTAFold All-Atom, and ESM-AA face two core challenges that limit their generalization across molecular modeling and generation tasks.</p>
<p>First, atomic data is inherently multi-modal: each atom comprises both a discrete atom type and continuous 3D coordinates. This poses challenges for structure models that need to jointly capture and predict both modalities. Unlike text or image data that exhibit a single modality, molecular structures require generative models that can handle discrete categorical variables (atom types, bond types) and continuous variables (coordinates) simultaneously.</p>
<p>Second, there has been no comprehensive analysis of how different training objectives and sampling strategies impact the performance of all-atom foundation models. Prior work has focused on individual model architectures without systematically comparing generative frameworks or studying how inference-time compute scaling affects prediction quality.</p>
<p>PharMolixFM addresses both challenges by providing a unified framework that implements three state-of-the-art multi-modal generative models and formulates all downstream tasks as a generalized denoising process with task-specific priors.</p>
<h2 id="multi-modal-denoising-with-task-specific-priors">Multi-Modal Denoising with Task-Specific Priors</h2>
<p>The core innovation of PharMolixFM is the formulation of molecular tasks as a generalized denoising process where task-specific priors control which parts of the molecular system are noised during training. The framework decomposes a biomolecular system into $N$ atoms represented as a triplet $\bar{\mathbf{S}}_0 = \langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle$, where $\mathbf{X}_0 \in \mathbb{R}^{N \times 3}$ are atom coordinates, $\mathbf{A}_0 \in \mathbb{Z}^{N \times D_1}$ are one-hot atom types, and $\mathbf{E}_0 \in \mathbb{Z}^{N \times N \times D_2}$ are one-hot bond types.</p>
<p>The generative model estimates the density $p_\theta(\langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle)$ subject to SE(3) invariance:</p>
<p>$$
p_\theta(\langle \mathbf{R}\mathbf{X}_0 + \mathbf{t}, \mathbf{A}_0, \mathbf{E}_0 \rangle) = p_\theta(\langle \mathbf{X}_0, \mathbf{A}_0, \mathbf{E}_0 \rangle)
$$</p>
<p>The variational lower bound is optimized over latent variables $S_1, \ldots, S_T$ obtained by adding independent noise to different modalities and atoms:</p>
<p>$$
q(S_{1:T} \mid S_0) = \prod_{i=1}^{T} \prod_{j=1}^{N} q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}, \sigma_{i,j}^{(\mathbf{X})}) , q(\mathbf{A}_{i,j} \mid \mathbf{A}_{0,j}, \sigma_{i,j}^{(\mathbf{A})}) , q(\mathbf{E}_{i,j} \mid \mathbf{E}_{0,j}, \sigma_{i,j}^{(\mathbf{E})})
$$</p>
<p>A key design choice is the noise schedule $\sigma_{i,j}^{(\mathcal{M})} = \frac{i}{T} \cdot \text{fix}_j^{(\mathcal{M})}$, where $\text{fix}_j^{(\mathcal{M})}$ is a scaling factor between 0 and 1 that controls which atoms and modalities receive noise. This &ldquo;Fix&rdquo; mechanism enables multiple training tasks:</p>
<ul>
<li><strong>Docking</strong> ($\text{Fix} = 1$ for protein and molecular graph, $\text{Fix} = 0$ for molecule coordinates): predicts binding pose given known atom/bond types.</li>
<li><strong>Structure-based drug design</strong> ($\text{Fix} = 1$ for protein, $\text{Fix} = 0$ for all molecule properties): generates novel molecules for a given pocket.</li>
<li><strong>Robustness augmentation</strong> ($\text{Fix} = 0.7$ for 15% randomly selected atoms, $\text{Fix} = 0$ for rest): simulates partial structure determination.</li>
</ul>
<h3 id="three-generative-model-variants">Three Generative Model Variants</h3>
<p><strong>Multi-modal diffusion (PharMolixFM-Diff)</strong> uses a Markovian forward process. Continuous coordinates follow Gaussian diffusion while discrete variables use a D3PM categorical transition:</p>
<p>$$
q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}) = \mathcal{N}(\sqrt{\alpha_{i,j}} , \mathbf{X}_{0,j}, (1 - \alpha_{i,j}) \mathbf{I}), \quad \alpha_{i,j} = \prod_{k=1}^{i}(1 - \sigma_{i,j}^{(\mathbf{X})})
$$</p>
<p>$$
q(\mathbf{A}_{i,j} \mid \mathbf{A}_{0,j}) = \text{Cat}(\mathbf{A}_{0,j} \bar{Q}_{i,j}^{(\mathbf{A})}), \quad Q_{i,j}^{(\mathbf{A})} = (1 - \sigma_{i,j}^{(\mathbf{A})}) \mathbf{I} + \frac{\sigma_{i,j}^{(\mathbf{A})}}{D_1} \mathbb{1}\mathbb{1}^T
$$</p>
<p>The training loss combines coordinate MSE with cross-entropy for discrete variables:</p>
<p>$$
\mathcal{L} = \mathbb{E}_{S_0, i, S_i} \left[ \lambda_i^{(\mathbf{X})} | \tilde{\mathbf{X}}_0 - \mathbf{X}_0 |_2^2 + \lambda_i^{(\mathbf{A})} \mathcal{L}_{CE}(\tilde{\mathbf{A}}_0, \mathbf{A}_0) + \lambda_i^{(\mathbf{E})} \mathcal{L}_{CE}(\tilde{\mathbf{E}}_0, \mathbf{E}_0) \right]
$$</p>
<p><strong>Multi-modal flow matching (PharMolixFM-Flow)</strong> constructs a direct mapping between data and prior distributions using conditional vector fields. For coordinates, the conditional flow uses a Gaussian path $q(\mathbf{X}_{i,j} \mid \mathbf{X}_{0,j}) = \mathcal{N}((1 - \sigma_{i,j}^{(\mathbf{X})}) \mathbf{X}_{0,j}, (\sigma_{i,j}^{(\mathbf{X})})^2 \mathbf{I})$, while discrete variables use the same D3PM Markov chain. Sampling proceeds by solving an ODE via Euler integration.</p>
<p><strong>Bayesian flow networks (PharMolixFM-BFN)</strong> perform generative modeling in the parameter space of the data distribution rather than the data space. The Bayesian flow distribution for coordinates is:</p>
<p>$$
p_F(\tilde{\mathbf{X}}_{i,j}^{(\theta)} \mid \mathbf{X}_{0,j}) = \mathcal{N}(\gamma_{i,j} \mathbf{X}_{0,j}, \gamma_{i,j}(1 - \gamma_{i,j}) \mathbf{I}), \quad \gamma_{i,j} = 1 - \alpha^{2(1 - \sigma_{i,j}^{(\mathbf{X})})}
$$</p>
<h3 id="network-architecture">Network Architecture</h3>
<p>The architecture follows PocketXMol with a dual-branch SE(3)-equivariant graph neural network. A protein branch (4-layer GNN with kNN graph) processes pocket atoms, then representations are passed to a molecule branch (6-layer GNN) that captures protein-molecule interactions. Independent prediction heads reconstruct atom coordinates, atom types, and bond types, with additional confidence heads for self-ranking during inference.</p>
<h2 id="docking-and-drug-design-experiments">Docking and Drug Design Experiments</h2>
<h3 id="protein-small-molecule-docking">Protein-Small-Molecule Docking</h3>
<p>PharMolixFM is evaluated on the PoseBusters benchmark (428 protein-small-molecule complexes) using the holo docking setting with a known protein structure and 10 Angstrom binding pocket. The metric is the ratio of predictions with RMSD &lt; 2 Angstrom.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Self-Ranking (%)</th>
          <th>Oracle-Ranking (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DiffDock</td>
          <td>38.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>RFAA</td>
          <td>42.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Vina</td>
          <td>52.3</td>
          <td>-</td>
      </tr>
      <tr>
          <td>UniMol-Docking V2</td>
          <td>77.6</td>
          <td>-</td>
      </tr>
      <tr>
          <td>SurfDock</td>
          <td>78.0</td>
          <td>-</td>
      </tr>
      <tr>
          <td>AlphaFold3</td>
          <td>90.4</td>
          <td>-</td>
      </tr>
      <tr>
          <td>PocketXMol (50 repeats)</td>
          <td>82.2</td>
          <td>95.3</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff (50 repeats)</td>
          <td>83.4</td>
          <td>96.0</td>
      </tr>
      <tr>
          <td>PharMolixFM-Flow (50 repeats)</td>
          <td>73.4</td>
          <td>93.7</td>
      </tr>
      <tr>
          <td>PharMolixFM-BFN (50 repeats)</td>
          <td>78.5</td>
          <td>93.5</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff (500 repeats)</td>
          <td>83.9</td>
          <td>98.1</td>
      </tr>
  </tbody>
</table>
<p>PharMolixFM-Diff achieves the second-best self-ranking result (83.4%), outperforming PocketXMol by 1.7% absolute but trailing AlphaFold3 (90.4%). The key advantage is inference speed: approximately 4.6 seconds per complex on a single A800 GPU compared to approximately 249.0 seconds for AlphaFold3 (a 54x speedup). Under oracle-ranking with 500 repeats, PharMolixFM-Diff reaches 98.1%, suggesting that better ranking strategies could further improve practical performance.</p>
<h3 id="structure-based-drug-design">Structure-Based Drug Design</h3>
<p>Evaluation uses the CrossDocked test set (100 protein pockets, 100 molecules generated per pocket), measuring Vina binding affinity scores and drug-likeness properties (QED and SA).</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Vina Score (Avg/Med)</th>
          <th>QED</th>
          <th>SA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pocket2Mol</td>
          <td>-5.14 / -4.70</td>
          <td>0.57</td>
          <td>0.76</td>
      </tr>
      <tr>
          <td>TargetDiff</td>
          <td>-5.47 / -6.30</td>
          <td>0.48</td>
          <td>0.58</td>
      </tr>
      <tr>
          <td>DecompDiff</td>
          <td>-5.67 / -6.04</td>
          <td>0.45</td>
          <td>0.61</td>
      </tr>
      <tr>
          <td>MolCRAFT</td>
          <td>-6.61 / -8.14</td>
          <td>0.46</td>
          <td>0.62</td>
      </tr>
      <tr>
          <td>PharMolixFM-Diff</td>
          <td>-6.18 / -6.44</td>
          <td>0.50</td>
          <td>0.73</td>
      </tr>
      <tr>
          <td>PharMolixFM-Flow</td>
          <td>-6.34 / -6.47</td>
          <td>0.49</td>
          <td>0.74</td>
      </tr>
      <tr>
          <td>PharMolixFM-BFN</td>
          <td>-6.38 / -6.45</td>
          <td>0.48</td>
          <td>0.64</td>
      </tr>
  </tbody>
</table>
<p>PharMolixFM achieves a better balance between binding affinity and drug-like properties compared to baselines. While MolCRAFT achieves the best Vina scores, PharMolixFM-Diff and Flow variants show notably higher QED (0.49-0.50 vs. 0.45-0.48) and SA (0.73-0.74 vs. 0.58-0.62), which are important for downstream validation and in-vivo application.</p>
<h3 id="inference-scaling-law">Inference Scaling Law</h3>
<p>The paper explores whether inference-time scaling holds for molecular generative models, fitting the relationship:</p>
<p>$$
\text{Acc} = a \log(bR + c) + d
$$</p>
<p>where $R$ is the number of sampling repeats. All three PharMolixFM variants exhibit logarithmic improvement in docking accuracy with increased sampling repeats, analogous to inference scaling laws observed in NLP. Performance plateaus eventually due to distributional differences between training and test sets.</p>
<h2 id="competitive-docking-with-faster-inference-but-limited-task-scope">Competitive Docking with Faster Inference, but Limited Task Scope</h2>
<p>PharMolixFM demonstrates that multi-modal generative models can achieve competitive all-atom molecular modeling with substantial inference speed advantages over AlphaFold3. The key findings are:</p>
<ol>
<li><strong>Diffusion outperforms flow matching and BFN</strong> for docking under standard sampling budgets. The stochastic nature of diffusion sampling appears beneficial compared to the deterministic ODE integration of flow matching.</li>
<li><strong>Oracle-ranking reveals untapped potential</strong>: the gap between self-ranking (83.4%) and oracle-ranking (98.1%) at 500 repeats indicates that confidence-based ranking is a bottleneck. Better ranking methods could close the gap with AlphaFold3.</li>
<li><strong>The three variants show similar performance for drug design</strong>, suggesting that model architecture and training data may matter more than the generative framework for generation tasks.</li>
<li><strong>Inference scaling laws hold</strong> for molecular generative models, paralleling findings in NLP.</li>
</ol>
<p>Limitations include that the framework is only evaluated on two tasks (docking and SBDD), and the paper does not address protein structure prediction, protein-protein interactions, or nucleic acid modeling, which are part of AlphaFold3&rsquo;s scope. The BFN variant underperforms the diffusion model, which the authors attribute to smaller noise scales at early sampling steps making training less challenging. The paper also does not compare against concurrent work on inference-time scaling for molecular models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>PDBBind, Binding MOAD, CrossDocked2020, PepBDB</td>
          <td>Not specified</td>
          <td>Filtered by PocketXMol criteria</td>
      </tr>
      <tr>
          <td>Docking eval</td>
          <td>PoseBusters benchmark</td>
          <td>428 complexes</td>
          <td>Holo docking with known protein</td>
      </tr>
      <tr>
          <td>SBDD eval</td>
          <td>CrossDocked test set</td>
          <td>100 pockets</td>
          <td>100 molecules per pocket</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Three generative variants: multi-modal diffusion (D3PM), flow matching, Bayesian flow networks</li>
<li>Task-specific noise via Fix mechanism (0, 0.7, or 1.0)</li>
<li>Training tasks selected with equal probability per sample</li>
<li>AdamW optimizer: weight decay 0.001, $\beta_1 = 0.99$, $\beta_2 = 0.999$</li>
<li>Linear warmup to learning rate 0.001 over 1000 steps</li>
<li>180K training steps with batch size 40</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Dual-branch SE(3)-equivariant GNN (protein: 4-layer, molecule: 6-layer)</li>
<li>kNN graph construction for protein and protein-molecule interactions</li>
<li>Independent prediction heads for coordinates, atom types, bond types</li>
<li>Confidence heads for self-ranking during inference</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>PharMolixFM-Diff</th>
          <th>AlphaFold3</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSD &lt; 2A self-ranking</td>
          <td>83.4% (50 rep)</td>
          <td>90.4%</td>
          <td>PoseBusters docking</td>
      </tr>
      <tr>
          <td>RMSD &lt; 2A oracle-ranking</td>
          <td>98.1% (500 rep)</td>
          <td>-</td>
          <td>PoseBusters docking</td>
      </tr>
      <tr>
          <td>Inference time (per complex)</td>
          <td>~4.6s</td>
          <td>~249.0s</td>
          <td>Single A800 GPU</td>
      </tr>
      <tr>
          <td>Vina score (avg)</td>
          <td>-6.18</td>
          <td>-</td>
          <td>CrossDocked SBDD</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: 4x 80GB A800 GPUs</li>
<li>Inference benchmarked on single A800 GPU</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/PharMolix/OpenBioMed">OpenBioMed (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, Y., Wang, J., Fan, S., &amp; Nie, Z. (2025). PharMolixFM: All-Atom Foundation Models for Molecular Modeling and Generation. <em>arXiv preprint arXiv:2503.21788</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{luo2025pharmolixfm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{PharMolixFM: All-Atom Foundation Models for Molecular Modeling and Generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Yizhen and Wang, Jiashuo and Fan, Siqi and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2503.21788}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PharmaGPT: Domain-Specific LLMs for Pharma and Chem</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/pharmagpt-domain-specific-llms-biopharmaceutical/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/pharmagpt-domain-specific-llms-biopharmaceutical/</guid><description>PharmaGPT introduces 13B and 70B parameter LLMs trained on biopharmaceutical and chemical corpora, outperforming GPT-3.5 and rivaling GPT-4 on pharmacy exams.</description><content:encoded><![CDATA[<h2 id="a-domain-specific-llm-suite-for-biopharmaceuticals-and-chemistry">A Domain-Specific LLM Suite for Biopharmaceuticals and Chemistry</h2>
<p>This is a <strong>Method</strong> paper that introduces PharmaGPT, a suite of domain-specific large language models with 13 billion and 70 billion parameters. The models are built on the LLaMA architecture and undergo continued pretraining on a curated corpus of biopharmaceutical and chemical literature, followed by instruction fine-tuning and reinforcement learning from human feedback (RLHF). The primary contribution is demonstrating that domain-specific continued pretraining on a general-purpose LLM backbone can produce models that outperform much larger general-purpose models on pharmaceutical knowledge tasks, using only a fraction of the parameters.</p>
<h2 id="bridging-the-gap-between-general-purpose-llms-and-specialized-pharmaceutical-knowledge">Bridging the Gap Between General-Purpose LLMs and Specialized Pharmaceutical Knowledge</h2>
<p>General-purpose LLMs like GPT-3.5 and GPT-4 show impressive broad capabilities but often fall short in specialized domains requiring precise terminology, deep domain knowledge, and high accuracy. The biopharmaceutical and chemical sectors present particular challenges: intricate terminologies, specialized regulatory knowledge, and a demand for precision that general models cannot consistently deliver. Most state-of-the-art LLMs are proprietary, English-centric, and lack depth in vertical domains. The authors identify a gap in the availability of domain-specific LLMs for biomedicine and chemistry, particularly multilingual models that can handle both English and Chinese pharmaceutical content.</p>
<h2 id="continued-pretraining-with-domain-specific-data-and-weighted-instruction-tuning">Continued Pretraining with Domain-Specific Data and Weighted Instruction Tuning</h2>
<p>PharmaGPT&rsquo;s core innovation lies in its training pipeline, which adapts the LLaMA backbone through three stages:</p>
<p><strong>Extended Tokenizer</strong>: The authors develop a new tokenizer using <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">byte-pair encoding (BPE)</a> from SentencePiece, trained on their pretraining data and merged with the LLaMA2 tokenizer. This extends the vocabulary from 32,000 to 55,296 tokens, improving compression efficiency for Chinese text and specialized domain terminology. The embedding and output layers are resized from $V \times H$ to $V&rsquo; \times H$ where $V = 32{,}000$ and $V&rsquo; = 55{,}296$.</p>
<p><strong>Two-Stage Continued Pretraining</strong>: The models consume 153 billion tokens in Stage 1 (primarily web, news, patents, and papers) and 43 billion tokens in Stage 2 (research reports, exams, books, chats, code, and supervised data). The data distribution shifts between stages to move from general domain knowledge toward specialized biopharmaceutical tasks.</p>
<p><strong>Weighted Instruction Fine-tuning</strong>: Inspired by OpenChat, the authors use a weighted autoregressive objective that zeros out loss on user instruction tokens. The loss function is:</p>
<p>$$\mathcal{L}_{SFT}(\Theta) = \mathbb{E}_{x \sim \mathcal{D}_{SFT}} \left[ -\alpha \sum_{i \in \text{output}} \log p(x_i \mid x_0, x_1, \dots, x_{i-1}; \Theta) \right]$$</p>
<p>where the weight $\alpha$ is set to 1 for expert-curated domain-specific instructions ($\mathcal{D}_{\exp}$) and 0.1 for generic instructions ($\mathcal{D}_{\text{gen}}$). This differential weighting ensures domain-relevant instructions receive higher priority during training.</p>
<p><strong>RLHF with PPO</strong>: A reward model is initialized from the pretrained PharmaGPT-70B and enhanced with two MLPs to output a scalar preference score. The reward model is trained with a binary ranking loss:</p>
<p>$$\mathcal{L}_{\text{ranking}} = -\log\left(\sigma\left(r_\theta(x, y_c) - r_\theta(x, y_r)\right)\right)$$</p>
<p>where $r_\theta(x, y_c)$ is the score for the preferred response and $r_\theta(x, y_r)$ is the score for the rejected response. The RLHF dataset consists of 50,000 human preference expert-annotated instructions with responses from PharmaGPT variants and commercial LLMs (GPT-4, ChatGPT-3.5). <a href="https://en.wikipedia.org/wiki/Proximal_policy_optimization">Proximal Policy Optimization (PPO)</a> is used for the RL training, selecting the highest-scoring response from four generated candidates at each step.</p>
<h2 id="evaluation-on-pharmacy-licensing-exams-translation-and-mmlu">Evaluation on Pharmacy Licensing Exams, Translation, and MMLU</h2>
<p>The evaluation covers four main benchmarks:</p>
<p><strong>NAPLEX (North American Pharmacist Licensure Examination)</strong>: PharmaGPT is tested across three NAPLEX sections. Results show consistent improvement across model iterations:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>NAPLEX I</th>
          <th>NAPLEX II</th>
          <th>NAPLEX III</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT 0.1</td>
          <td>5.0</td>
          <td>2.5</td>
          <td>3.5</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.3</td>
          <td>42.0</td>
          <td>48.0</td>
          <td>46.5</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.5</td>
          <td>57.0</td>
          <td>59.0</td>
          <td>58.0</td>
      </tr>
      <tr>
          <td>PharmaGPT 0.7</td>
          <td>66.0</td>
          <td>68.0</td>
          <td>76.0</td>
      </tr>
  </tbody>
</table>
<p>PharmaGPT 0.7 scores in the 66-76% range across all three NAPLEX sections, outperforming GPT-3.5-turbo by considerable margins.</p>
<p><strong>Chinese Pharmacist Examination</strong>: PharmaGPT achieves scores in the 70% range across all four exam categories, outperforming both GPT-3.5-turbo and GPT-4 in all categories. This result is notable given GPT-4&rsquo;s much larger scale.</p>
<p><strong>Biomedical Translation</strong>: PharmaGPT 0.7 outperforms GPT-3.5, Claude 3, and Google Translate on biomedical paper translation (English-Chinese), achieving <a href="https://en.wikipedia.org/wiki/BLEU">BLEU</a> scores of 30 (paragraph-level), 18 (sentence-level), and 10 (word-level).</p>
<p><strong>MMLU</strong>: On the general Multitask Multilingual Language Understanding benchmark, PharmaGPT achieves scores in the 80% range across most biomedical and life science tasks, surpassing GPT-3.5-turbo and performing comparably to GPT-4 in areas such as physiology, health sciences, and biology.</p>
<h2 id="strong-domain-performance-with-smaller-scale-but-limited-reproducibility">Strong Domain Performance with Smaller Scale, but Limited Reproducibility</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Domain-specific continued pretraining enables a 70B parameter model to match or exceed GPT-4 on pharmaceutical knowledge tasks, despite having a fraction of GPT-4&rsquo;s parameters</li>
<li>Iterative post-training (versions 0.1 through 0.7) shows consistent improvement, with the largest gains occurring between versions 0.3 and 0.5</li>
<li>The two-stage pretraining strategy, shifting from general domain data to more specialized exam and report data, appears effective for building domain expertise</li>
<li>Scaling laws hold within the PharmaGPT family: larger parameter counts consistently produce better performance on both NAPLEX and Chinese pharmaceutical exams</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>Potential biases in the training data</li>
<li>Model dependency on the quality and diversity of input prompts</li>
<li>Challenges in accurately assessing performance on highly specialized tasks without domain expert evaluation</li>
<li>Interpretability concerns for use in sensitive healthcare and pharmaceutical applications</li>
<li>The 3B model is trained from scratch while the 13B and 70B models use LLaMA as a backbone, making direct comparison across model sizes less straightforward</li>
</ul>
<p><strong>Missing details</strong>: The paper does not release model weights, training code, or the proprietary training dataset. No ablation studies isolate the contribution of each training stage (continued pretraining vs. instruction tuning vs. RLHF). The evaluation is limited to multiple-choice exams and translation, without testing on molecular property prediction, reaction prediction, or other computational chemistry tasks common in this domain.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining Stage 1</td>
          <td>Web, News, Patents, Papers</td>
          <td>153B tokens</td>
          <td>Proprietary corpus; not publicly available</td>
      </tr>
      <tr>
          <td>Pretraining Stage 2</td>
          <td>Research Reports, Exams, Books, Chats, Code</td>
          <td>43B tokens</td>
          <td>Proprietary corpus; not publicly available</td>
      </tr>
      <tr>
          <td>Instruction Tuning</td>
          <td>Manually labeled + synthesized data</td>
          <td>Several hundred thousand instructions</td>
          <td>Includes expert Q&amp;A, patent data, ShareGPT</td>
      </tr>
      <tr>
          <td>RLHF</td>
          <td>Human preference annotations</td>
          <td>50,000 annotated instructions</td>
          <td>Expert annotators ranked responses</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>NAPLEX, Chinese Pharmacist Exam, MMLU, MT</td>
          <td>Not specified</td>
          <td>Exam datasets sourced from public exams</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Base architecture</strong>: LLaMA (13B and 70B variants); 3B model trained from scratch</li>
<li><strong>Tokenizer</strong>: Extended BPE tokenizer (55,296 vocab size) merged with LLaMA2 tokenizer</li>
<li><strong>Training objective</strong>: Standard autoregressive LM (pretraining), weighted autoregressive with $\alpha \in {0.1, 1.0}$ (SFT), PPO (RLHF)</li>
<li><strong>Reward model</strong>: Initialized from PharmaGPT-70B with two additional MLPs</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Base</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT-3B</td>
          <td>3B</td>
          <td>Trained from scratch</td>
          <td>Not evaluated in main results</td>
      </tr>
      <tr>
          <td>PharmaGPT-13B</td>
          <td>13B</td>
          <td>LLaMA-13B</td>
          <td>Post-trained</td>
      </tr>
      <tr>
          <td>PharmaGPT-70B</td>
          <td>70B</td>
          <td>LLaMA-70B</td>
          <td>Primary model; versions 0.1-0.7 reported</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>PharmaGPT 0.7</th>
          <th>GPT-3.5</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NAPLEX I</td>
          <td>66%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>NAPLEX II</td>
          <td>68%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>NAPLEX III</td>
          <td>76%</td>
          <td>~50%</td>
          <td>Estimated from figures</td>
      </tr>
      <tr>
          <td>Chinese Pharmacist Exam</td>
          <td>~70% range</td>
          <td>Lower</td>
          <td>Outperforms GPT-4</td>
      </tr>
      <tr>
          <td>Biomedical Translation (paragraph BLEU)</td>
          <td>30</td>
          <td>27</td>
          <td>English-Chinese</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify the hardware used for training. Training hyperparameters for the 70B model include tensor parallelism (TP=8) and pipeline parallelism (PP=16) during pretraining, suggesting multi-node GPU training, likely on at least 128 GPUs.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PharmaGPT models</td>
          <td>Model</td>
          <td>Not released</td>
          <td>No public weights or API access</td>
      </tr>
      <tr>
          <td>Training data</td>
          <td>Dataset</td>
          <td>Proprietary</td>
          <td>PatSnap internal data</td>
      </tr>
      <tr>
          <td>Training code</td>
          <td>Code</td>
          <td>Not released</td>
          <td>No public repository</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: <strong>Closed</strong>. Neither the model weights, training data, nor training code are publicly available. The proprietary nature of both the data pipeline and the models makes independent reproduction infeasible.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, L., Wang, W., Bai, Z., Xu, P., Fang, Y., Fang, J., &hellip; &amp; Tu, C. (2024). PharmaGPT: Domain-Specific Large Language Models for Bio-Pharmaceutical and Chemistry. <em>arXiv preprint arXiv:2406.18045</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chen2024pharmagpt,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{PharmaGPT: Domain-Specific Large Language Models for Bio-Pharmaceutical and Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Linqing and Wang, Weilei and Bai, Zilong and Xu, Peng and Fang, Yan and Fang, Jie and Wu, Wentao and Zhou, Lizhi and Zhang, Ruiji and Xia, Yubin and Xu, Chaobo and Hu, Ran and Xu, Licong and Cai, Qijun and Hua, Haoran and Sun, Jing and Liu, Jin and Qiu, Tian and Liu, Haowen and Hu, Meng and Li, Xiuwen and Gao, Fei and Wang, Yufu and Tie, Lin and Wang, Chaochao and Lu, Jianping and Sun, Cheng and Wang, Yixin and Yang, Shengjie and Li, Yuancheng and Jin, Lu and Zhang, Lisha and Bian, Fu and Ye, Zhongkai and Pei, Lidong and Tu, Changyang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2406.18045}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2406.18045}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ORGAN: Objective-Reinforced GANs for Molecule Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/</guid><description>ORGAN combines GANs with reinforcement learning to steer SMILES-based molecular generation toward drug-likeness, solubility, and synthesizability objectives.</description><content:encoded><![CDATA[<h2 id="combining-gans-and-reinforcement-learning-for-goal-directed-sequence-generation">Combining GANs and Reinforcement Learning for Goal-Directed Sequence Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces ORGAN (Objective-Reinforced Generative Adversarial Network), a framework for generating sequences that are both realistic (close to the training distribution) and optimized for domain-specific objectives. ORGAN extends SeqGAN by adding external reward functions to the reinforcement learning signal, with a tunable parameter $\lambda$ controlling the balance between adversarial (discriminator) and objective-based rewards. The authors demonstrate ORGAN on two domains: molecular generation using <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings (optimizing druglikeness, solubility, and synthesizability) and musical melody generation (optimizing tonality and step ratios).</p>
<h2 id="exposure-bias-and-mode-collapse-in-discrete-sequence-generation">Exposure Bias and Mode Collapse in Discrete Sequence Generation</h2>
<p>Generating discrete sequences with desirable properties presents two intertwined challenges. First, RNNs trained via maximum likelihood estimation (MLE) suffer from exposure bias, where the model sees only ground-truth prefixes during training but must condition on its own (potentially erroneous) outputs at generation time. Second, while <a href="/posts/what-is-a-gan/">GANs</a> can address some of these issues through adversarial training, they were not initially applicable to discrete data due to non-differentiability of the sampling step. SeqGAN resolved this by framing the generator as an RL agent, but it optimizes only for distributional fidelity (fooling the discriminator) without any mechanism to steer generation toward specific property targets.</p>
<p>In drug discovery, simply generating valid, drug-like molecules is insufficient. Practitioners need to optimize for particular pharmaceutical properties (e.g., solubility, synthesizability, druglikeness) while maintaining structural diversity. Naive RL approaches can optimize properties effectively but tend to collapse onto trivial solutions (e.g., repeating &ldquo;CCCCCCC&rdquo; to maximize solubility). The challenge is to combine the distributional regularization of adversarial training with the goal-directedness of RL.</p>
<h2 id="mixed-reward-interpolating-between-adversarial-and-objective-signals">Mixed Reward: Interpolating Between Adversarial and Objective Signals</h2>
<p>ORGAN&rsquo;s core innovation is a reward function that linearly interpolates between the discriminator score and domain-specific objectives:</p>
<p>$$R(Y_{1:T}) = \lambda \cdot D_{\phi}(Y_{1:T}) + (1 - \lambda) \cdot O_{i}(Y_{1:T})$$</p>
<p>When $\lambda = 1$, the model reduces to SeqGAN (pure adversarial training). When $\lambda = 0$, it becomes naive RL optimizing only the objective. Intermediate values allow the adversarial component to regularize the generator, keeping samples within the distribution while the objective component steers toward desired properties.</p>
<p>The generator $G_{\theta}$ is an LSTM-based RNN that produces sequences token-by-token. Training follows the REINFORCE algorithm, where the expected long-term reward is:</p>
<p>$$J(\theta) = \mathbb{E}\left[R(Y_{1:T}) \mid s_{0}, \theta\right] = \sum_{y_{1} \in Y} G_{\theta}(y_{1} \mid s_{0}) \cdot Q(s_{0}, y_{1})$$</p>
<p>For intermediate timesteps (partial sequences), the action-value function $Q$ is estimated via $N$-time Monte Carlo rollouts:</p>
<p>$$Q(Y_{1:t-1}, y_{t}) = \begin{cases} \frac{1}{N} \sum_{n=1}^{N} R(Y_{1:T}^{n}), &amp; \text{if } t &lt; T \\ R(Y_{1:T}), &amp; \text{if } t = T \end{cases}$$</p>
<p>where $Y_{1:T}^{n}$ are completions sampled by rolling out the current policy $G_{\theta}$ from state $Y_{1:t}$.</p>
<p>The policy gradient is:</p>
<p>$$\nabla_{\theta} J(\theta) \simeq \frac{1}{T} \sum_{t=1}^{T} \mathbb{E}_{y_{t} \sim G_{\theta}(y_{t} \mid Y_{1:t-1})} \left[\nabla_{\theta} \log G_{\theta}(y_{t} \mid Y_{1:t-1}) \cdot Q(Y_{1:t-1}, y_{t})\right]$$</p>
<p>Two additional mechanisms improve training:</p>
<ol>
<li><strong>Diversity penalty</strong>: Repeated sequences have their reward divided by their copy count, providing diminishing returns for non-unique outputs.</li>
<li><strong>Wasserstein distance</strong>: The authors also implement a variant (OR(W)GAN) that replaces the standard GAN discriminator loss with the Wasserstein-1 distance via Kantorovich-Rubinstein duality, which can improve training stability and diversity.</li>
</ol>
<h2 id="molecular-and-musical-melody-generation-experiments">Molecular and Musical Melody Generation Experiments</h2>
<h3 id="architecture">Architecture</h3>
<p>The generator $G_{\theta}$ is an RNN with LSTM cells. The discriminator $D_{\phi}$ is a CNN for text classification following Kim (2014), with 75% dropout and L2 regularization. All optimization uses Adam. Molecular metrics are computed with RDKit.</p>
<h3 id="molecular-generation-setup">Molecular Generation Setup</h3>
<p>Training data consists of 5,000 random molecules from the <a href="/notes/chemistry/datasets/qm9/">QM9</a> dataset (134k stable small molecules with up to 9 heavy atoms), encoded as SMILES strings with maximum sequence length 51 and alphabet size 43. Each generator is pre-trained for 250 MLE epochs, with the discriminator trained for 10 epochs. Adversarial/RL training then proceeds for up to 100 additional epochs. The default $\lambda$ is 0.5.</p>
<p>Three molecular objectives are evaluated:</p>
<ul>
<li><strong>Solubility (LogP)</strong>: water-octanol partition coefficient via RDKit&rsquo;s Crippen function</li>
<li><strong>Synthesizability</strong>: SA score estimating ease of synthesis (0 = hard, 1 = easy)</li>
<li><strong>Druglikeness</strong>: QED score capturing medicinal chemistry aesthetics</li>
</ul>
<p>Diversity is measured using average Jaccard distance of molecular fingerprints relative to a random training subset.</p>
<h3 id="molecular-generation-results">Molecular Generation Results</h3>
<table>
  <thead>
      <tr>
          <th>Objective</th>
          <th>Algorithm</th>
          <th>Validity (%)</th>
          <th>Diversity</th>
          <th>Druglikeness</th>
          <th>Synthesizability</th>
          <th>Solubility</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>None</td>
          <td>MLE</td>
          <td>75.9</td>
          <td>0.64</td>
          <td>0.48 (0%)</td>
          <td>0.23 (0%)</td>
          <td>0.30 (0%)</td>
      </tr>
      <tr>
          <td>None</td>
          <td>SeqGAN</td>
          <td>80.3</td>
          <td>0.61</td>
          <td>0.49 (+2%)</td>
          <td>0.25 (+6%)</td>
          <td>0.31 (+3%)</td>
      </tr>
      <tr>
          <td>Druglikeness</td>
          <td>ORGAN</td>
          <td>88.2</td>
          <td>0.55</td>
          <td>0.52 (+8%)</td>
          <td>0.32 (+38%)</td>
          <td>0.35 (+18%)</td>
      </tr>
      <tr>
          <td>Druglikeness</td>
          <td>OR(W)GAN</td>
          <td>85.0</td>
          <td>0.95</td>
          <td>0.60 (+25%)</td>
          <td>0.54 (+130%)</td>
          <td>0.47 (+57%)</td>
      </tr>
      <tr>
          <td>Druglikeness</td>
          <td>Naive RL</td>
          <td>97.1</td>
          <td>0.80</td>
          <td>0.57 (+19%)</td>
          <td>0.53 (+126%)</td>
          <td>0.50 (+67%)</td>
      </tr>
      <tr>
          <td>Synthesizability</td>
          <td>ORGAN</td>
          <td>96.5</td>
          <td>0.92</td>
          <td>0.51 (+6%)</td>
          <td>0.83 (+255%)</td>
          <td>0.45 (+52%)</td>
      </tr>
      <tr>
          <td>Synthesizability</td>
          <td>OR(W)GAN</td>
          <td>97.6</td>
          <td>1.00</td>
          <td>0.20 (-59%)</td>
          <td>0.75 (+223%)</td>
          <td>0.84 (+184%)</td>
      </tr>
      <tr>
          <td>Solubility</td>
          <td>ORGAN</td>
          <td>94.7</td>
          <td>0.76</td>
          <td>0.50 (+4%)</td>
          <td>0.63 (+171%)</td>
          <td>0.55 (+85%)</td>
      </tr>
      <tr>
          <td>Solubility</td>
          <td>OR(W)GAN</td>
          <td>94.1</td>
          <td>0.90</td>
          <td>0.42 (-12%)</td>
          <td>0.66 (+185%)</td>
          <td>0.54 (+81%)</td>
      </tr>
      <tr>
          <td>Solubility</td>
          <td>Naive RL</td>
          <td>92.7</td>
          <td>0.75</td>
          <td>0.49 (+3%)</td>
          <td>0.70 (+200%)</td>
          <td>0.78 (+162%)</td>
      </tr>
      <tr>
          <td>All (alternated)</td>
          <td>ORGAN</td>
          <td>96.1</td>
          <td>92.3</td>
          <td>0.52 (+9%)</td>
          <td>0.71 (+206%)</td>
          <td>0.53 (+79%)</td>
      </tr>
  </tbody>
</table>
<p>Key observations: OR(W)GAN consistently achieves higher diversity than standard ORGAN. Naive RL often achieves higher raw objective scores but at the cost of generating trivial solutions (e.g., simple atom chains for solubility). The Wasserstein variant provides better diversity properties. Multi-objective training via alternating objectives across epochs achieves gains comparable to individually optimized models.</p>
<h3 id="music-generation-setup">Music Generation Setup</h3>
<p>Using 1,000 melodies from the EsAC folk dataset, each encoded as 36-token sequences where tokens represent sixteenth-note events across three octaves (C3-B5). Two metrics are optimized: tonality (proportion of perfect fifths) and ratio of steps (conjunct melodic motion). Diversity is measured as average pairwise edit distance.</p>
<h3 id="music-results">Music Results</h3>
<table>
  <thead>
      <tr>
          <th>Objective</th>
          <th>Algorithm</th>
          <th>Diversity</th>
          <th>Tonality</th>
          <th>Ratio of Steps</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>None</td>
          <td>MLE</td>
          <td>0.221</td>
          <td>0.007</td>
          <td>0.010</td>
      </tr>
      <tr>
          <td>None</td>
          <td>SeqGAN</td>
          <td>0.187</td>
          <td>0.005</td>
          <td>0.010</td>
      </tr>
      <tr>
          <td>Tonality</td>
          <td>Naive RL</td>
          <td>0.100</td>
          <td>0.478</td>
          <td>2.9E-05</td>
      </tr>
      <tr>
          <td>Tonality</td>
          <td>ORGAN</td>
          <td>0.268</td>
          <td>0.372</td>
          <td>1.78E-04</td>
      </tr>
      <tr>
          <td>Tonality</td>
          <td>OR(W)GAN</td>
          <td>0.268</td>
          <td>0.177</td>
          <td>2.4E-04</td>
      </tr>
      <tr>
          <td>Ratio of Steps</td>
          <td>Naive RL</td>
          <td>0.321</td>
          <td>0.001</td>
          <td>0.829</td>
      </tr>
      <tr>
          <td>Ratio of Steps</td>
          <td>ORGAN</td>
          <td>0.433</td>
          <td>0.001</td>
          <td>0.632</td>
      </tr>
      <tr>
          <td>Ratio of Steps</td>
          <td>OR(W)GAN</td>
          <td>0.134</td>
          <td>5.95E-05</td>
          <td>0.622</td>
      </tr>
  </tbody>
</table>
<p>ORGAN outperforms SeqGAN and MLE on all metrics. Naive RL achieves higher raw scores but with lower diversity, producing simpler, less interesting outputs.</p>
<h2 id="capacity-ceilings-trade-offs-and-future-directions">Capacity Ceilings, Trade-offs, and Future Directions</h2>
<p>The authors identify several limitations and findings:</p>
<p><strong>Capacity ceiling</strong>: GAN-based models tend to generate sequences matching the training set&rsquo;s average length (15.42 characters). RL-only approaches can break this constraint, generating shorter (9.4) or longer (21.3) sequences depending on the objective. The upper bound of optimized properties also matches the training data&rsquo;s maximum, suggesting dataset-dependent limits.</p>
<p><strong>Lambda trade-off</strong>: Varying $\lambda$ reveals an optimal balance between objective optimization and distributional fidelity. This optimum depends on the model, dataset, and metric, suggesting that hyperparameter search over $\lambda$ is important in practice.</p>
<p><strong>Tonality vs. steps inverse relationship</strong>: In the music task, optimizing for tonality (perfect fifths) inherently conflicts with optimizing for step ratios (consecutive notes), since consecutive scale notes do not form perfect fifths.</p>
<p><strong>Limitations</strong>: The paper evaluates on relatively small datasets (5k molecules, 1k melodies) and short sequences. The molecular experiments use QM9 (small molecules with up to 9 heavy atoms), which limits the scope of conclusions for drug-like chemical space. The Wasserstein variant sometimes lags behind the standard GAN loss in raw metric scores, though it offers better diversity.</p>
<p><strong>Future directions</strong>: The authors propose extending ORGAN to non-sequential data (images, audio) by framing GANs as RL problems more broadly, and investigating how different heuristic choices affect performance. They also suggest exploring other discrete GAN formulations (MaliGAN, BGAN) with RL extensions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Molecular training</td>
          <td>QM9 subset</td>
          <td>5,000 molecules</td>
          <td>Random subset from 134k stable small molecules with up to 9 heavy atoms</td>
      </tr>
      <tr>
          <td>Music training</td>
          <td>EsAC folk dataset</td>
          <td>1,000 melodies</td>
          <td>36-token sequences, processed following Chen et al. (2017)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Generator pre-trained for 250 epochs via MLE; discriminator for 10 epochs</li>
<li>Adversarial/RL training for up to 100 epochs</li>
<li>Default $\lambda = 0.5$ for reward mixing</li>
<li>Monte Carlo rollouts for intermediate reward estimation</li>
<li>Duplicate penalty: reward divided by copy count</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Generator</strong>: RNN with LSTM cells</li>
<li><strong>Discriminator</strong>: CNN for text classification (Kim, 2014) with 75% dropout, L2 regularization</li>
<li><strong>Optimizer</strong>: Adam for all gradient descent steps</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Domain</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity (%)</td>
          <td>Fraction of generated SMILES that decode to valid molecules</td>
          <td>Molecules</td>
      </tr>
      <tr>
          <td>Diversity</td>
          <td>Average Jaccard distance of fingerprints to training subset</td>
          <td>Molecules</td>
      </tr>
      <tr>
          <td>Druglikeness (QED)</td>
          <td>Quantitative Estimate of Drug-likeness</td>
          <td>Molecules</td>
      </tr>
      <tr>
          <td>Synthesizability (SA)</td>
          <td>Synthetic accessibility score</td>
          <td>Molecules</td>
      </tr>
      <tr>
          <td>Solubility (LogP)</td>
          <td>Water-octanol partition coefficient</td>
          <td>Molecules</td>
      </tr>
      <tr>
          <td>Tonality</td>
          <td>Proportion of perfect fifths</td>
          <td>Music</td>
      </tr>
      <tr>
          <td>Ratio of Steps</td>
          <td>Proportion of conjunct melodic intervals</td>
          <td>Music</td>
      </tr>
      <tr>
          <td>Diversity (edit)</td>
          <td>Average pairwise edit distance</td>
          <td>Music</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gablg1/ORGAN">ORGAN</a></td>
          <td>Code</td>
          <td>GPL-2.0</td>
          <td>Official implementation including metrics for molecules and music</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Guimaraes, G. L., Sánchez-Lengeling, B., Outeiral, C., Farias, P. L. C., &amp; Aspuru-Guzik, A. (2017). Objective-Reinforced Generative Adversarial Networks (ORGAN) for Sequence Generation Models. <em>arXiv preprint arXiv:1705.10843</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{guimaraes2017organ,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Objective-Reinforced Generative Adversarial Networks (ORGAN) for Sequence Generation Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Guimaraes, Gabriel Lima and Sanchez-Lengeling, Benjamin and Outeiral, Carlos and Farias, Pedro Luis Cunha and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1705.10843}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Machine Translation for Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/</guid><description>Nam and Kim apply a GRU-based seq2seq model with attention to predict organic reaction products from SMILES, pioneering the NMT approach to chemistry.</description><content:encoded><![CDATA[<h2 id="pioneering-seq2seq-translation-for-reaction-prediction">Pioneering Seq2Seq Translation for Reaction Prediction</h2>
<p>This is a <strong>Method</strong> paper. It introduces the idea of applying neural machine translation (NMT) to organic chemistry reaction prediction by framing product prediction as a sequence-to-sequence translation problem from reactant/reagent <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> to product SMILES. This was one of the earliest works to demonstrate that a data-driven encoder-decoder model could predict reaction products without any hand-coded reaction rules or SMARTS transformations.</p>
<h2 id="limitations-of-existing-reaction-prediction-methods">Limitations of Existing Reaction Prediction Methods</h2>
<p>Prior computational approaches to reaction prediction fell into three categories, each with significant drawbacks:</p>
<ol>
<li>
<p><strong>Rule-based methods</strong> (e.g., CAMEO, EROS) relied on manually encoded reaction rules. They performed well on reactions covered by the rules but required continuous manual encoding as new reaction types were discovered. Many older systems became outdated for this reason.</p>
</li>
<li>
<p><strong>Physical calculation methods</strong> computed energies of transition states from plausible reaction pathways using quantum mechanics. While principled, these approaches carried high computational cost. Simplified approaches (ToyChem, ROBIA) traded accuracy for speed.</p>
</li>
<li>
<p><strong>Machine learning methods</strong> at the time either predicted individual mechanistic steps (requiring tree search for multi-step reactions) or classified reaction types and applied SMARTS transformations to generate products. The classification-based approach of Wei et al. still required manual encoding of SMARTS transformations for new reaction types and struggled with ambiguous reaction classes.</p>
</li>
</ol>
<p>The key gap was the absence of a method that could predict reaction products directly from input molecules, learn from data alone, and generalize to new reaction types without manual rule encoding.</p>
<h2 id="core-innovation-reactions-as-machine-translation">Core Innovation: Reactions as Machine Translation</h2>
<p>The central insight is that SMILES strings can be treated as a language with grammatical specifications. Predicting reaction products then becomes a problem of translating &ldquo;reactant and reagent&rdquo; sentences into &ldquo;product&rdquo; sentences.</p>
<p>The model uses a <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a>-based encoder-decoder architecture with attention:</p>
<ul>
<li><strong>Encoder</strong>: 3 layers of GRU cells that process the reversed, tokenized SMILES string of reactants and reagents</li>
<li><strong>Decoder</strong>: 3 layers of GRU cells that generate product SMILES tokens autoregressively</li>
<li><strong>Attention mechanism</strong>: allows the decoder to attend to relevant encoder states at each generation step</li>
<li><strong>Embedding dimension</strong>: 600</li>
<li><strong>Vocabulary</strong>: 311 input tokens (reactants/reagents), 180 output tokens (products)</li>
<li><strong>Bucketed sequences</strong>: four bucket sizes handle variable-length inputs and outputs: (54, 54), (70, 60), (90, 65), (150, 80)</li>
</ul>
<p>The SMILES tokenization uses a <a href="https://en.wikipedia.org/wiki/Parsing_expression_grammar">PEG</a>-based parser that splits SMILES strings into atoms, bonds, branching symbols, and ring closure numbers. Input sequences are reversed before feeding to the encoder, following standard practice in NMT at the time.</p>
<p>The translation objective finds the product sequence $\mathbf{y}$ that maximizes the conditional probability:</p>
<p>$$p(\mathbf{y} \mid \mathbf{x}) = \prod_{t=1}^{T} p(y_t \mid y_1, \ldots, y_{t-1}, \mathbf{x})$$</p>
<p>where $\mathbf{x}$ is the tokenized reactant/reagent sequence and $T$ is the product sequence length.</p>
<h2 id="training-data-and-experimental-evaluation">Training Data and Experimental Evaluation</h2>
<h3 id="training-sets">Training Sets</h3>
<p>Two training sets were constructed:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Source</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Patent reactions (&ldquo;real&rdquo;)</td>
          <td style="text-align: left">1,094,235</td>
          <td style="text-align: left">USPTO patent applications (2001-2013), filtered by length</td>
      </tr>
      <tr>
          <td style="text-align: left">Generated reactions (&ldquo;gen&rdquo;)</td>
          <td style="text-align: left">865,118</td>
          <td style="text-align: left">75 reaction types from Wade&rsquo;s organic chemistry textbook, applied to <a href="/notes/chemistry/datasets/gdb-11/">GDB-11</a> molecules (1-10 atoms)</td>
      </tr>
  </tbody>
</table>
<p>The &ldquo;real&rdquo; set was filtered to exclude reactions with reactant/reagent strings longer than 150 characters, product strings longer than 80 characters, or more than four products. The &ldquo;gen&rdquo; set was composed by iterating reaction templates (as SMARTS) over small molecules from GDB-11, covering five substrate types: acid derivatives, alcohols, aldehydes/ketones, alkenes, and alkynes.</p>
<p>Two models were compared: a &ldquo;gen&rdquo; model (trained only on generated reactions) and a &ldquo;real+gen&rdquo; model (trained on both sets).</p>
<h3 id="textbook-problem-evaluation">Textbook Problem Evaluation</h3>
<p>The models were tested on 10 problem sets from Wade&rsquo;s textbook, following the evaluation approach of Wei et al. Each problem set contained 6-15 reactions. Evaluation metrics included the ratio of fully correct predictions and the average <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> between Morgan fingerprints of predicted and actual products.</p>
<p>The &ldquo;real+gen&rdquo; model outperformed the &ldquo;gen&rdquo; model on most problem sets. On problem set 17-44 (aromatic compound reactions, only present in the &ldquo;real&rdquo; training set), the &ldquo;real+gen&rdquo; model correctly answered 4 out of 11 problems while the &ldquo;gen&rdquo; model answered 2. The &ldquo;gen&rdquo; model&rsquo;s ability to correctly predict some aromatic reactions despite never being trained on them suggests the model can extrapolate to unseen reaction patterns.</p>
<p>For <a href="https://en.wikipedia.org/wiki/Diels%E2%80%93Alder_reaction">Diels-Alder reactions</a> (problem set 15-30), neither model achieved fully correct predictions for all problems, though the &ldquo;real+gen&rdquo; model showed better Tanimoto scores, indicating partially correct structural predictions even when the exact product was missed.</p>
<h3 id="scalability-testing">Scalability Testing</h3>
<p>A scalability test used generated reactions with substrate molecules containing 11-16 atoms (larger than the training set molecules with fewer than 11 atoms). Results showed:</p>
<ul>
<li>The &ldquo;real+gen&rdquo; model maintained Tanimoto scores around 0.7 and error rates around 0.4 as substrate atom count increased</li>
<li>The ratio of fully correct predictions decreased as atom count increased, revealing that the recurrent network struggled with longer input sequences</li>
<li>The &ldquo;real+gen&rdquo; model produced fewer invalid SMILES strings than the &ldquo;gen&rdquo; model, likely because training on more reactions improved the decoder&rsquo;s ability to generate syntactically valid SMILES</li>
</ul>
<h3 id="attention-analysis">Attention Analysis</h3>
<p>Visualization of attention weights revealed a limitation: the decoder cells predominantly attended to the first few encoder cells rather than distributing attention across the full input sequence. This means the attention mechanism was not learning meaningful &ldquo;alignment&rdquo; between reactant atoms and product atoms. The authors note that if decoder cells generating tokens for unreactive sites could attend to the corresponding encoder cells (analogous to atom mapping), prediction quality on longer sequences could improve.</p>
<h3 id="token-embedding-analysis">Token Embedding Analysis</h3>
<p>t-SNE visualization of the learned token embeddings showed that encoder and decoder tokens clustered primarily by syntactic similarity rather than chemical properties. The model did not learn chemically meaningful embeddings, which the authors identify as an area for future improvement.</p>
<h2 id="key-findings-limitations-and-impact">Key Findings, Limitations, and Impact</h2>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li>Treating reaction prediction as NMT is viable: the seq2seq model can predict products without any hand-coded rules</li>
<li>Training on real patent data significantly improves prediction over generated data alone</li>
<li>The model can extrapolate to reaction types not seen during training (e.g., the &ldquo;gen&rdquo; model predicting aromatic reactions)</li>
<li>Compared to the fingerprint-based approach of Wei et al., this method performed better on textbook problems and eliminated the need for manual SMARTS encoding</li>
</ul>
<h3 id="limitations">Limitations</h3>
<ul>
<li><strong>Invalid SMILES generation</strong>: the token-by-token generation process can produce syntactically invalid SMILES (e.g., mismatched parentheses), which the authors scored as zero</li>
<li><strong>Sequence length degradation</strong>: prediction accuracy dropped for longer SMILES strings, a known limitation of RNN-based seq2seq models at the time</li>
<li><strong>Poor attention alignment</strong>: attention weights collapsed to the first encoder positions rather than learning meaningful reactant-product correspondences</li>
<li><strong>Chemically naive embeddings</strong>: token embeddings did not capture chemical properties</li>
<li><strong>Multiple reaction pathways</strong>: reactions with competing pathways (e.g., substitution vs. elimination) were difficult for the model to handle</li>
</ul>
<h3 id="historical-significance">Historical Significance</h3>
<p>This paper is historically significant as one of the first (alongside concurrent work) to propose the NMT framing for reaction prediction. This framing was later adopted and refined by the <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a> (Schwaller et al., 2019), which replaced GRUs with the Transformer architecture and achieved over 90% top-1 accuracy on standard benchmarks. The conceptual contribution of treating SMILES-to-SMILES translation as machine translation became the foundation of an entire subfield.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Training (real)</td>
          <td style="text-align: left">USPTO patent reactions</td>
          <td style="text-align: left">1,094,235</td>
          <td style="text-align: left">2001-2013 applications, filtered by length</td>
      </tr>
      <tr>
          <td style="text-align: left">Training (gen)</td>
          <td style="text-align: left">Generated from Wade textbook templates</td>
          <td style="text-align: left">865,118</td>
          <td style="text-align: left">75 reaction types, GDB-11 substrates</td>
      </tr>
      <tr>
          <td style="text-align: left">Testing (textbook)</td>
          <td style="text-align: left">Wade textbook problems</td>
          <td style="text-align: left">~100</td>
          <td style="text-align: left">10 problem sets, 6-15 reactions each</td>
      </tr>
      <tr>
          <td style="text-align: left">Testing (scalability)</td>
          <td style="text-align: left">Generated from <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a></td>
          <td style="text-align: left">2,400</td>
          <td style="text-align: left">400 per atom count (11-16)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>GRU-based encoder-decoder with attention mechanism</li>
<li>PEG-based SMILES tokenizer</li>
<li>Input sequence reversal</li>
<li>Bucketed training with four bucket sizes</li>
<li>TensorFlow seq2seq tutorial implementation with default learning rate</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Parameter</th>
          <th style="text-align: left">Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">GRU layers</td>
          <td style="text-align: left">3</td>
      </tr>
      <tr>
          <td style="text-align: left">Embedding size</td>
          <td style="text-align: left">600</td>
      </tr>
      <tr>
          <td style="text-align: left">Input vocabulary</td>
          <td style="text-align: left">311 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left">Output vocabulary</td>
          <td style="text-align: left">180 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left">Buckets</td>
          <td style="text-align: left">(54,54), (70,60), (90,65), (150,80)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">gen Model</th>
          <th style="text-align: left">real+gen Model</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Textbook correct ratio</td>
          <td style="text-align: left">Variable by set</td>
          <td style="text-align: left">Higher on most sets</td>
          <td style="text-align: left">10 problem sets</td>
      </tr>
      <tr>
          <td style="text-align: left">Average Tanimoto similarity</td>
          <td style="text-align: left">Variable</td>
          <td style="text-align: left">~0.7 on scalability test</td>
          <td style="text-align: left">Morgan fingerprint based</td>
      </tr>
      <tr>
          <td style="text-align: left">Invalid SMILES ratio</td>
          <td style="text-align: left">Higher</td>
          <td style="text-align: left">~0.4 on scalability test</td>
          <td style="text-align: left">Decreases with more training data</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Nam, J. &amp; Kim, J. (2016). Linking the Neural Machine Translation and the Prediction of Organic Chemistry Reactions. <em>arXiv preprint</em>, arXiv:1612.09529. <a href="https://arxiv.org/abs/1612.09529">https://arxiv.org/abs/1612.09529</a></p>
<p><strong>Publication</strong>: arXiv preprint 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{nam2016linking,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Linking the Neural Machine Translation and the Prediction of Organic Chemistry Reactions}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Nam, Juno and Kim, Jurae}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1612.09529}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.1612.09529}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MoMu: Bridging Molecular Graphs and Natural Language</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/</guid><description>MoMu bridges molecular graphs and natural language via contrastive pre-training, enabling cross-modal retrieval, captioning, and property prediction.</description><content:encoded><![CDATA[<h2 id="bridging-molecular-graphs-and-natural-language-through-contrastive-learning">Bridging Molecular Graphs and Natural Language Through Contrastive Learning</h2>
<p>MoMu (Molecular Multimodal foundation model) is a <strong>Method</strong> paper that proposes a multimodal pre-training approach to associate molecular graphs with natural language descriptions. The primary contribution is a dual-encoder architecture, consisting of a Graph Isomorphism Network (GIN) for molecular graphs and a BERT-based text encoder, jointly trained through contrastive learning on weakly-correlated graph-text pairs collected from scientific literature. The pre-trained model supports four downstream capabilities: cross-modal retrieval (graph-to-text and text-to-graph), molecule captioning, zero-shot text-to-graph molecule generation, and molecular property prediction.</p>
<h2 id="why-single-modality-models-are-insufficient-for-molecular-understanding">Why Single-Modality Models Are Insufficient for Molecular Understanding</h2>
<p>Existing AI models for molecular tasks generally operate on a single modality and learn a single cognitive ability. Language-based models process <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings or natural language texts and handle tasks like property prediction from strings, literature comprehension, or SMILES-based generation. Graph-based models use molecular graph representations and handle graph-level property prediction or graph generation. Neither category connects structural information from molecular graphs with the rich semantic knowledge encoded in scientific texts.</p>
<p>Prior work by Zeng et al. (KV-PLM) jointly modeled molecule-related texts and SMILES strings, but SMILES representations have inherent drawbacks: they are one-dimensional and may lose structural information, they cannot capture structural similarities between molecules, and a single molecule can have multiple valid SMILES representations. Molecular graphs, by contrast, are more intuitive and better reveal functional structures. Human experts learn molecular knowledge by associating both graphical representations and textual descriptions, yet no prior model bridged these two modalities directly.</p>
<p>The key challenge is the scarcity of paired molecular graph-text data compared to general image-text datasets. Additionally, learning specialized molecular knowledge requires foundational cognitive abilities in both the graph and text domains, making training from scratch infeasible with limited data.</p>
<h2 id="contrastive-pre-training-with-inter-modal-and-intra-modal-objectives">Contrastive Pre-Training with Inter-Modal and Intra-Modal Objectives</h2>
<p>MoMu consists of two encoders initialized from pre-trained unimodal models: a GIN graph encoder initialized from GraphCL self-supervised weights, and a BERT text encoder initialized from either Sci-BERT (yielding MoMu-S) or KV-PLM (yielding MoMu-K).</p>
<h3 id="data-collection">Data Collection</h3>
<p>The authors collect approximately 15,613 molecular graph-document pairs by:</p>
<ol>
<li>Gathering names, synonyms, and SMILES for the top 50K compounds in <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a></li>
<li>Converting SMILES to molecular graphs using the OGB <code>smiles2graph</code> function</li>
<li>Retrieving related text from the S2ORC corpus (136M+ papers) by querying with molecule names, filtering to Medicine, Biology, Chemistry, and Computer Science fields</li>
<li>Restricting retrieval to abstract, introduction, and conclusion sections to avoid experimental data artifacts</li>
</ol>
<h3 id="contrastive-training-objective">Contrastive Training Objective</h3>
<p>For each graph-text pair in a mini-batch of $N$ pairs, MoMu applies two graph augmentations (node dropping and subgraph extraction) to create two augmented graphs, and randomly samples two sentences from the document. This produces $2N$ graph representations ${z_1^G, \tilde{z}_1^G, \ldots, z_N^G, \tilde{z}_N^G}$ and $2N$ text representations ${z_1^T, \tilde{z}_1^T, \ldots, z_N^T, \tilde{z}_N^T}$.</p>
<p>The cross-modal contrastive loss for a pair $(z_i^G, z_i^T)$ is:</p>
<p>$$
\ell_i^{(z_i^G, z_i^T)} = -\log \frac{\exp(\text{sim}(z_i^G, z_i^T) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(z_i^G, z_j^T) / \tau)}
$$</p>
<p>where $\tau$ is the temperature parameter and $\text{sim}(\cdot, \cdot)$ projects both representations into a shared 256-dimensional space before computing cosine similarity. The total cross-modal loss includes four contrastive terms for each pair: $(z_i^G, z_i^T)$, $(\tilde{z}_i^G, z_i^T)$, $(z_i^G, \tilde{z}_i^T)$, and $(\tilde{z}_i^G, \tilde{z}_i^T)$.</p>
<p>An intra-modal graph contrastive loss further strengthens the graph encoder:</p>
<p>$$
\ell_i^{(z_i^G, \tilde{z}_i^G)} = -\log \frac{\exp(\text{sim}(z_i^G, \tilde{z}_i^G) / \tau)}{\sum_{j=1}^{N} \exp(\text{sim}(z_i^G, \tilde{z}_j^G) / \tau)}
$$</p>
<h3 id="zero-shot-text-to-graph-generation">Zero-Shot Text-to-Graph Generation</h3>
<p>MoMu enables a zero-shot generation pipeline by combining the pre-trained MoMu encoders with MoFlow, a flow-based molecular generator. Given an input text description $x^T$, the method:</p>
<ol>
<li>Samples a latent variable $q$ from MoFlow&rsquo;s Gaussian prior $P(q)$</li>
<li>Generates a molecular graph through MoFlow&rsquo;s reverse flows: $\hat{E} = f_g^{-1}(q_e)$ and $\hat{V} = f_c^{-1}(q_v \mid GN(\hat{E}))$</li>
<li>Feeds $\hat{V}$ (using soft atom type probabilities instead of hard assignments) into MoMu&rsquo;s graph encoder</li>
<li>Optimizes $q$ to maximize the cosine similarity between the resulting graph and text representations:</li>
</ol>
<p>$$
\ell_q = -\text{sim}(z^G, z^T) / \tau
$$</p>
<p>All MoMu and MoFlow parameters are frozen; only $q$ is updated via Adam for up to 500 iterations. The final molecule is obtained by applying argmax to the optimized probability matrices $\hat{V}$ and $\hat{E}$.</p>
<h2 id="evaluation-across-four-downstream-tasks">Evaluation Across Four Downstream Tasks</h2>
<h3 id="cross-modal-retrieval">Cross-Modal Retrieval</h3>
<p>MoMu is evaluated on the PCdes dataset (15K SMILES-description pairs from PubChem, split 10,500/1,500/3,000 for train/val/test). Retrieval is performed in mini-batches of 64 pairs, reporting top-1 accuracy and Recall@20.</p>
<p><strong>Graph-to-Text Retrieval (PCdes, fine-tuned)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Sentence Acc</th>
          <th>Sentence R@20</th>
          <th>Paragraph Acc</th>
          <th>Paragraph R@20</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sci-BERT</td>
          <td>50.38</td>
          <td>62.11</td>
          <td>62.57</td>
          <td>60.67</td>
      </tr>
      <tr>
          <td>KV-PLM</td>
          <td>53.79</td>
          <td>66.63</td>
          <td>64.81</td>
          <td>63.87</td>
      </tr>
      <tr>
          <td>KV-PLM*</td>
          <td>55.92</td>
          <td>68.59</td>
          <td>77.92</td>
          <td>75.93</td>
      </tr>
      <tr>
          <td>MoMu-S</td>
          <td>58.64</td>
          <td>80.59</td>
          <td>80.62</td>
          <td>79.11</td>
      </tr>
      <tr>
          <td>MoMu-K</td>
          <td>58.74</td>
          <td>81.29</td>
          <td>81.09</td>
          <td>80.15</td>
      </tr>
  </tbody>
</table>
<p><strong>Text-to-Graph Retrieval (PCdes, fine-tuned)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Sentence Acc</th>
          <th>Sentence R@20</th>
          <th>Paragraph Acc</th>
          <th>Paragraph R@20</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sci-BERT</td>
          <td>50.12</td>
          <td>68.02</td>
          <td>61.75</td>
          <td>60.77</td>
      </tr>
      <tr>
          <td>KV-PLM</td>
          <td>54.22</td>
          <td>71.80</td>
          <td>64.95</td>
          <td>64.27</td>
      </tr>
      <tr>
          <td>KV-PLM*</td>
          <td>55.61</td>
          <td>74.77</td>
          <td>77.03</td>
          <td>75.47</td>
      </tr>
      <tr>
          <td>MoMu-S</td>
          <td>55.44</td>
          <td>76.92</td>
          <td>80.22</td>
          <td>79.02</td>
      </tr>
      <tr>
          <td>MoMu-K</td>
          <td>54.94</td>
          <td>78.29</td>
          <td>81.45</td>
          <td>80.62</td>
      </tr>
  </tbody>
</table>
<p>In zero-shot retrieval (on a separate test set of 5,562 pairs not seen during pre-training), MoMu achieves approximately 39-46% accuracy compared to below 2% for Sci-BERT and KV-PLM, demonstrating strong generalization.</p>
<h3 id="molecule-captioning">Molecule Captioning</h3>
<p>MoMu&rsquo;s graph features are appended to MolT5&rsquo;s encoder inputs through a learned MLP mapping module on the ChEBI-20 dataset. Results show improvements in BLEU, METEOR, and Text2Mol scores when incorporating graph features, though ROUGE-L slightly drops. The graph structural information leads to more accurate captions for complex molecular structures.</p>
<h3 id="molecular-property-prediction">Molecular Property Prediction</h3>
<p>The pre-trained graph encoder from MoMu is fine-tuned on eight <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> datasets using scaffold splitting and ROC-AUC evaluation (10 runs).</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>No Pre-Train</th>
          <th>GraphCL</th>
          <th>MoMu-S</th>
          <th>MoMu-K</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>65.8</td>
          <td>69.7</td>
          <td><strong>70.5</strong></td>
          <td>70.1</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>74.0</td>
          <td>73.9</td>
          <td>75.6</td>
          <td>75.6</td>
      </tr>
      <tr>
          <td>ToxCast</td>
          <td>63.4</td>
          <td>62.4</td>
          <td>63.4</td>
          <td>63.0</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>57.3</td>
          <td>60.5</td>
          <td>60.5</td>
          <td>60.4</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>58.0</td>
          <td>76.0</td>
          <td><strong>79.9</strong></td>
          <td>77.4</td>
      </tr>
      <tr>
          <td>MUV</td>
          <td>71.8</td>
          <td>69.8</td>
          <td>70.5</td>
          <td>71.1</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>75.3</td>
          <td><strong>78.5</strong></td>
          <td>75.9</td>
          <td>76.2</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>70.1</td>
          <td>75.4</td>
          <td>76.7</td>
          <td>77.1</td>
      </tr>
      <tr>
          <td><strong>Average</strong></td>
          <td>66.96</td>
          <td>70.78</td>
          <td><strong>71.63</strong></td>
          <td>71.36</td>
      </tr>
  </tbody>
</table>
<p>MoMu-S achieves the best average ROC-AUC (71.63%) across all eight datasets, outperforming GraphCL (70.78%), the self-supervised method used to initialize MoMu&rsquo;s graph encoder. MoMu outperforms GraphCL on six of eight datasets. Notably, MoMu-S and MoMu-K perform comparably, indicating that KV-PLM&rsquo;s SMILES-based knowledge does not transfer well to graph-based representations.</p>
<h3 id="zero-shot-text-to-graph-generation-1">Zero-Shot Text-to-Graph Generation</h3>
<p>The method generates molecules from three types of text descriptions:</p>
<ol>
<li><strong>High-level vague descriptions</strong> (e.g., &ldquo;The molecule is beautiful&rdquo;): MoMu generates diverse, interpretable molecules where &ldquo;beautiful&rdquo; tends to produce locally symmetric and stretched graphs, &ldquo;versatile&rdquo; produces molecules with varied elements and functional groups, and &ldquo;strange&rdquo; produces cluttered, irregular structures.</li>
<li><strong>Functional descriptions</strong> (e.g., &ldquo;fluorescent molecules&rdquo;, &ldquo;high water solubility and barrier permeability with low toxicity&rdquo;): MoMu successfully generates molecules with appropriate functional groups and properties. For the solubility/permeability/toxicity query, MoMu generates molecules that satisfy three of three evaluable properties.</li>
<li><strong>Structural descriptions</strong> (e.g., &ldquo;molecules containing <a href="https://en.wikipedia.org/wiki/Nucleophile">nucleophilic</a> groups&rdquo;): MoMu generates diverse molecules with appropriate functional groups (amino, hydroxyl, carbonyl, halogen atoms).</li>
</ol>
<h2 id="promising-multimodal-transfer-with-clear-data-limitations">Promising Multimodal Transfer with Clear Data Limitations</h2>
<p>MoMu demonstrates that contrastive pre-training on weakly-correlated graph-text data can bridge molecular graphs and natural language in a shared representation space. The key findings are:</p>
<ol>
<li><strong>Cross-modal alignment works with limited data</strong>: With only 15K graph-text pairs (far fewer than the millions used in vision-language models like CLIP), MoMu achieves meaningful cross-modal retrieval and enables zero-shot generation.</li>
<li><strong>Multimodal supervision improves graph representations</strong>: The graph encoder supervised by text descriptions outperforms self-supervised methods (GraphCL, AttrMasking, ContextPred) on average across molecular property prediction benchmarks.</li>
<li><strong>SMILES knowledge does not transfer to graphs</strong>: MoMu-S and MoMu-K perform comparably across all tasks, showing that structural information learned from one-dimensional SMILES strings does not readily generalize to graph neural networks.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several important limitations:</p>
<ul>
<li><strong>Data scarcity</strong>: 15K graph-text pairs is substantially smaller than general image-text datasets, potentially leaving the common space insufficiently aligned.</li>
<li><strong>Noisy supervision</strong>: Retrieved texts may mention a molecule by name without describing its properties or structure, leading to spurious correlations.</li>
<li><strong>Generator constraints</strong>: The zero-shot generation method is limited by MoFlow&rsquo;s capacity (maximum 38 atoms, 9 element types from ZINC250K training).</li>
<li><strong>Property coverage</strong>: Generation quality degrades for molecular properties that appear infrequently or not at all in the training texts.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose four avenues: (1) collecting larger-scale multimodal molecular data including 3D conformations, (2) using strongly-correlated paired data with more advanced generators, (3) developing interpretable tools for the learned cross-modal space, and (4) wet-lab validation of generated molecules.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>Collected graph-text pairs (PubChem + S2ORC)</td>
          <td>15,613 pairs</td>
          <td>~37M paragraphs total; top 50K PubChem compounds</td>
      </tr>
      <tr>
          <td>Cross-modal retrieval</td>
          <td>PCdes</td>
          <td>15K pairs (10.5K/1.5K/3K split)</td>
          <td>SMILES-description pairs from PubChem</td>
      </tr>
      <tr>
          <td>Molecule captioning</td>
          <td>ChEBI-20</td>
          <td>~33K pairs</td>
          <td>Used with MolT5</td>
      </tr>
      <tr>
          <td>Text-to-graph generation</td>
          <td><a href="/notes/chemistry/datasets/zinc-22/">ZINC250K</a> (MoFlow)</td>
          <td>250K molecules</td>
          <td>Pre-trained generator, max 38 atoms</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet (8 datasets)</td>
          <td>Varies</td>
          <td>BBBP, Tox21, ToxCast, SIDER, ClinTox, MUV, HIV, BACE</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Graph augmentations</strong>: Node dropping (10% ratio) and subgraph extraction (80% of original size via random walk)</li>
<li><strong>Contrastive learning</strong>: InfoNCE loss with temperature $\tau = 0.1$, following the DeClip paradigm with both inter-modal and intra-modal objectives</li>
<li><strong>Zero-shot generation</strong>: Adam optimizer on latent variable $q$ for up to 500 iterations; formal charges prohibited in output</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Graph encoder</strong>: GIN with 5 layers, 300-dimensional hidden size, initialized from GraphCL checkpoint</li>
<li><strong>Text encoder</strong>: BERT-base (768 hidden size), initialized from Sci-BERT or KV-PLM</li>
<li><strong>Projection heads</strong>: Two MLPs projecting graph (300-dim) and text (768-dim) features to 256-dimensional shared space</li>
<li><strong>Optimizer</strong>: AdamW, learning rate 0.0001, weight decay 1e-5, 300 epochs, batch size 256</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>Best Result</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>G-T Retrieval (PCdes)</td>
          <td>Accuracy / R@20</td>
          <td>81.09 / 80.15 (paragraph)</td>
          <td>MoMu-K, fine-tuned</td>
      </tr>
      <tr>
          <td>T-G Retrieval (PCdes)</td>
          <td>Accuracy / R@20</td>
          <td>81.45 / 80.62 (paragraph)</td>
          <td>MoMu-K, fine-tuned</td>
      </tr>
      <tr>
          <td>Zero-shot G-T Retrieval</td>
          <td>Accuracy</td>
          <td>~46%</td>
          <td>vs. ~1.4% for baselines</td>
      </tr>
      <tr>
          <td>Property Prediction</td>
          <td>ROC-AUC (avg)</td>
          <td>71.63%</td>
          <td>MoMu-S, 8 MoleculeNet datasets</td>
      </tr>
      <tr>
          <td>Molecule Captioning</td>
          <td>Text2Mol</td>
          <td>Improved over MolT5</td>
          <td>MoMu + MolT5-large</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 8x NVIDIA Tesla V100 PCIe 32GB GPUs</li>
<li>Framework: PyTorch</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BingSu12/MoMu">MoMu code</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Pre-training and downstream task code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/yangzhao1230/GraphTextRetrieval">GraphTextRetrieval</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Data collection and cross-modal retrieval code</td>
      </tr>
      <tr>
          <td><a href="https://pan.baidu.com/s/1aHJoYTTZWDHPCcRuu9I7Fg">Pre-training dataset</a></td>
          <td>Dataset</td>
          <td>Not specified</td>
          <td>Hosted on Baidu Pan (Chinese cloud storage)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Su, B., Du, D., Yang, Z., Zhou, Y., Li, J., Rao, A., Sun, H., Lu, Z., &amp; Wen, J.-R. (2022). A Molecular Multimodal Foundation Model Associating Molecule Graphs with Natural Language. arXiv preprint arXiv:2209.05481.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{su2022momu,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Molecular Multimodal Foundation Model Associating Molecule Graphs with Natural Language}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Su, Bing and Du, Dazhao and Yang, Zhao and Zhou, Yujie and Li, Jiangmeng and Rao, Anyi and Sun, Hao and Lu, Zhiwu and Wen, Ji-Rong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2209.05481}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolFM: Trimodal Molecular Foundation Pre-training</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/</guid><description>MolFM fuses molecular graphs, biomedical text, and knowledge graphs via cross-modal attention for joint molecular representation learning.</description><content:encoded><![CDATA[<h2 id="trimodal-pre-training-for-molecular-understanding">Trimodal Pre-training for Molecular Understanding</h2>
<p>MolFM is a <strong>Method</strong> paper that introduces a multimodal molecular foundation model integrating three distinct sources of molecular knowledge: 2D molecular graphs, biomedical text, and knowledge graphs. The primary contribution is a pre-training framework that uses fine-grained cross-modal attention to fuse information across all three modalities, combined with theoretical justification from a deep metric learning perspective. MolFM achieves the best reported results (at time of publication) on cross-modal retrieval, molecule captioning, text-based molecule generation, and molecular property prediction.</p>
<h2 id="why-existing-molecular-models-fall-short">Why Existing Molecular Models Fall Short</h2>
<p>Prior multimodal molecular foundation models operate on at most two modalities (structures and text) and suffer from two key limitations. First, generative approaches like KV-PLM and MolT5 rely on 1D <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, which cannot capture complex topological and spatial molecular properties such as macrocycles. Contrastive approaches like <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a> and MoleculeSTM learn global alignment between molecule graphs and text but overlook fine-grained connections between specific substructures and textual descriptions.</p>
<p>Second, and more fundamentally, no prior model incorporates <a href="https://en.wikipedia.org/wiki/Knowledge_graph">knowledge graphs</a> as a third modality. Knowledge graphs encode global-level relationships among molecules, target ligands, diseases, and other biomedical entities. These relationships capture functional and structural similarity patterns that cannot be learned from individual molecule-text pairs alone. MolFM addresses both gaps by introducing cross-modal attention across all three modalities and providing theoretical guarantees about what the pre-training objectives learn.</p>
<h2 id="cross-modal-attention-and-metric-learning-guarantees">Cross-Modal Attention and Metric Learning Guarantees</h2>
<h3 id="architecture">Architecture</h3>
<p>MolFM uses three pre-trained single-modal encoders:</p>
<ul>
<li><strong>Molecular graph encoder</strong>: A 5-layer GIN (1.8M parameters) initialized from GraphMVP, producing atom-level features $h_{SA}$ and a graph-level feature $h_{SM}$</li>
<li><strong>Text encoder</strong>: A 6-layer transformer (61.8M parameters) initialized from KV-PLM&rsquo;s first 6 layers, producing token features $h_T$</li>
<li><strong>Knowledge graph encoder</strong>: A TransE model (12.6M parameters) trained on the knowledge graph for 500 epochs, producing entity features $h_K$</li>
</ul>
<p>A multimodal encoder (61.8M parameters, 6 transformer layers with cross-attention) fuses the three modalities. The cross-attention uses text token features as queries and the concatenation of atom features and knowledge graph neighbor features as keys and values. For each molecule, the knowledge graph input is the molecule&rsquo;s entity and $N=4$ randomly sampled one-hop neighbors.</p>
<h3 id="pre-training-objectives">Pre-training Objectives</h3>
<p>MolFM combines four losses:</p>
<p><strong>Structure-text contrastive (STC)</strong> aligns the global feature spaces of structure and text encoders using a symmetric InfoNCE loss:</p>
<p>$$\mathcal{L}_{stc} = -\frac{1}{2} \left[ \log \frac{\exp(s(z_S, z_T) / \tau)}{\sum_{S&rsquo; \in B} \exp(s(z_{S&rsquo;}, z_T) / \tau)} + \log \frac{\exp(s(z_S, z_T) / \tau)}{\sum_{T&rsquo; \in B} \exp(s(z_S, z_{T&rsquo;}) / \tau)} \right]$$</p>
<p>where $s(\cdot, \cdot)$ is cosine similarity and $\tau = 0.1$ is a temperature parameter.</p>
<p><strong>Cross-modal matching (CMM)</strong> predicts whether a structure-text-knowledge triplet corresponds to the same molecule, using cross-entropy over the multimodal encoder&rsquo;s CLS token:</p>
<p>$$\mathcal{L}_{cmm} = \sum_{(\tilde{S}, \tilde{T}, \tilde{K}) \in \tilde{B}} H\left[y_{cmm}(\tilde{S}, \tilde{T}, \tilde{K}),; p_{cmm}\left(\mathcal{M}_\theta(h_{\tilde{S}}, h_{\tilde{T}}, h_{\tilde{K}})\right)\right]$$</p>
<p><strong>Masked language modeling (MLM)</strong> predicts masked text tokens conditioned on all three modalities:</p>
<p>$$\mathcal{L}_{mlm} = H\left[y_{mlm}(\hat{T}),; p_{mlm}\left(\mathcal{M}_\theta(h_S, h_{\hat{T}}, h_K)\right)\right]$$</p>
<p><strong>Knowledge graph embedding (KGE)</strong> regularizes entity embeddings with a max-margin TransE loss:</p>
<p>$$\mathcal{L}_{kge} = \sum_{h \in K} \left[\max(0, d(h,r,t) - d(h,r,\tilde{t}) + \Delta) + \max(0, d(h,r,t) - d(\tilde{h},r,t) + \Delta)\right]$$</p>
<p>where $d(h,r,t) = | f(h) + g(r) - f(t) |_2$ and $\Delta = 0.2$.</p>
<p>The total pre-training loss is:</p>
<p>$$\mathcal{L} = \mathbb{E}_{(S,T,K)}\left[\mathcal{L}_{stc} + \mathcal{L}_{cmm} + \mathcal{L}_{mlm} + \mathcal{L}_{kge}\right]$$</p>
<h3 id="theoretical-justifications">Theoretical Justifications</h3>
<p>The authors provide metric learning interpretations for each objective. For CMM, they show that the loss is proportional to assigning higher scores to matched triplets and lower scores to unmatched ones, aligning the feature space across all three modalities.</p>
<p>For KGE, two lemmas provide guarantees about structurally and functionally similar molecules:</p>
<p><strong>Lemma 1</strong> (Structural similarity): For a symmetric structural-similarity relation $r_s$, the KGE loss satisfies:</p>
<p>$$\mathcal{L}_{kge}(h, r_s, t) \propto 2|f(h) - f(t)| - \mathbb{E}_{\tilde{t}}|f(h) - f(\tilde{t})| - \mathbb{E}_{\tilde{h}}|f(\tilde{h}) - f(t)|$$</p>
<p>This shows KGE pulls structurally similar molecules closer while pushing dissimilar ones apart.</p>
<p><strong>Lemma 2</strong> (Functional similarity): For molecules $h$ and $t$ that interact with a common entity $o$, the distance between their embeddings is upper-bounded:</p>
<p>$$|f(h) - f(t)| \leq \alpha,\mathbb{E}_{(e_1, r, e_2) \sim \mathcal{I}}\left[\mathcal{L}_{kge}(e_1, r, e_2)\right] + C$$</p>
<p>where $\alpha \approx 1$ and $C \approx 0$. This guarantees that minimizing KGE also brings functionally similar molecules closer in the embedding space.</p>
<h2 id="experiments-across-four-downstream-tasks">Experiments Across Four Downstream Tasks</h2>
<h3 id="pre-training-data">Pre-training Data</h3>
<p>MolFM pre-trains on 15K molecules from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> paired with 37M paragraphs from S2ORC. The knowledge graph contains 49K entities and 3.2M relations, constructed from <a href="https://en.wikipedia.org/wiki/DrugBank">DrugBank</a>, <a href="https://en.wikipedia.org/wiki/BindingDB">BindingDB</a>, and additional public databases with heuristic augmentation.</p>
<h3 id="cross-modal-retrieval">Cross-Modal Retrieval</h3>
<p>Evaluated on PCdes (paragraph-level) in zero-shot and fine-tuning settings. MolFM uses a re-ranking strategy that linearly combines cosine similarity with CMM logits over the top-$k$ retrieved candidates.</p>
<table>
  <thead>
      <tr>
          <th>Mode</th>
          <th>Model</th>
          <th>S-T MRR</th>
          <th>S-T R@1</th>
          <th>S-T R@10</th>
          <th>T-S MRR</th>
          <th>T-S R@1</th>
          <th>T-S R@10</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Zero-shot</td>
          <td>MoMu</td>
          <td>9.89</td>
          <td>5.08</td>
          <td>18.93</td>
          <td>10.33</td>
          <td>4.90</td>
          <td>20.69</td>
      </tr>
      <tr>
          <td>Zero-shot</td>
          <td>MolFM</td>
          <td>21.42</td>
          <td>13.90</td>
          <td>36.21</td>
          <td>23.63</td>
          <td>16.14</td>
          <td>39.54</td>
      </tr>
      <tr>
          <td>Fine-tune</td>
          <td>MoMu</td>
          <td>34.29</td>
          <td>24.47</td>
          <td>53.84</td>
          <td>34.53</td>
          <td>24.87</td>
          <td>54.25</td>
      </tr>
      <tr>
          <td>Fine-tune</td>
          <td>MolFM</td>
          <td>39.56</td>
          <td>29.76</td>
          <td>58.63</td>
          <td>39.34</td>
          <td>29.39</td>
          <td>58.49</td>
      </tr>
  </tbody>
</table>
<p>MolFM achieves 12.13% and 5.04% absolute gains over MoMu under zero-shot and fine-tuning settings, respectively.</p>
<h3 id="molecule-captioning">Molecule Captioning</h3>
<p>Evaluated on ChEBI-20 using MolT5 decoders. MolFM&rsquo;s structure encoder features are concatenated with the MolT5 encoder outputs.</p>
<table>
  <thead>
      <tr>
          <th>Decoder</th>
          <th>Encoder</th>
          <th>BLEU-4</th>
          <th>ROUGE-L</th>
          <th>METEOR</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-base</td>
          <td>MolT5-base</td>
          <td>0.457</td>
          <td>0.578</td>
          <td>0.569</td>
          <td>0.547</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MoMu</td>
          <td>0.462</td>
          <td>0.575</td>
          <td>0.576</td>
          <td>0.558</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>GraphMVP</td>
          <td>0.491</td>
          <td>0.592</td>
          <td>0.599</td>
          <td>0.570</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MolFM</td>
          <td>0.498</td>
          <td>0.594</td>
          <td>0.607</td>
          <td>0.576</td>
      </tr>
  </tbody>
</table>
<h3 id="text-based-molecule-generation">Text-Based Molecule Generation</h3>
<p>Also on ChEBI-20 with MolT5 decoders. MolFM&rsquo;s text features are projected and fed to the decoder.</p>
<table>
  <thead>
      <tr>
          <th>Decoder</th>
          <th>Encoder</th>
          <th>Exact</th>
          <th>Valid</th>
          <th>Morgan FTS</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-base</td>
          <td>MolT5-base</td>
          <td>0.082</td>
          <td>0.786</td>
          <td>0.601</td>
          <td>0.543</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MoMu</td>
          <td>0.183</td>
          <td>0.863</td>
          <td>0.678</td>
          <td>0.580</td>
      </tr>
      <tr>
          <td>MolT5-base</td>
          <td>MolFM</td>
          <td>0.210</td>
          <td>0.892</td>
          <td>0.697</td>
          <td>0.583</td>
      </tr>
  </tbody>
</table>
<h3 id="molecular-property-prediction">Molecular Property Prediction</h3>
<p>On <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> (8 classification datasets), MolFM concatenates the structure feature and the multimodal encoder&rsquo;s CLS feature to predict properties.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP</th>
          <th>Tox21</th>
          <th>ClinTox</th>
          <th>HIV</th>
          <th>BACE</th>
          <th>Avg</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GraphMVP</td>
          <td>72.4</td>
          <td>74.4</td>
          <td>77.5</td>
          <td>77.0</td>
          <td>81.2</td>
          <td>73.07</td>
      </tr>
      <tr>
          <td>DeepEIK</td>
          <td>72.1</td>
          <td>72.4</td>
          <td>89.7</td>
          <td>75.0</td>
          <td>80.5</td>
          <td>73.27</td>
      </tr>
      <tr>
          <td>MolFM (w/o T+K)</td>
          <td>72.2</td>
          <td>76.6</td>
          <td>78.6</td>
          <td>78.2</td>
          <td>82.6</td>
          <td>73.95</td>
      </tr>
      <tr>
          <td>MolFM (w/ T+K)</td>
          <td>72.9</td>
          <td>77.2</td>
          <td>79.7</td>
          <td>78.8</td>
          <td>83.9</td>
          <td>74.62</td>
      </tr>
  </tbody>
</table>
<p>With multimodal inputs, MolFM averages 74.62% ROC-AUC, a 1.55% absolute gain over GraphMVP.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>Zero-shot retrieval ablations reveal that cross-modal attention to atoms and CMM are the most critical components. Removing either causes a sharp drop (approximately 3% on S-T retrieval). Knowledge graph incorporation yields a 1.5% average improvement, with both attention to neighbors and KGE contributing marginally.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p>MolFM demonstrates that incorporating knowledge graphs as a third modality provides consistent improvements across all evaluated tasks. The theoretical analysis connecting pre-training objectives to deep metric learning provides interpretability for why the model works: STC and CMM align representations of the same molecule across modalities, while KGE pulls structurally and functionally similar molecules closer in the embedding space.</p>
<p>The cross-modal attention visualizations show that MolFM learns to associate specific atom substructures with relevant text tokens and knowledge graph entities. For example, the model correctly attends to functional groups mentioned in textual descriptions.</p>
<p>The authors acknowledge several limitations:</p>
<ol>
<li><strong>Data quality</strong>: The pre-training dataset (15K molecules) is small and may introduce biases</li>
<li><strong>Cold-start problem</strong>: MolFM provides limited benefit for newly emerged molecules lacking text and knowledge graph information</li>
<li><strong>Entity scope</strong>: The model focuses on molecules and does not incorporate proteins, genes, or cell lines, which could further improve biomedical understanding</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training (molecules)</td>
          <td>PubChem</td>
          <td>15K molecules</td>
          <td>Follows MoMu&rsquo;s pre-training data</td>
      </tr>
      <tr>
          <td>Pre-training (text)</td>
          <td>S2ORC</td>
          <td>37M paragraphs</td>
          <td>Biomedical literature paragraphs</td>
      </tr>
      <tr>
          <td>Knowledge graph</td>
          <td>DrugBank, BindingDB, public DBs</td>
          <td>49K entities, 3.2M relations</td>
          <td>Constructed with heuristics from MoCL</td>
      </tr>
      <tr>
          <td>Cross-modal retrieval</td>
          <td>PCdes</td>
          <td>Paragraph-level</td>
          <td>Test split</td>
      </tr>
      <tr>
          <td>Captioning/Generation</td>
          <td>ChEBI-20</td>
          <td>-</td>
          <td>Following MolT5 splits</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet</td>
          <td>8 datasets</td>
          <td>Classification tasks, ROC-AUC metric</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: AdamW with weight decay $1 \times 10^{-4}$</li>
<li>Learning rate: linear warmup to $1 \times 10^{-4}$ over 2,000 iterations, cosine annealing to $1 \times 10^{-5}$</li>
<li>Batch size: 128</li>
<li>Pre-training epochs: 300</li>
<li>Knowledge graph neighbors per molecule: $N = 4$</li>
<li>Temperature: $\tau = 0.1$</li>
<li>Margin: $\Delta = 0.2$</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
          <th>Initialization</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph encoder</td>
          <td>5-layer GIN</td>
          <td>1.8M</td>
          <td>GraphMVP</td>
      </tr>
      <tr>
          <td>Text encoder</td>
          <td>6-layer Transformer</td>
          <td>61.8M</td>
          <td>KV-PLM (first 6 layers)</td>
      </tr>
      <tr>
          <td>Knowledge encoder</td>
          <td>TransE</td>
          <td>12.6M</td>
          <td>Trained 500 epochs on KG</td>
      </tr>
      <tr>
          <td>Multimodal encoder</td>
          <td>6-layer Transformer + cross-attention</td>
          <td>61.8M</td>
          <td>KV-PLM (last 6 layers)</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td></td>
          <td><strong>~138M</strong></td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metrics</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Cross-modal retrieval</td>
          <td>MRR, Recall@1/5/10</td>
      </tr>
      <tr>
          <td>Molecule captioning</td>
          <td>BLEU-2/4, ROUGE-1/2/L, METEOR, Text2Mol</td>
      </tr>
      <tr>
          <td>Text-to-molecule generation</td>
          <td>BLEU, Exact ratio, Validity, Levenshtein, Fingerprint Tanimoto (MACCS/RDKit/Morgan), Text2Mol</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>ROC-AUC per dataset</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>4 NVIDIA A100 GPUs for pre-training</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BioFM/OpenBioMed">OpenBioMed</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation including MolFM</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, Y., Yang, K., Hong, M., Liu, X. Y., &amp; Nie, Z. (2023). MolFM: A Multimodal Molecular Foundation Model. <em>arXiv preprint arXiv:2307.09484</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{luo2023molfm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolFM: A Multimodal Molecular Foundation Model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Yizhen and Yang, Kai and Hong, Massimo and Liu, Xing Yi and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2307.09484}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolecularRNN: Graph-Based Molecular Generation and RL</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/molecularrnn-graph-generation-optimized-properties/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/molecularrnn-graph-generation-optimized-properties/</guid><description>MolecularRNN extends GraphRNN with atom and bond type predictions, valency-based rejection sampling, and policy gradient optimization for molecular generation.</description><content:encoded><![CDATA[<h2 id="a-graph-recurrent-model-for-molecular-generation-with-property-optimization">A Graph Recurrent Model for Molecular Generation with Property Optimization</h2>
<p>This is a <strong>Method</strong> paper that introduces MolecularRNN, a graph-based recurrent generative model for molecular structures. The model extends GraphRNN to handle typed nodes (atom types) and typed edges (bond types), enabling direct generation of molecular graphs rather than working through string representations like SMILES. Three key contributions are combined: (1) the MolecularRNN architecture for autoregressive graph generation, (2) valency-based rejection sampling for guaranteed 100% validity at inference, and (3) policy gradient reinforcement learning for shifting molecular property distributions toward desired ranges.</p>
<h2 id="why-generate-molecules-as-graphs-rather-than-strings">Why Generate Molecules as Graphs Rather Than Strings</h2>
<p>Computational de novo molecular design aims to create novel molecules with desired properties, a task central to drug discovery. At the time of this work, most deep generative models for molecules operated on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, inheriting the complications of SMILES grammar and the problem that structurally similar molecules can have very different string representations. Graph-based representations are more natural for molecules, with atoms mapping to nodes and bonds to edges, and they allow direct enforcement of chemical constraints during generation.</p>
<p>Existing graph-based methods had their own limitations. Junction tree VAE (JT-VAE) generates molecules from structural fragments, which introduces ambiguity when converting junction trees back to molecules, particularly problematic during property optimization since molecules sharing a junction tree can have very different property values. The GCPN model uses graph convolutional networks with reinforcement learning but was evaluated only on top-3 generated molecules, making it difficult to assess overall distribution quality. Prior atom-level graph generation models like Li et al. (2018a) were restricted to molecules with at most 20 heavy atoms, limiting practical applicability.</p>
<h2 id="core-innovation-extending-graphrnn-with-chemical-constraints-and-rl">Core Innovation: Extending GraphRNN with Chemical Constraints and RL</h2>
<p>MolecularRNN builds on the GraphRNN architecture by introducing atom type predictions alongside edge type predictions. The model generates molecular graphs sequentially: at each step, a NodeRNN predicts the type of the next atom, then an EdgeRNN predicts bond types to all preceding atoms within a BFS-ordered window.</p>
<h3 id="autoregressive-graph-generation">Autoregressive Graph Generation</h3>
<p>The joint likelihood over atom types $C^{\pi}$ and adjacency vectors $S^{\pi}$ under BFS ordering $\pi$ is factorized as:</p>
<p>$$
p\left(S^{\pi}, C^{\pi}\right) = \prod_{i=1}^{n+1} p\left(C_{i}^{\pi} \mid S_{&lt;i}^{\pi}, C_{&lt;i}^{\pi}\right) p\left(S_{i}^{\pi} \mid C_{i}^{\pi}, S_{&lt;i}^{\pi}, C_{&lt;i}^{\pi}\right)
$$</p>
<p>NodeRNN processes embeddings of previous atom types and adjacency vectors to produce a hidden state, from which a two-layer MLP with softmax predicts the next atom type $\psi_{i}$:</p>
<p>$$
h_{i}^{\text{node}} = \text{NodeRNN}\left(h_{i-1}^{\text{node}}, \left[\text{emb}(S_{i-1}^{\pi}), \text{emb}(C_{i-1}^{\pi})\right]\right)
$$</p>
<p>$$
\psi_{i} = \text{NodeMLP}\left(h_{i}^{\text{node}}\right)
$$</p>
<p>EdgeRNN then unrolls across preceding atoms to predict bond types $\phi_{i,j}$, initialized with the NodeRNN hidden state:</p>
<p>$$
h_{i,j}^{\text{edge}} = \text{EdgeRNN}\left(h_{i,j-1}^{\text{edge}}, \text{emb}(S_{i,j-1}^{\pi})\right), \quad h_{i,0}^{\text{edge}} = h_{i}^{\text{node}}
$$</p>
<p>$$
\phi_{i,j} = \text{EdgeMLP}\left(h_{i,j}^{\text{edge}}\right)
$$</p>
<p>Bond types are categorical over {no bond, single, double, triple}, and molecules are represented in kekulized form. BFS ordering limits the EdgeRNN window to $M = 12$ preceding atoms.</p>
<h3 id="valency-based-rejection-sampling">Valency-Based Rejection Sampling</h3>
<p>During inference, each proposed bond of order $k$ between atoms $i$ and $j$ is accepted only if both atoms remain within their allowed valencies:</p>
<p>$$
\sum_{j} A_{i,j}^{\pi} + k \leq \text{valency}_{C_{i}^{\pi}} \quad \text{and} \quad \sum_{i} A_{i,j}^{\pi} + k \leq \text{valency}_{C_{j}^{\pi}}
$$</p>
<p>Atoms that do not fill their valencies are complemented with hydrogens. This constraint can be enforced directly on graphs (unlike SMILES, where intermediate substrings are not chemically meaningful), yielding 100% valid molecules.</p>
<h3 id="property-optimization-via-policy-gradient">Property Optimization via Policy Gradient</h3>
<p>For property optimization, MolecularRNN is formulated as a policy network in a Markov Decision Process. The loss function uses REINFORCE with a discounted final reward:</p>
<p>$$
L(\theta) = -\sum_{i=1}^{N} r(s_{N}) \cdot \gamma^{i} \cdot \log p(s_{i} \mid s_{i-1}; \theta)
$$</p>
<p>where $r(s_{N})$ is the reward from a property critic and $\gamma$ is a discount factor. The authors also introduce a structural penalty during RL training that assigns a penalty of $-10$ to atoms violating valency constraints, providing a learning signal from invalid intermediate molecules.</p>
<h2 id="experimental-setup-pretraining-and-property-optimization">Experimental Setup: Pretraining and Property Optimization</h2>
<h3 id="pretraining">Pretraining</h3>
<p>MolecularRNN is pretrained on three datasets: ChEMBL (~1.5M bioactive molecules), <a href="/notes/chemistry/datasets/zinc-22/">ZINC 250k</a> (250K randomly selected commercially available compounds), and <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> (~1.9M drug-like molecules from ZINC). The model considers 9 atom types (C, N, O, F, P, S, Cl, Br, I), 3 bond types (single, double, triple), and molecules with 10-50 heavy atoms. Architecture: NodeRNN with 4 GRU layers (hidden size 256), EdgeRNN with 4 GRU layers (hidden size 128), node embedding size 128, edge embedding size 16. Training uses Adam with learning rate 0.001 and multiplicative decay on 4 GPUs with batch size 512 per GPU for 250 epochs.</p>
<h3 id="generation-quality-at-scale">Generation Quality at Scale</h3>
<p>The pretrained model generates 1 million molecules per dataset (far larger than prior work: JT-VAE used 5K samples, Li et al. used 100K). Results with valency-based rejection sampling:</p>
<table>
  <thead>
      <tr>
          <th>Training Set</th>
          <th>Valid</th>
          <th>Unique</th>
          <th>Novel</th>
          <th>IntDiv (p=1)</th>
          <th>IntDiv (p=2)</th>
          <th>SA Score</th>
          <th>QED</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChEMBL</td>
          <td>100%</td>
          <td>99.2%</td>
          <td>99.3%</td>
          <td>0.895</td>
          <td>0.890</td>
          <td>3.67 +/- 1.20</td>
          <td>0.56 +/- 0.20</td>
      </tr>
      <tr>
          <td>ZINC 250k</td>
          <td>100%</td>
          <td>99.8%</td>
          <td>100%</td>
          <td>0.892</td>
          <td>0.887</td>
          <td>3.60 +/- 1.01</td>
          <td>0.68 +/- 0.16</td>
      </tr>
      <tr>
          <td>MOSES</td>
          <td>100%</td>
          <td>99.4%</td>
          <td>100%</td>
          <td>0.881</td>
          <td>0.876</td>
          <td>3.24 +/- 0.97</td>
          <td>0.74 +/- 0.14</td>
      </tr>
  </tbody>
</table>
<p>Comparison with baselines on ZINC 250k (30K samples):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Valid</th>
          <th>Unique</th>
          <th>Novel</th>
          <th>SA Score</th>
          <th>QED</th>
          <th>IntDiv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>JT-VAE</td>
          <td>99.8%</td>
          <td>100%</td>
          <td>100%</td>
          <td>3.37</td>
          <td>0.76</td>
          <td>0.85</td>
      </tr>
      <tr>
          <td>GCPN</td>
          <td>100%</td>
          <td>99.97%</td>
          <td>100%</td>
          <td>4.62</td>
          <td>0.61</td>
          <td>0.90</td>
      </tr>
      <tr>
          <td>MolecularRNN</td>
          <td>100%</td>
          <td>99.89%</td>
          <td>100%</td>
          <td>3.59</td>
          <td>0.68</td>
          <td>0.89</td>
      </tr>
  </tbody>
</table>
<p>GCPN generates overly complex molecules (high SA score of 4.62), while MolecularRNN produces more realistic structures with higher internal diversity than JT-VAE.</p>
<h3 id="property-optimization-results">Property Optimization Results</h3>
<p>Policy gradient optimization is run for 300 iterations with batch size 512 and constant learning rate $10^{-5}$, discount factor $\gamma = 0.97$. Top-3 scores for penalized logP and QED:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>logP 1st</th>
          <th>logP 2nd</th>
          <th>logP 3rd</th>
          <th>QED 1st</th>
          <th>QED 2nd</th>
          <th>QED 3rd</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a></td>
          <td>3.63</td>
          <td>3.49</td>
          <td>3.44</td>
          <td>0.896</td>
          <td>0.824</td>
          <td>0.820</td>
      </tr>
      <tr>
          <td>JT-VAE</td>
          <td>5.30</td>
          <td>4.93</td>
          <td>4.49</td>
          <td>0.925</td>
          <td>0.911</td>
          <td>0.910</td>
      </tr>
      <tr>
          <td>GCPN</td>
          <td>7.98</td>
          <td>7.85</td>
          <td>7.80</td>
          <td>0.948</td>
          <td>0.947</td>
          <td>0.946</td>
      </tr>
      <tr>
          <td>MolecularRNN</td>
          <td>10.34</td>
          <td>10.19</td>
          <td>10.14</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.947</td>
      </tr>
  </tbody>
</table>
<p>MolecularRNN achieves the highest penalized logP scores (10.34 vs. GCPN&rsquo;s 7.98) while matching GCPN on QED. The authors also demonstrate melting temperature optimization using a GCN-based property predictor as the critic (RMSE 39.5 degrees C), showing that the RL framework generalizes to properties that cannot be computed directly from molecular graphs.</p>
<h2 id="distribution-level-evaluation-and-learned-chemical-patterns">Distribution-Level Evaluation and Learned Chemical Patterns</h2>
<p>The authors emphasize that reporting only top-3 scores is not informative, and they compare full property distributions. MolecularRNN shifts the QED distribution further toward maximum values compared to GCPN. They also note that during melting temperature optimization, the model rediscovered two chemical phenomena: fusing aromatic rings increases melting point, and the presence of polar groups (C=O, OH, NH2, heterocyclic nitrogens) enhances dipole-dipole interactions and raises melting temperature.</p>
<p>Without valency-based rejection sampling, the pretrained model achieves 65% validity. After structural penalty training (assigning -10 to valency-violating atoms and optimizing with policy gradient), validity increases to 90%. Enabling rejection sampling then achieves 100%.</p>
<p>Several limitations are worth noting. The BFS ordering introduces an arbitrary sequencing over equivalent graph traversals (the node order permutation problem is not addressed). The evaluation uses top-3 scores for property optimization, though the authors do advocate for distributional evaluation. The molecule size is capped at 50 heavy atoms. The paper does not report training time or wall-clock generation speed. Future directions mentioned include multi-objective property optimization and scaffold completion (graph completion from a given core structure).</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL</td>
          <td>~1.5M molecules</td>
          <td>Bioactive molecules with experimental measurements</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>ZINC 250k</td>
          <td>250K molecules</td>
          <td>Random subset of ZINC database</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>MOSES</td>
          <td>~1.9M molecules</td>
          <td>Drug-like subset of ZINC</td>
      </tr>
      <tr>
          <td>Melting point critic</td>
          <td>Custom split</td>
          <td>37,940 train / 9,458 test</td>
          <td>Melting temperatures from -196 to 517 degrees C</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining</strong>: Maximum likelihood with Adam optimizer, learning rate 0.001 with multiplicative decay to $10^{-5}$, 250 epochs</li>
<li><strong>Structural penalty</strong>: Policy gradient with -10 penalty per valency-violating atom</li>
<li><strong>Property optimization</strong>: REINFORCE (policy gradient), 300 iterations, batch size 512, learning rate $10^{-5}$, discount factor $\gamma = 0.97$</li>
<li><strong>Melting point critic</strong>: GCN regression (4 layers, hidden size 128), Adam with learning rate 0.001, exponential decay $\gamma = 0.8$, 30 epochs, batch size 32</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>NodeRNN</strong>: 4 GRU layers, hidden size 256, node embedding 128</li>
<li><strong>EdgeRNN</strong>: 4 GRU layers, hidden size 128, edge embedding 16</li>
<li><strong>NodeMLP/EdgeMLP</strong>: 2-layer MLP with 128 hidden units, ReLU activation, softmax output</li>
<li><strong>BFS window</strong>: $M = 12$ preceding atoms</li>
<li><strong>Atom types</strong>: 9 (C, N, O, F, P, S, Cl, Br, I)</li>
<li><strong>Bond types</strong>: 3 (single, double, triple) + no bond</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>% chemically valid molecules (RDKit)</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>% unique in generated pool (up to 1M)</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>% not in training set</td>
      </tr>
      <tr>
          <td>Internal Diversity</td>
          <td>Average pairwise Tanimoto distance</td>
      </tr>
      <tr>
          <td>SA Score</td>
          <td>Synthetic accessibility (2-4 optimal range)</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>Drug-likeness score (0-1)</td>
      </tr>
      <tr>
          <td>Penalized logP</td>
          <td>Lipophilicity with ring and SA penalties</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>4 GPUs (NVIDIA, specific model not stated)</li>
<li>Per-GPU batch size of 512 for pretraining</li>
<li>Training time not reported</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Popova, M., Shvets, M., Oliva, J., &amp; Isayev, O. (2019). MolecularRNN: Generating realistic molecular graphs with optimized properties. <em>arXiv preprint arXiv:1905.13372</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{popova2019molecularrnn,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolecularRNN: Generating realistic molecular graphs with optimized properties}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Popova, Mariya and Shvets, Mykhailo and Oliva, Junier and Isayev, Olexandr}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1905.13372}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Memory-Assisted RL for Diverse De Novo Mol. Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/memory-assisted-rl-diverse-molecular-design/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/memory-assisted-rl-diverse-molecular-design/</guid><description>A memory unit for REINVENT-based RL that tracks generated scaffolds and penalizes repeated solutions, increasing molecular diversity up to fourfold.</description><content:encoded><![CDATA[<h2 id="a-memory-module-for-diverse-molecular-generation-via-rl">A Memory Module for Diverse Molecular Generation via RL</h2>
<p>This is a <strong>Method</strong> paper that introduces a memory unit for reinforcement learning (RL)-based molecular generation. The primary contribution is a hash-table-based memory mechanism that integrates into the REINVENT framework&rsquo;s scoring function. By tracking previously generated high-scoring molecules and penalizing the reward when new molecules are too similar to those already stored, the memory unit forces the generative model to explore different regions of chemical space rather than collapsing onto a single scaffold family.</p>
<h2 id="policy-collapse-limits-rl-based-de-novo-design">Policy Collapse Limits RL-Based De Novo Design</h2>
<p>Recurrent neural networks (RNNs) trained with reinforcement learning can generate novel molecules optimized for desired properties. The <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> algorithm and related approaches (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGANIC</a>, GENTRL) demonstrated the viability of coupling a pretrained SMILES-based generative model with a scoring function via RL. However, a persistent problem is <strong>policy collapse</strong> (also called mode collapse): once the model discovers a high-scoring region of chemical space, it continues to exploit that region, producing structurally similar compounds with minor substitution differences. This severely limits the practical utility of RL-based generation in drug design, where medicinal chemists need diverse scaffolds to explore structure-activity relationships and manage intellectual property concerns.</p>
<p>Prior work by Liu et al. [31] attempted to address this by engineering an explorative RNN alongside the standard generative RNN, but it did not substantially increase diversity compared to standard REINVENT. Other approaches like Generative Examination Networks (GEN) performed statistical analysis during training but were not evaluated in optimization scenarios.</p>
<h2 id="core-innovation-hash-table-memory-unit-for-reward-modification">Core Innovation: Hash-Table Memory Unit for Reward Modification</h2>
<p>The key insight is to dynamically modify the reward surface during RL by maintaining a memory of previously explored chemical space. The memory unit is a hash table of index-bucket pairs. Each bucket stores up to a fixed number of high-scoring molecules (default: 25) that are chemically similar to a seed molecule (the index).</p>
<h3 id="integration-with-reinvent">Integration with REINVENT</h3>
<p>The memory unit modifies the augmented likelihood used in REINVENT. For a generated compound $c$, the augmented log-likelihood becomes:</p>
<p>$$
\log P(c)_{Aug} = \log P(c)_{PriorNetwork} + \sigma \times S(c) \times M(c)
$$</p>
<p>where $\sigma$ is a scalar coefficient, $S(c)$ is the scoring function output, and $M(c)$ is the memory unit output (either 0 or 1). The reward is:</p>
<p>$$
R(c) = \left(\log P(c)_{Aug} - \log P(c)_{AgentNetwork}\right)^2
$$</p>
<p>and the loss is $\text{loss} = -R(c)$.</p>
<h3 id="memory-unit-operation">Memory Unit Operation</h3>
<p>When a high-scoring molecule is generated:</p>
<ol>
<li>Its fingerprint or scaffold is compared against all index structures in the memory</li>
<li>If it is similar to an index (above a Tanimoto cutoff, default 0.6) and the corresponding bucket is not full, $M(c) = 1$ and the molecule is added to the bucket</li>
<li>If the bucket is full, $M(c) = 0$, effectively zeroing the reward contribution and discouraging the model from generating similar molecules</li>
<li>If no similar index exists, a new index-bucket pair is created</li>
</ol>
<h3 id="four-similarity-criteria">Four Similarity Criteria</h3>
<p>The authors evaluate four criteria for grouping molecules in the memory:</p>
<ol>
<li><strong>Compound similarity</strong>: ECFP4 Tanimoto similarity at the whole-molecule level</li>
<li><strong>Identical Bemis-Murcko (BM) scaffold</strong>: exact match of Bemis-Murcko frameworks</li>
<li><strong>Identical carbon skeleton</strong>: exact match of carbon skeletons (BM scaffolds with all heteroatoms replaced by carbon and bonds set to single)</li>
<li><strong>Scaffold similarity</strong>: atom pair fingerprint Tanimoto similarity between carbon skeletons (fuzzy matching)</li>
</ol>
<h3 id="alternative-output-modes">Alternative Output Modes</h3>
<p>Beyond the binary output ($M(c) \in {0, 1}$), the authors also explored smooth output functions. The linear mode:</p>
<p>$$
M(c) = 1 - \frac{\text{compounds in bucket}}{\text{bucket size}}
$$</p>
<p>And the sigmoid mode:</p>
<p>$$
M(c) = 1 - \frac{1}{1 + e^{-\left(\frac{\frac{\text{compounds in bucket}}{\text{bucket size}} \times 2 - 1}{0.15}\right)}}
$$</p>
<p>Both smooth modes yielded slightly fewer analogs than the binary mode and were not pursued further.</p>
<h2 id="experimental-setup-logp-optimization-and-target-activity-prediction">Experimental Setup: LogP Optimization and Target Activity Prediction</h2>
<h3 id="case-study-1-logp-optimization">Case Study 1: LogP Optimization</h3>
<p>As a proof of concept, the authors optimized LogP values for known DRD2 inhibitors. Starting from 487 DRD2 compounds with LogP &gt;= 5 (from ExCAPE-DB), they applied transfer learning to the prior model for 20 epochs, then ran RL for 150 iterations (100 compounds per iteration, 15,000 total). The scoring function was:</p>
<p>$$
S = 1 - \tanh\left(\min\left(|2 - \text{AlogP}|, |3 - \text{AlogP}|\right)\right)
$$</p>
<p>targeting LogP values between 2.0 and 3.0.</p>
<h3 id="case-study-2-htr1a-and-drd2-activity-prediction">Case Study 2: HTR1A and DRD2 Activity Prediction</h3>
<p>For a more complex scenario, the authors trained SVM classifiers (with <a href="https://en.wikipedia.org/wiki/Platt_scaling">Platt scaling</a> for probabilistic output) on bioactivity data from ExCAPE-DB to predict activity against two neurotransmitter receptors:</p>
<ul>
<li><strong><a href="https://en.wikipedia.org/wiki/5-HT1A_receptor">HTR1A</a></strong>: 3,599 actives (pIC50 &gt;= 7) and 66,684 inactives</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a></strong>: 2,981 actives (pIC50 &gt;= 7) and 346,206 inactives (100,000 sampled)</li>
</ul>
<p>Data was split using Butina clustering on ECFP6 at a 0.4 Tanimoto cutoff (60/20/20 train/val/test). The SVM models achieved excellent performance:</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Set</th>
          <th>Balanced Accuracy</th>
          <th>ROC AUC</th>
          <th>F1</th>
          <th>MCC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HTR1A</td>
          <td>Test</td>
          <td>0.96</td>
          <td>0.99</td>
          <td>0.75</td>
          <td>0.75</td>
      </tr>
      <tr>
          <td>DRD2</td>
          <td>Test</td>
          <td>0.95</td>
          <td>0.99</td>
          <td>0.71</td>
          <td>0.72</td>
      </tr>
  </tbody>
</table>
<p>RL was run for 300 iterations (100 compounds each, 30,000 total). Compounds with predicted activity &gt;= 0.7 were considered active.</p>
<h3 id="generative-model-architecture">Generative Model Architecture</h3>
<p>The RNN prior model followed the REINVENT architecture: an embedding layer, three GRU layers with 256 dimensions, and a linear output layer. It was pretrained on ~1.5 million ChEMBL 25 compounds (filtered to remove known HTR1A actives and DRD2 analogs) for 10 epochs using Adam with a learning rate of 0.01.</p>
<h3 id="comparisons">Comparisons</h3>
<p>The authors compared memory-assisted RL against:</p>
<ul>
<li>Standard REINVENT RL (no memory)</li>
<li>Experience replay (re-presenting 8 high-scoring compounds per iteration)</li>
<li>Temperature scaling (values from 1.0 to 10.0)</li>
<li>Memory + experience replay combined</li>
</ul>
<h2 id="results-up-to-fourfold-increase-in-diverse-active-compounds">Results: Up to Fourfold Increase in Diverse Active Compounds</h2>
<h3 id="logp-optimization-results">LogP Optimization Results</h3>
<p>Memory-assisted RL increased the number of optimized compounds (LogP 2-3) by roughly threefold:</p>
<table>
  <thead>
      <tr>
          <th>Memory Type</th>
          <th>Optimized Compounds</th>
          <th>Unique BM Scaffolds</th>
          <th>Unique Carbon Skeletons</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>No memory</td>
          <td>938</td>
          <td>727</td>
          <td>396</td>
      </tr>
      <tr>
          <td>Compound similarity</td>
          <td>3,451</td>
          <td>2,963</td>
          <td>1,472</td>
      </tr>
      <tr>
          <td>Identical BM Scaffold</td>
          <td>3,428</td>
          <td>2,865</td>
          <td>1,398</td>
      </tr>
      <tr>
          <td>Identical Carbon Skeleton</td>
          <td>3,315</td>
          <td>3,002</td>
          <td>1,799</td>
      </tr>
      <tr>
          <td>Scaffold Similarity</td>
          <td>3,591</td>
          <td>3,056</td>
          <td>1,538</td>
      </tr>
  </tbody>
</table>
<p>The memory unit also increased the generation of relevant analogs. ECFP6 analogs (Tanimoto &gt;= 0.4 to training set) increased from 145 to up to 549, and shared MMP cores increased from 5 to up to 19, confirming that the memory unit promoted exploration of chemically relevant space rather than random drift.</p>
<h3 id="htr1a-and-drd2-activity-optimization-results">HTR1A and DRD2 Activity Optimization Results</h3>
<p>The improvements were even more pronounced for target activity optimization:</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Memory Type</th>
          <th>Active Compounds</th>
          <th>Unique BM Scaffolds</th>
          <th>Unique Carbon Skeletons</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HTR1A</td>
          <td>No memory</td>
          <td>9,323</td>
          <td>7,312</td>
          <td>5,446</td>
      </tr>
      <tr>
          <td>HTR1A</td>
          <td>Compound similarity</td>
          <td>16,779</td>
          <td>13,304</td>
          <td>9,887</td>
      </tr>
      <tr>
          <td>HTR1A</td>
          <td>Identical Carbon Skeleton</td>
          <td>17,597</td>
          <td>15,531</td>
          <td>12,408</td>
      </tr>
      <tr>
          <td>DRD2</td>
          <td>No memory</td>
          <td>5,143</td>
          <td>2,635</td>
          <td>1,949</td>
      </tr>
      <tr>
          <td>DRD2</td>
          <td>Compound similarity</td>
          <td>21,486</td>
          <td>17,844</td>
          <td>12,749</td>
      </tr>
      <tr>
          <td>DRD2</td>
          <td>Scaffold Similarity</td>
          <td>22,784</td>
          <td>20,712</td>
          <td>16,434</td>
      </tr>
  </tbody>
</table>
<p>For DRD2, the effect was particularly striking: standard RL showed clear policy collapse with only 576 ECFP6 analogs to the training set, while memory-assisted RL generated up to 6,315. The compound similarity memory unit produced the most MMP analogs (217 to the training set vs. 7 without memory).</p>
<h3 id="parameter-sensitivity">Parameter Sensitivity</h3>
<p>Bucket size had a modest effect: larger buckets (allowing more compounds before penalization) slightly increased analog generation. The Tanimoto similarity threshold of 0.6 was near-optimal for the scaffold similarity memory; higher thresholds reduced diversity gains. The compound similarity memory showed increasing analogs with higher thresholds, but BM scaffold and carbon skeleton counts plateaued above 0.6.</p>
<h3 id="comparison-with-experience-replay-and-temperature-scaling">Comparison with Experience Replay and Temperature Scaling</h3>
<ul>
<li><strong>Experience replay alone</strong> increased diversity compared to vanilla RL but was less effective than the memory unit alone</li>
<li><strong>Memory + experience replay</strong> achieved the best results overall, as experience replay provided the model with diverse starting points for exploration after the memory unit altered the reward landscape</li>
<li><strong>Temperature scaling</strong> was largely ineffective: only a value of 1.25 showed improvement, and even then it achieved only about 50% of the analogs generated by memory-assisted RL. Temperatures above 2.0 degraded SMILES validity, and above 4.0 prevented valid molecule generation entirely</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li>All evaluations are retrospective; no synthesized compounds were experimentally tested</li>
<li>The SVM activity models, while accurate, may have applicability domain limitations for highly novel scaffolds</li>
<li>The binary memory output mode was found to work best, but the transition from exploration to exploitation is abrupt</li>
<li>The method was only tested with two biological targets and one physicochemical property</li>
<li>Computational overhead of the memory unit is not discussed</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior model training</td>
          <td>ChEMBL 25</td>
          <td>~1.5M compounds</td>
          <td>Filtered: max 50 heavy atoms, no stereochemistry, removed HTR1A actives and DRD2 analogs</td>
      </tr>
      <tr>
          <td>HTR1A activity data</td>
          <td>ExCAPE-DB</td>
          <td>3,599 actives + 66,684 inactives</td>
          <td>pIC50 &gt;= 7 threshold for actives</td>
      </tr>
      <tr>
          <td>DRD2 activity data</td>
          <td>ExCAPE-DB</td>
          <td>2,981 actives + 100,000 inactives (sampled)</td>
          <td>pIC50 &gt;= 7 threshold for actives</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Generative model</strong>: RNN with embedding + 3 GRU layers (256 dim) + linear output (REINVENT architecture)</li>
<li><strong>RL</strong>: Augmented likelihood formulation with sigma scaling coefficient</li>
<li><strong>SVM classifiers</strong>: Non-linear SVM with MinMax kernel, Platt scaling, ECFP6 count-based fingerprints (2048 dim)</li>
<li><strong>Butina clustering</strong>: ECFP6 Tanimoto cutoff 0.4 for train/val/test splitting</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unique compounds</td>
          <td>Number of distinct valid SMILES generated</td>
      </tr>
      <tr>
          <td>Unique BM scaffolds</td>
          <td>Bemis-Murcko framework diversity</td>
      </tr>
      <tr>
          <td>Unique carbon skeletons</td>
          <td>Carbon skeleton diversity (stripped BM scaffolds)</td>
      </tr>
      <tr>
          <td>ECFP6 analogs</td>
          <td>Compounds with Tanimoto &gt;= 0.4 to known actives</td>
      </tr>
      <tr>
          <td>MMP analogs</td>
          <td>Matched molecular pair relationships with known actives</td>
      </tr>
      <tr>
          <td>Shared MMP cores</td>
          <td>Scaffold cores shared between generated and known compounds</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/tblaschke/reinvent-memory">reinvent-memory</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with prepared datasets</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Blaschke, T., Engkvist, O., Bajorath, J., &amp; Chen, H. (2020). Memory-assisted reinforcement learning for diverse molecular de novo design. <em>Journal of Cheminformatics</em>, 12, 68. <a href="https://doi.org/10.1186/s13321-020-00473-0">https://doi.org/10.1186/s13321-020-00473-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{blaschke2020memory,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Memory-assisted reinforcement learning for diverse molecular de novo design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Blaschke, Thomas and Engkvist, Ola and Bajorath, J{\&#34;u}rgen and Chen, Hongming}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{68}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00473-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LSTM Neural Network for Drug-Like Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/</guid><description>An LSTM neural network trained on 509K ChEMBL SMILES generates one million novel drug-like molecules with realistic substructures and bioactivity profiles.</description><content:encoded><![CDATA[<h2 id="an-early-method-for-lstm-based-molecular-generation">An Early Method for LSTM-Based Molecular Generation</h2>
<p>This is a <strong>Method</strong> paper that applies character-level LSTM networks to the task of de novo drug-like molecule generation. The primary contribution is demonstrating that an LSTM trained on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings from a large bioactive compound database (ChEMBL) can produce novel, diverse molecules whose chemical properties closely match those of known drug-like compounds. The paper also validates the generated molecules through virtual screening with profile QSAR models, showing comparable predicted bioactivity to the training set.</p>
<h2 id="the-challenge-of-exploring-drug-like-chemical-space">The Challenge of Exploring Drug-Like Chemical Space</h2>
<p>The theoretical space of drug-like molecules is astronomically large. Brute-force enumeration approaches such as <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a> (which catalogued 166 billion molecules) are feasible only for small molecules, and full enumeration of molecules with 25-30 heavy atoms (the typical size of drug molecules) remains computationally intractable. Traditional cheminformatics approaches to sampling this space rely on fragment combination, evolutionary algorithms, or particle swarm optimization.</p>
<p>The authors position LSTM networks as a viable alternative. LSTMs had already demonstrated the ability to learn sequential structure in domains like text and music generation, making them natural candidates for learning SMILES grammar and generating novel valid molecular strings. At the time of writing (late 2017), several groups were exploring this direction, including Bjerrum and Threlfall (ZINC-based generation), <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al.</a> (VAE-based latent space design), <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">Olivecrona et al.</a> (RL-guided generation), and Segler et al. (focused library design). This paper contributes a large-scale empirical study with detailed analysis of the generated molecules&rsquo; chemical quality.</p>
<h2 id="character-level-lstm-with-temperature-based-sampling">Character-Level LSTM with Temperature-Based Sampling</h2>
<p>The core approach is straightforward: train an LSTM to predict the next character in a SMILES string, then sample from the trained model to generate new molecules character by character.</p>
<p>The network architecture consists of:</p>
<ul>
<li>Two stacked LSTM layers (which learn the SMILES grammar)</li>
<li>A dropout layer for regularization</li>
<li>A dense output layer with 23 neurons (one per character in the reduced SMILES alphabet) and softmax activation</li>
</ul>
<p>The RMSProp optimizer was used for training. The learning rate was gradually decreased from 0.01 to 0.0002 during training. At generation time, a temperature parameter controls the randomness of character sampling to produce more diverse structures rather than reproducing training molecules too closely.</p>
<p>A key preprocessing step reduces the SMILES alphabet to 23 characters. Multi-character atom tokens are replaced with single characters (<code>Cl</code> → <code>L</code>, <code>Br</code> → <code>R</code>, <code>[nH]</code> → <code>A</code>). Only the organic atom subset (<code>H</code>, <code>C</code>, <code>N</code>, <code>O</code>, <code>S</code>, <code>P</code>, <code>F</code>, <code>Cl</code>, <code>Br</code>, <code>I</code>) is retained. Charged molecules, stereo information, and molecules with more than 5 ring closures are excluded. The training corpus totals 23,664,668 characters, with 40-character windows used as input sequences during training.</p>
<h2 id="training-on-chembl-and-generating-one-million-molecules">Training on ChEMBL and Generating One Million Molecules</h2>
<h3 id="training-data">Training Data</h3>
<p>The training set consists of 509,000 bioactive molecules from ChEMBL with reported activity below 10 micromolar on any target.</p>
<h3 id="generation-and-filtering">Generation and Filtering</h3>
<p>The LSTM generates SMILES strings character by character. The generated strings undergo a two-stage validation:</p>
<ol>
<li><strong>Bracket and ring closure check</strong> (fast text-based): 54% of generated SMILES are discarded for unpaired brackets or ring closures</li>
<li><strong>Full chemical parsing with RDKit</strong>: An additional 14% fail due to unrealistic aromatic systems or incorrect valences</li>
<li><strong>Final yield</strong>: 32% of generated SMILES correspond to valid molecules</li>
</ol>
<p>One million valid molecules were generated in under 2 hours on 300 CPUs.</p>
<h3 id="novelty-and-diversity">Novelty and Diversity</h3>
<p>Out of one million generated molecules, only 2,774 (0.28%) were identical to molecules in the training ChEMBL set. The generated set contained 627,000 unique scaffolds compared to 172,000 in ChEMBL, with an overlap of only 18,000 scaffolds. This demonstrates substantial novelty and diversity.</p>
<h3 id="physicochemical-properties">Physicochemical Properties</h3>
<p>Calculated molecular descriptors (molecular weight, logP, and topological polar surface area) for the generated molecules closely matched the distributions of the ChEMBL training set. The synthetic accessibility score distributions were also practically identical, indicating comparable molecular complexity.</p>
<h3 id="substructure-feature-comparison">Substructure Feature Comparison</h3>
<p>The paper compares substructure features across three molecule sets: ChEMBL training data, LSTM-generated molecules, and a naive SMILES baseline generator. The naive generator uses only character frequency statistics and basic SMILES syntax rules, producing primarily macrocycles with very few fused aromatic systems.</p>
<table>
  <thead>
      <tr>
          <th>Feature</th>
          <th>ChEMBL (%)</th>
          <th>LSTM Generated (%)</th>
          <th>Naive Baseline (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>No rings</td>
          <td>0.4</td>
          <td>0.4</td>
          <td>0.1</td>
      </tr>
      <tr>
          <td>1 ring</td>
          <td>2.8</td>
          <td>4.3</td>
          <td>13.2</td>
      </tr>
      <tr>
          <td>2 rings</td>
          <td>14.8</td>
          <td>23.1</td>
          <td>17.7</td>
      </tr>
      <tr>
          <td>3 rings</td>
          <td>32.2</td>
          <td>43.5</td>
          <td>27.3</td>
      </tr>
      <tr>
          <td>4 rings</td>
          <td>32.7</td>
          <td>23.9</td>
          <td>25.2</td>
      </tr>
      <tr>
          <td>&gt;4 rings</td>
          <td>17.2</td>
          <td>4.8</td>
          <td>16.5</td>
      </tr>
      <tr>
          <td>Fused aromatic rings</td>
          <td>38.8</td>
          <td>30.9</td>
          <td>0.2</td>
      </tr>
      <tr>
          <td>Large rings (&gt;8)</td>
          <td>0.4</td>
          <td>1.8</td>
          <td>75.9</td>
      </tr>
      <tr>
          <td>Spiro rings</td>
          <td>1.9</td>
          <td>0.6</td>
          <td>0.6</td>
      </tr>
      <tr>
          <td>Contains N</td>
          <td>96.5</td>
          <td>96.1</td>
          <td>92.3</td>
      </tr>
      <tr>
          <td>Contains O</td>
          <td>93.0</td>
          <td>92.0</td>
          <td>85.5</td>
      </tr>
      <tr>
          <td>Contains S</td>
          <td>35.6</td>
          <td>27.9</td>
          <td>39.6</td>
      </tr>
      <tr>
          <td>Contains halogen</td>
          <td>40.7</td>
          <td>38.8</td>
          <td>49.4</td>
      </tr>
  </tbody>
</table>
<p>The LSTM-generated molecules closely mirror the ChEMBL distributions, while the naive generator fails to capture drug-like structural patterns. The LSTM tends to slightly over-represent 2-3 ring systems and under-represent 4+ ring systems relative to ChEMBL. Functional group distributions also closely matched between ChEMBL and the LSTM output.</p>
<h3 id="virtual-screening-validation">Virtual Screening Validation</h3>
<p>The generated molecules were evaluated using profile QSAR models for 159 ChEMBL kinase assays. The six best models (with realistic test set R-squared &gt; 0.75) were used to predict pIC50 values for both actual ChEMBL compounds and generated compounds. The cumulative frequency distributions of predicted activity were nearly identical between the two sets.</p>
<p>Kolmogorov-Smirnov (KS) tests on random samples of 1,000 compounds confirmed this quantitatively:</p>
<table>
  <thead>
      <tr>
          <th>Assay</th>
          <th>KS D</th>
          <th>Distributions Differ?</th>
          <th>Mean (Real)</th>
          <th>Mean (Gen)</th>
          <th>Stdev (Real)</th>
          <th>Stdev (Gen)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>688395</td>
          <td>6.01%</td>
          <td>No</td>
          <td>4.66</td>
          <td>4.69</td>
          <td>0.25</td>
          <td>0.24</td>
      </tr>
      <tr>
          <td>668624</td>
          <td>3.60%</td>
          <td>No</td>
          <td>4.86</td>
          <td>4.86</td>
          <td>0.25</td>
          <td>0.24</td>
      </tr>
      <tr>
          <td>809226</td>
          <td>9.90%</td>
          <td>Yes</td>
          <td>5.33</td>
          <td>5.26</td>
          <td>0.34</td>
          <td>0.30</td>
      </tr>
      <tr>
          <td>809226</td>
          <td>4.30%</td>
          <td>No</td>
          <td>5.18</td>
          <td>5.13</td>
          <td>0.47</td>
          <td>0.43</td>
      </tr>
      <tr>
          <td>688781</td>
          <td>2.20%</td>
          <td>No</td>
          <td>4.83</td>
          <td>4.82</td>
          <td>0.26</td>
          <td>0.25</td>
      </tr>
      <tr>
          <td>809170</td>
          <td>8.70%</td>
          <td>Yes</td>
          <td>5.12</td>
          <td>5.07</td>
          <td>0.51</td>
          <td>0.46</td>
      </tr>
  </tbody>
</table>
<p>For 4 of 6 models, the null hypothesis that the distributions are the same could not be rejected at the 95% confidence level (critical D = 6.04%). Even for the two assays where the KS test rejected the null hypothesis, the maximum vertical distance between distributions was below 10%.</p>
<h2 id="generated-molecules-are-novel-drug-like-and-potentially-bioactive">Generated Molecules Are Novel, Drug-Like, and Potentially Bioactive</h2>
<p>The key findings of this study are:</p>
<ol>
<li><strong>High novelty</strong>: Only 0.28% of generated molecules match training compounds; 627K novel scaffolds were produced versus 172K in ChEMBL</li>
<li><strong>Drug-like quality</strong>: Physicochemical properties, substructure features, functional group distributions, and synthetic accessibility scores all closely match the ChEMBL training distribution, without these being explicit constraints</li>
<li><strong>Predicted bioactivity</strong>: Virtual screening with profile QSAR models shows the generated molecules have comparable predicted activity profiles to known bioactive compounds</li>
<li><strong>Scalability</strong>: One million valid molecules in under 2 hours on 300 CPUs, with the potential to scale to billions with GPU acceleration</li>
<li><strong>LSTM superiority over naive baselines</strong>: A simple statistical SMILES generator using only character frequencies produces chemically unrealistic molecules (mostly macrocycles), demonstrating that the LSTM genuinely learns drug-like chemical patterns</li>
</ol>
<p>The main limitations are the 32% validity rate (68% of generated SMILES are invalid), the exclusion of stereochemistry and charged molecules from the training set, and the lack of any goal-directed generation capability (the model produces unconditional samples from the training distribution). The code was described as &ldquo;available on request&rdquo; from the corresponding author rather than publicly released.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL bioactive molecules</td>
          <td>509,000 molecules</td>
          <td>Activity &lt; 10 uM on any target; organic atoms only; no charges or stereo</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Double-stacked LSTM layers with dropout</li>
<li>Softmax output over 23-character reduced SMILES alphabet</li>
<li>RMSProp optimizer with learning rate annealed from 0.01 to 0.0002</li>
<li>Temperature-based sampling at generation time</li>
<li>40-character input windows during training</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture consists of two LSTM layers, a dropout layer, and a 23-neuron dense output layer. Exact hidden unit counts and dropout rates are not specified in the paper.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid SMILES rate</td>
          <td>32%</td>
          <td>After bracket check and RDKit parsing</td>
      </tr>
      <tr>
          <td>Novelty (vs. training)</td>
          <td>99.72%</td>
          <td>Only 2,774 of 1M match ChEMBL</td>
      </tr>
      <tr>
          <td>Unique scaffolds</td>
          <td>627,000</td>
          <td>vs. 172,000 in ChEMBL</td>
      </tr>
      <tr>
          <td>KS test (4/6 assays)</td>
          <td>Not significantly different</td>
          <td>At 95% confidence</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Generation: 300 CPUs for under 2 hours (1 million valid molecules)</li>
<li>Training hardware not specified</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ertl, P., Lewis, R., Martin, E., &amp; Polyakov, V. (2017). In silico generation of novel, drug-like chemical matter using the LSTM neural network. <em>arXiv preprint</em>, arXiv:1712.07449.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ertl2017silico,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{In silico generation of novel, drug-like chemical matter using the LSTM neural network}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ertl, Peter and Lewis, Richard and Martin, Eric and Polyakov, Valery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1712.07449}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LatentGAN: Latent-Space GAN for Molecular Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/</guid><description>LatentGAN combines a SMILES heteroencoder with a Wasserstein GAN to generate novel drug-like molecules in latent space, avoiding SMILES syntax issues.</description><content:encoded><![CDATA[<h2 id="a-gan-operating-in-learned-latent-space-for-molecular-design">A GAN Operating in Learned Latent Space for Molecular Design</h2>
<p>LatentGAN is a <strong>Method</strong> paper that introduces a two-stage architecture for de novo molecular generation. The first stage trains a heteroencoder to map SMILES strings into a continuous latent vector space. The second stage trains a Wasserstein GAN with gradient penalty (WGAN-GP) to generate new latent vectors that, when decoded, produce valid and novel molecular structures. The key contribution is decoupling the GAN from direct SMILES string generation, allowing the adversarial training to focus on learning the distribution of molecular latent representations rather than character-level sequence generation.</p>
<h2 id="limitations-of-direct-smiles-generation-with-gans">Limitations of Direct SMILES Generation with GANs</h2>
<p>Prior GAN-based molecular generation methods such as ORGAN and ORGANIC operated directly on SMILES strings. This created a fundamental challenge: the generator had to simultaneously learn valid SMILES syntax and the distribution of chemically meaningful molecules. ORGAN struggled with optimizing discrete molecular properties like Lipinski&rsquo;s Rule of Five, while ORGANIC showed limited success beyond the QED drug-likeness score. Other approaches (RANC, ATNC) substituted more advanced recurrent architectures but still operated in the discrete SMILES space.</p>
<p>Meanwhile, variational autoencoders (VAEs) demonstrated that working in continuous latent space could enable molecular generation, but they relied on forcing the latent distribution to match a Gaussian prior through KL divergence. This assumption is not necessarily appropriate for chemical space, which is inherently discontinuous.</p>
<p>RNN-based methods with transfer learning offered an alternative for target-biased generation, but the authors hypothesized that combining GANs with learned latent representations could produce complementary chemical space coverage.</p>
<h2 id="heteroencoder-plus-wasserstein-gan-architecture">Heteroencoder Plus Wasserstein GAN Architecture</h2>
<p>The core innovation of LatentGAN is separating molecular representation learning from adversarial generation through a two-component pipeline.</p>
<h3 id="heteroencoder">Heteroencoder</h3>
<p>The heteroencoder is an autoencoder trained on pairs of different non-canonical (randomized) SMILES representations of the same molecule. This is distinct from a standard autoencoder because the input and target SMILES are different representations of the same structure.</p>
<p>The encoder uses a two-layer bidirectional LSTM with 512 units per layer (256 forward, 256 backward). The concatenated output feeds into a 512-dimensional feed-forward layer. During training, zero-centered Gaussian noise with $\sigma = 0.1$ is added to the latent vector as regularization. The decoder is a four-layer unidirectional LSTM with a softmax output layer. Batch normalization with momentum 0.9 is applied to all hidden layers except the noise layer.</p>
<p>Training uses teacher forcing with categorical cross-entropy loss for 100 epochs. The learning rate starts at $10^{-3}$ for the first 50 epochs and decays exponentially to $10^{-6}$ by the final epoch. After training, the noise layer is deactivated for deterministic encoding and decoding.</p>
<p>An important design choice is that the heteroencoder makes no assumption about the latent space distribution (unlike VAEs with their KL divergence term). The latent space is shaped purely by reconstruction loss, and the GAN later learns to sample from this unconstrained distribution.</p>
<h3 id="wasserstein-gan-with-gradient-penalty">Wasserstein GAN with Gradient Penalty</h3>
<p>The GAN uses the WGAN-GP formulation. The critic (discriminator) consists of three feed-forward layers of 256 dimensions each with leaky ReLU activations (no activation on the final layer). The generator has five feed-forward layers of 256 dimensions each with batch normalization and leaky ReLU between layers.</p>
<p>The training ratio is 5:1, with five critic updates for every generator update. The generator takes random vectors sampled from a uniform distribution and learns to produce latent vectors indistinguishable from the real encoded molecular latent vectors.</p>
<p>The WGAN-GP loss for the critic is:</p>
<p>$$L_{\text{critic}} = \mathbb{E}_{\tilde{x} \sim \mathbb{P}_g}[D(\tilde{x})] - \mathbb{E}_{x \sim \mathbb{P}_r}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim \mathbb{P}_{\hat{x}}}[(|\nabla_{\hat{x}} D(\hat{x})|_2 - 1)^2]$$</p>
<p>where $\lambda$ is the gradient penalty coefficient, $\mathbb{P}_r$ is the real data distribution (encoded latent vectors), $\mathbb{P}_g$ is the generator distribution, and $\mathbb{P}_{\hat{x}}$ samples uniformly along straight lines between pairs of real and generated points.</p>
<h3 id="generation-pipeline">Generation Pipeline</h3>
<p>At inference time, the full pipeline operates as: (1) sample a random vector, (2) pass through the trained generator to produce a latent vector, (3) decode the latent vector into a SMILES string using the pretrained heteroencoder decoder.</p>
<h2 id="experiments-on-drug-like-and-target-biased-generation">Experiments on Drug-Like and Target-Biased Generation</h2>
<h3 id="datasets">Datasets</h3>
<p>The heteroencoder was trained on 1,347,173 SMILES from ChEMBL 25, standardized with MolVS and restricted to molecules with atoms from {H, C, N, O, S, Cl, Br} and at most 50 heavy atoms.</p>
<p>For general drug-like generation, a random subset of 100,000 ChEMBL compounds was used to train the GAN model for 30,000 epochs.</p>
<p>For target-biased generation, three datasets were extracted from ExCAPE-DB for EGFR, HTR1A, and S1PR1 targets. These were clustered into training and test sets to ensure chemical series were not split across sets.</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Training Set</th>
          <th>Test Set</th>
          <th>SVM ROC-AUC</th>
          <th>SVM Kappa</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>EGFR</td>
          <td>2,949</td>
          <td>2,326</td>
          <td>0.850</td>
          <td>0.56</td>
      </tr>
      <tr>
          <td>HTR1A</td>
          <td>48,283</td>
          <td>23,048</td>
          <td>0.993</td>
          <td>0.90</td>
      </tr>
      <tr>
          <td>S1PR1</td>
          <td>49,381</td>
          <td>23,745</td>
          <td>0.995</td>
          <td>0.91</td>
      </tr>
  </tbody>
</table>
<p>SVM target prediction models using 2048-bit FCFP6 fingerprints were built with scikit-learn to evaluate generated compounds.</p>
<h3 id="baselines">Baselines</h3>
<p>RNN-based generative models with transfer learning served as the primary baseline. A prior RNN model was trained on the same ChEMBL set, then fine-tuned on each target dataset. The LatentGAN was also benchmarked on the MOSES platform against VAE, JTN-VAE, and AAE architectures.</p>
<h3 id="heteroencoder-performance">Heteroencoder Performance</h3>
<p>The heteroencoder achieved 99% valid SMILES on the training set and 98% on the test set. Reconstruction error (decoding to a different molecule) was 18% on training and 20% on test. Notably, decoding to a different valid SMILES of the same molecule is not counted as an error.</p>
<h3 id="target-biased-generation-results">Target-Biased Generation Results</h3>
<p>From 50,000 sampled SMILES per target model:</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>Arch.</th>
          <th>Valid (%)</th>
          <th>Unique (%)</th>
          <th>Novel (%)</th>
          <th>Active (%)</th>
          <th>Recovered Actives (%)</th>
          <th>Recovered Neighbors</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>EGFR</td>
          <td>GAN</td>
          <td>86</td>
          <td>56</td>
          <td>97</td>
          <td>71</td>
          <td>5.26</td>
          <td>196</td>
      </tr>
      <tr>
          <td>EGFR</td>
          <td>RNN</td>
          <td>96</td>
          <td>46</td>
          <td>95</td>
          <td>65</td>
          <td>7.74</td>
          <td>238</td>
      </tr>
      <tr>
          <td>HTR1A</td>
          <td>GAN</td>
          <td>86</td>
          <td>66</td>
          <td>95</td>
          <td>71</td>
          <td>5.05</td>
          <td>284</td>
      </tr>
      <tr>
          <td>HTR1A</td>
          <td>RNN</td>
          <td>96</td>
          <td>50</td>
          <td>90</td>
          <td>81</td>
          <td>7.28</td>
          <td>384</td>
      </tr>
      <tr>
          <td>S1PR1</td>
          <td>GAN</td>
          <td>89</td>
          <td>31</td>
          <td>98</td>
          <td>44</td>
          <td>0.93</td>
          <td>24</td>
      </tr>
      <tr>
          <td>S1PR1</td>
          <td>RNN</td>
          <td>97</td>
          <td>35</td>
          <td>97</td>
          <td>65</td>
          <td>3.72</td>
          <td>43</td>
      </tr>
  </tbody>
</table>
<h3 id="moses-benchmark">MOSES Benchmark</h3>
<p>On the MOSES benchmark (trained on a ZINC subset of 1,584,663 compounds, sampled 30,000 SMILES), LatentGAN showed comparable or better results than JTN-VAE and AAE on Frechet ChemNet Distance (FCD), Fragment similarity, and Scaffold similarity, while producing slightly worse nearest-neighbor cosine similarity (SNN). The standard VAE showed signs of mode collapse with high test metric overlap and low novelty.</p>
<h2 id="complementary-generation-and-drug-likeness-preservation">Complementary Generation and Drug-Likeness Preservation</h2>
<h3 id="key-findings">Key Findings</h3>
<p><strong>Validity and novelty</strong>: LatentGAN achieved 86-89% validity on target-biased tasks (lower than RNN&rsquo;s 96-97%) but produced higher uniqueness on two of three targets and comparable or higher novelty (95-98%).</p>
<p><strong>Complementary chemical space</strong>: The overlap between LatentGAN-generated and RNN-generated active compounds was very small at both compound and scaffold levels. A probabilistic analysis showed that the RNN model would be very unlikely to eventually cover the LatentGAN output space. This suggests the two architectures can work complementarily in de novo design campaigns.</p>
<p><strong>Drug-likeness</strong>: QED score distributions of LatentGAN-generated compounds closely matched training set distributions across all three targets, with training compounds showing only slightly higher drug-likeness. SA score distributions were similarly well-preserved.</p>
<p><strong>Chemical space coverage</strong>: PCA analysis using MQN fingerprints confirmed that generated compounds occupy most of the chemical space of the training sets. Some regions of the PCA plots contained compounds predicted as inactive, which corresponded to non-drug-like outliers in the training data.</p>
<p><strong>Novel scaffolds</strong>: About 14% of scaffolds in the sampled sets had similarity below 0.4 to the training set across all three targets, indicating LatentGAN can generate genuinely novel chemical scaffolds. Around 5% of generated compounds were identical to training set compounds, while 21-25% had Tanimoto similarity below 0.4.</p>
<h3 id="limitations">Limitations</h3>
<p>The paper acknowledges several limitations. The 18-20% heteroencoder reconstruction error means a non-trivial fraction of encoded molecules decode to different structures. Validity rates (86-89%) are lower than RNN baselines (96-97%). The S1PR1 target showed notably lower uniqueness (31%) and predicted activity (44%) compared to the other targets, possibly due to the smaller effective training set of active compounds. The paper does not report specific hardware requirements or training times. No wet-lab experimental validation of generated compounds was performed.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors envision LatentGAN as a complementary tool to existing RNN-based generative models, with the two architectures covering different regions of chemical space. The approach of operating in learned latent space rather than directly on SMILES strings offers a general framework that could be extended to other molecular representations or generation objectives.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Heteroencoder training</td>
          <td>ChEMBL 25 (subset)</td>
          <td>1,347,173 SMILES</td>
          <td>Standardized with MolVS; atoms restricted to H, C, N, O, S, Cl, Br; max 50 heavy atoms</td>
      </tr>
      <tr>
          <td>General GAN training</td>
          <td>ChEMBL 25 (random subset)</td>
          <td>100,000</td>
          <td>Subset of heteroencoder training set</td>
      </tr>
      <tr>
          <td>Target-biased training</td>
          <td>ExCAPE-DB (EGFR)</td>
          <td>2,949 actives</td>
          <td>Clustered train/test split</td>
      </tr>
      <tr>
          <td>Target-biased training</td>
          <td>ExCAPE-DB (HTR1A)</td>
          <td>48,283 actives</td>
          <td>Clustered train/test split</td>
      </tr>
      <tr>
          <td>Target-biased training</td>
          <td>ExCAPE-DB (S1PR1)</td>
          <td>49,381 actives</td>
          <td>Clustered train/test split</td>
      </tr>
      <tr>
          <td>Benchmarking</td>
          <td>ZINC (MOSES subset)</td>
          <td>1,584,663</td>
          <td>Canonical SMILES</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Heteroencoder</strong>: Bidirectional LSTM encoder (2 layers, 512 units) + unidirectional LSTM decoder (4 layers), trained with teacher forcing and categorical cross-entropy for 100 epochs</li>
<li><strong>GAN</strong>: WGAN-GP with 5:1 critic-to-generator training ratio. General model trained 30,000 epochs; target models trained 10,000 epochs</li>
<li><strong>Evaluation</strong>: SVM classifiers with FCFP6 fingerprints (2048 bits) for activity prediction; MQN fingerprints for PCA-based chemical space analysis; Murcko scaffolds for scaffold-level analysis</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Heteroencoder: 512-dim latent space, bidirectional LSTM encoder, unidirectional LSTM decoder</li>
<li>Generator: 5 feed-forward layers of 256 dims with batch norm and leaky ReLU</li>
<li>Critic: 3 feed-forward layers of 256 dims with leaky ReLU</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>LatentGAN (EGFR)</th>
          <th>RNN Baseline (EGFR)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>86%</td>
          <td>96%</td>
          <td>Percent valid SMILES</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>56%</td>
          <td>46%</td>
          <td>Percent unique among valid</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>97%</td>
          <td>95%</td>
          <td>Not in training set</td>
      </tr>
      <tr>
          <td>Predicted active</td>
          <td>71%</td>
          <td>65%</td>
          <td>By SVM model</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Dierme/latent-gan">LatentGAN source code</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Includes trained heteroencoder model and training sets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Prykhodko, O., Johansson, S.V., Kotsias, P.-C., Arús-Pous, J., Bjerrum, E.J., Engkvist, O., &amp; Chen, H. (2019). A de novo molecular generation method using latent vector based generative adversarial network. <em>Journal of Cheminformatics</em>, 11(1), 74. <a href="https://doi.org/10.1186/s13321-019-0397-9">https://doi.org/10.1186/s13321-019-0397-9</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{prykhodko2019latentgan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A de novo molecular generation method using latent vector based generative adversarial network}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Prykhodko, Oleksii and Johansson, Simon Viet and Kotsias, Panagiotis-Christos and Ar{\&#39;u}s-Pous, Josep and Bjerrum, Esben Jannik and Engkvist, Ola and Chen, Hongming}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{74}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-019-0397-9}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Grammar VAE: Generating Valid Molecules via CFGs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/</guid><description>The Grammar VAE encodes and decodes molecular parse trees from context-free grammars, guaranteeing syntactically valid SMILES outputs during generation.</description><content:encoded><![CDATA[<h2 id="a-grammar-constrained-vae-for-discrete-data-generation">A Grammar-Constrained VAE for Discrete Data Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces the Grammar Variational Autoencoder (GVAE), a variant of the <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoder</a> that operates directly on parse trees from context-free grammars (CFGs) rather than on raw character sequences. The primary contribution is a decoding mechanism that uses a stack and grammar-derived masks to restrict the output at every timestep to only syntactically valid production rules. This guarantees that every decoded output is a valid string under the grammar, addressing a fundamental limitation of character-level VAEs when applied to structured discrete data such as <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> molecular strings and arithmetic expressions.</p>
<h2 id="why-character-level-vaes-fail-on-structured-discrete-data">Why Character-Level VAEs Fail on Structured Discrete Data</h2>
<p>Generative models for continuous data (images, audio) had achieved impressive results by 2017, but generating structured discrete data remained difficult. The key challenge is that string representations of molecules and mathematical expressions are brittle: small perturbations to a character sequence often produce invalid outputs. <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a> demonstrated a character-level VAE (CVAE) for SMILES strings that could encode molecules into a continuous latent space and decode them back, enabling latent-space optimization for molecular design. However, the CVAE frequently decoded latent points into strings that were not valid SMILES, particularly when exploring regions of latent space far from training data.</p>
<p>The fundamental issue is that character-level decoders must implicitly learn the syntactic rules of the target language from data alone. For SMILES, this includes matching parentheses, valid atom types, proper bonding, and ring closure notation. The GVAE addresses this by giving the decoder explicit knowledge of the grammar, so it can focus entirely on learning the semantic structure of the data.</p>
<h2 id="core-innovation-stack-based-grammar-masking-in-the-decoder">Core Innovation: Stack-Based Grammar Masking in the Decoder</h2>
<p>The GVAE encodes and decodes sequences of production rules from a context-free grammar rather than sequences of characters.</p>
<p><strong>Encoding.</strong> Given an input string (e.g., a SMILES molecule), the encoder first parses it into a parse tree using the CFG, then performs a left-to-right pre-order traversal of the tree to extract an ordered sequence of production rules. Each rule is represented as a one-hot vector of dimension $K$ (total number of production rules in the grammar). The resulting $T(\mathbf{X}) \times K$ matrix is processed by a convolutional neural network to produce the mean and variance of a Gaussian posterior $q_{\phi}(\mathbf{z} \mid \mathbf{X})$.</p>
<p><strong>Decoding with grammar masks.</strong> The decoder maps a latent vector $\mathbf{z}$ through an RNN to produce a matrix of logits $\mathbf{F} \in \mathbb{R}^{T_{max} \times K}$. The key innovation is a last-in first-out (LIFO) stack that tracks the current parsing state. At each timestep $t$, the decoder:</p>
<ol>
<li>Pops the top non-terminal $\alpha$ from the stack</li>
<li>Applies a fixed binary mask $\mathbf{m}_{\alpha} \in {0, 1}^K$ that zeros out all production rules whose left-hand side is not $\alpha$</li>
<li>Samples a production rule from the masked softmax distribution:</li>
</ol>
<p>$$
p(\mathbf{x}_{t} = k \mid \alpha, \mathbf{z}) = \frac{m_{\alpha,k} \exp(f_{tk})}{\sum_{j=1}^{K} m_{\alpha,j} \exp(f_{tj})}
$$</p>
<ol start="4">
<li>Pushes the right-hand-side non-terminals of the selected rule onto the stack (right-to-left, so the leftmost is on top)</li>
</ol>
<p>This process continues until the stack is empty or $T_{max}$ timesteps are reached. Because the mask restricts selection to only those rules applicable to the current non-terminal, every generated sequence of production rules is guaranteed to be a valid derivation under the grammar.</p>
<p><strong>Training.</strong> The model is trained by maximizing the ELBO:</p>
<p>$$
\mathcal{L}(\phi, \theta; \mathbf{X}) = \mathbb{E}_{q(\mathbf{z} \mid \mathbf{X})} \left[ \log p_{\theta}(\mathbf{X}, \mathbf{z}) - \log q_{\phi}(\mathbf{z} \mid \mathbf{X}) \right]
$$</p>
<p>where the likelihood factorizes as:</p>
<p>$$
p(\mathbf{X} \mid \mathbf{z}) = \prod_{t=1}^{T(\mathbf{X})} p(\mathbf{x}_{t} \mid \mathbf{z})
$$</p>
<p>During training, the masks at each timestep are determined by the ground-truth production rule sequence, so no stack simulation is needed. The stack-based decoding is only required at generation time.</p>
<p><strong>Syntactic vs. semantic validity.</strong> The grammar guarantees syntactic validity but not semantic validity. The GVAE can still produce chemically implausible molecules (e.g., an oxygen atom with three bonds) because such constraints are not context-free. SMILES ring-bond digit matching is also not context-free, so the grammar cannot enforce it. Additionally, sequences that have not emptied the stack by $T_{max}$ are marked invalid.</p>
<h2 id="experiments-on-symbolic-regression-and-molecular-optimization">Experiments on Symbolic Regression and Molecular Optimization</h2>
<p>The authors evaluate the GVAE on two domains: arithmetic expressions and molecules. Both use Bayesian optimization (BO) over the learned latent space.</p>
<p><strong>Setup.</strong> After training each VAE, the authors encode training data into latent vectors and train a sparse Gaussian process (SGP) with 500 inducing points to predict properties from latent representations. They then run batch BO with expected improvement, selecting 50 candidates per iteration.</p>
<h3 id="arithmetic-expressions">Arithmetic Expressions</h3>
<ul>
<li><strong>Data</strong>: 100,000 randomly generated univariate expressions from a simple grammar (3 binary operators, 2 unary operators, 3 constants), each with at most 15 production rules</li>
<li><strong>Target</strong>: Find an expression minimizing $\log(1 + \text{MSE})$ against the true function $1/3 + x + \sin(x \cdot x)$</li>
<li><strong>BO iterations</strong>: 5, averaged over 10 repetitions</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Fraction Valid</th>
          <th>Average Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GVAE</td>
          <td>0.99 +/- 0.01</td>
          <td>3.47 +/- 0.24</td>
      </tr>
      <tr>
          <td>CVAE</td>
          <td>0.86 +/- 0.06</td>
          <td>4.75 +/- 0.25</td>
      </tr>
  </tbody>
</table>
<p>The GVAE&rsquo;s best expression ($x/1 + \sin(3) + \sin(x \cdot x)$, score 0.04) nearly exactly recovers the true function, while the CVAE&rsquo;s best ($x \cdot 1 + \sin(3) + \sin(3/1)$, score 0.39) misses the sinusoidal component.</p>
<h3 id="molecular-optimization">Molecular Optimization</h3>
<ul>
<li><strong>Data</strong>: 250,000 SMILES strings from the ZINC database</li>
<li><strong>Target</strong>: Maximize penalized logP (water-octanol partition coefficient penalized for ring size and synthetic accessibility)</li>
<li><strong>BO iterations</strong>: 10, averaged over 5 trials</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Fraction Valid</th>
          <th>Average Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GVAE</td>
          <td>0.31 +/- 0.07</td>
          <td>-9.57 +/- 1.77</td>
      </tr>
      <tr>
          <td>CVAE</td>
          <td>0.17 +/- 0.05</td>
          <td>-54.66 +/- 2.66</td>
      </tr>
  </tbody>
</table>
<p>The GVAE produces roughly twice as many valid molecules as the CVAE and finds molecules with substantially better penalized logP scores (best: 2.94 vs. 1.98).</p>
<h3 id="latent-space-quality">Latent Space Quality</h3>
<p>Interpolation experiments show that the GVAE produces valid outputs at every intermediate point when linearly interpolating between two encoded expressions, while the CVAE passes through invalid strings. Grid searches around encoded molecules in the GVAE latent space show smooth transitions where neighboring points differ by single atoms.</p>
<h3 id="predictive-performance">Predictive Performance</h3>
<p>Sparse GP models trained on GVAE latent features achieve better test RMSE and log-likelihood than those trained on CVAE features for both expressions and molecules:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GVAE (Expressions)</th>
          <th>CVAE (Expressions)</th>
          <th>GVAE (Molecules)</th>
          <th>CVAE (Molecules)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Test LL</td>
          <td>-1.320 +/- 0.001</td>
          <td>-1.397 +/- 0.003</td>
          <td>-1.739 +/- 0.004</td>
          <td>-1.812 +/- 0.004</td>
      </tr>
      <tr>
          <td>Test RMSE</td>
          <td>0.884 +/- 0.002</td>
          <td>0.975 +/- 0.004</td>
          <td>1.404 +/- 0.006</td>
          <td>1.504 +/- 0.006</td>
      </tr>
  </tbody>
</table>
<h3 id="reconstruction-and-prior-sampling">Reconstruction and Prior Sampling</h3>
<p>On held-out molecules, the GVAE achieves 53.7% reconstruction accuracy vs. 44.6% for the CVAE. When sampling from the prior $p(\mathbf{z}) = \mathcal{N}(0, \mathbf{I})$, 7.2% of GVAE samples are valid molecules vs. 0.7% for the CVAE.</p>
<h2 id="key-findings-limitations-and-impact">Key Findings, Limitations, and Impact</h2>
<p><strong>Key findings.</strong> Incorporating grammar structure into the VAE decoder consistently improves validity rates, latent space smoothness, downstream predictive performance, and Bayesian optimization outcomes across both domains. The approach is general: any domain with a context-free grammar can benefit.</p>
<p><strong>Limitations acknowledged by the authors.</strong></p>
<ul>
<li>The GVAE guarantees syntactic but not semantic validity. For molecules, invalid ring-bond patterns and chemically implausible structures can still be generated.</li>
<li>The molecular validity rate during BO (31%) is substantially higher than the CVAE (17%) but still means most decoded molecules are invalid, largely due to non-context-free constraints in SMILES.</li>
<li>The approach requires a context-free grammar for the target domain, which limits applicability to well-defined formal languages.</li>
<li>Sequences that do not complete parsing within $T_{max}$ timesteps are discarded as invalid.</li>
</ul>
<p><strong>Impact.</strong> The GVAE was an influential early contribution to constrained molecular generation. It directly inspired the Syntax-Directed VAE (SD-VAE) by Dai et al. (2018), which uses attribute grammars for tighter semantic constraints, and contributed to the broader movement toward structured molecular generation methods including graph-based approaches. The paper demonstrated that encoding domain knowledge into the decoder architecture is more effective than relying on the model to learn structural constraints from data alone.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (expressions)</td>
          <td>Generated arithmetic expressions</td>
          <td>100,000</td>
          <td>Up to 15 production rules each</td>
      </tr>
      <tr>
          <td>Training (molecules)</td>
          <td>ZINC database subset</td>
          <td>250,000 SMILES</td>
          <td>Same subset as <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a></td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder: 1D convolutional neural network over one-hot rule sequences</li>
<li>Decoder: RNN with stack-based grammar masking</li>
<li>Latent space: 56 dimensions (molecules), isotropic Gaussian prior</li>
<li>Property predictor: Sparse Gaussian process with 500 inducing points</li>
<li>Optimization: Batch Bayesian optimization with expected improvement, 50 candidates per iteration, Kriging Believer for batch selection</li>
</ul>
<h3 id="models">Models</h3>
<p>Architecture details follow <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016)</a> with modifications for grammar-based encoding/decoding. Specific layer sizes and hyperparameters are described in the supplementary material.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>GVAE</th>
          <th>CVAE</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fraction valid (expressions)</td>
          <td>0.99</td>
          <td>0.86</td>
          <td>During BO</td>
      </tr>
      <tr>
          <td>Fraction valid (molecules)</td>
          <td>0.31</td>
          <td>0.17</td>
          <td>During BO</td>
      </tr>
      <tr>
          <td>Best penalized logP</td>
          <td>2.94</td>
          <td>1.98</td>
          <td>Best molecule found</td>
      </tr>
      <tr>
          <td>Reconstruction accuracy</td>
          <td>53.7%</td>
          <td>44.6%</td>
          <td>On held-out molecules</td>
      </tr>
      <tr>
          <td>Prior validity</td>
          <td>7.2%</td>
          <td>0.7%</td>
          <td>Sampling from N(0,I)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mkusner/grammarVAE">grammarVAE</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kusner, M. J., Paige, B., &amp; Hernández-Lobato, J. M. (2017). Grammar Variational Autoencoder. <em>Proceedings of the 34th International Conference on Machine Learning (ICML)</em>, 1945-1954.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kusner2017grammar,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Grammar Variational Autoencoder}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kusner, Matt J. and Paige, Brooks and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 34th International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1945--1954}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DrugEx v2: Pareto Multi-Objective RL for Drug Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/drugex-v2-pareto-multi-objective-rl/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/drugex-v2-pareto-multi-objective-rl/</guid><description>DrugEx v2 extends RNN-based de novo drug design with Pareto ranking and evolutionary exploration for multi-objective molecule generation.</description><content:encoded><![CDATA[<h2 id="multi-objective-de-novo-drug-design-with-pareto-optimization">Multi-Objective De Novo Drug Design with Pareto Optimization</h2>
<p>This is a <strong>Method</strong> paper that extends the DrugEx framework (v1) to handle multi-objective optimization in de novo drug design. The primary contribution is integrating Pareto-based ranking with evolutionary algorithm concepts (crossover and mutation) into an RNN-based reinforcement learning pipeline. The system generates <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>-based molecules optimized simultaneously for activity toward multiple protein targets while avoiding off-targets, addressing polypharmacology scenarios where drugs must bind multiple specific receptors.</p>
<h2 id="polypharmacology-and-the-limits-of-single-objective-generation">Polypharmacology and the Limits of Single-Objective Generation</h2>
<p>Traditional drug discovery follows the &ldquo;one drug, one target, one disease&rdquo; paradigm, but drug molecules interact with an average of six protein targets. Off-target binding causes side effects that remain a leading cause of clinical failure and post-approval drug withdrawals (over 500 drugs withdrawn due to fatal toxicity). Complex diseases often require modulating multiple targets simultaneously, making polypharmacology an important design objective.</p>
<p>Prior deep learning approaches for de novo design, including DrugEx v1, focused on generating molecules active against a single target. Extending these methods to multiple objectives introduces fundamental challenges: objectives are often contradictory (high affinity for one target may correlate with high affinity for an undesired off-target), and naive weighted-sum approaches can collapse diversity by over-optimizing a single dominant objective. The authors specifically target the <a href="https://en.wikipedia.org/wiki/Adenosine_receptor">adenosine receptor</a> system, where $A_1AR$ and $A_{2A}AR$ selectivity profiles matter for therapeutic efficacy, and <a href="https://en.wikipedia.org/wiki/HERG">hERG</a> channel binding must be avoided to prevent cardiac toxicity.</p>
<h2 id="evolutionary-exploration-and-pareto-ranking-in-rl">Evolutionary Exploration and Pareto Ranking in RL</h2>
<p>The core innovation of DrugEx v2 has two components: an evolutionary exploration strategy and Pareto-based reward assignment.</p>
<h3 id="evolutionary-exploration-strategy">Evolutionary Exploration Strategy</h3>
<p>The generation process uses three RNN networks with identical LSTM architectures:</p>
<ul>
<li><strong>Agent net</strong> ($G_A$): the primary generator, updated at each training epoch via policy gradient</li>
<li><strong>Crossover net</strong> ($G_C$): initialized from the fine-tuned model, updated iteratively from $G_A$ after each convergence period</li>
<li><strong>Mutation net</strong> ($G_M$): initialized from the pre-trained model, parameters fixed throughout training</li>
</ul>
<p>At each token-generation step, a random number determines whether the token probability comes from the combination of $G_A$ and $G_C$ (with probability $1 - \varepsilon$) or from $G_M$ (with probability $\varepsilon$). This mirrors crossover and mutation operations from evolutionary algorithms, maintaining diversity while steering toward desired properties.</p>
<h3 id="pareto-front-reward-scheme">Pareto Front Reward Scheme</h3>
<p>For $n$ objectives (three in this study: $A_1AR$, $A_{2A}AR$, hERG), each molecule receives a score $R_i$ based on its predicted bioactivity:</p>
<p>$$
R_{i} = \begin{cases} \text{minmax}(pX_{i}), &amp; \text{if high affinity required} \\ 1 - \text{minmax}(pX_{i}), &amp; \text{if low affinity required} \\ 0, &amp; \text{if SMILES invalid} \end{cases}
$$</p>
<p>where $pX_i$ is the predicted bioactivity (range 3.0 to 10.0), normalized to [0, 1].</p>
<p>For the multi-target case, high affinity is required for both $A_1AR$ and $A_{2A}AR$ while low affinity is required for hERG. For the target-specific case, high affinity is required only for $A_{2A}AR$ while low affinity is required for both $A_1AR$ and hERG.</p>
<p>Molecules are ranked using a <a href="https://en.wikipedia.org/wiki/Multi-objective_optimization">non-dominated sorting</a> algorithm to construct Pareto fronts. Within each front, molecules are ranked by average Tanimoto distance (using ECFP6 fingerprints) rather than crowding distance, favoring chemically diverse solutions. The final reward is:</p>
<p>$$
R_i^{*} = \begin{cases} 0.5 + \frac{k - N_{undesired}}{2N_{desired}}, &amp; \text{if desired} \\ \frac{k}{2N_{undesired}}, &amp; \text{if undesired} \end{cases}
$$</p>
<p>where $k$ is the molecule&rsquo;s index in the Pareto rank. Rewards for undesired and desired solutions are distributed in $(0, 0.5]$ and $(0.5, 1.0]$, respectively.</p>
<p>The agent is trained via policy gradient:</p>
<p>$$
J(\theta) = \mathbb{E}\left[R^{*}(y_{1:T}) \middle|\theta\right] = \sum_{t=1}^{T} \log G(y_t | y_{1:t-1}) \cdot R^{*}(y_{1:T})
$$</p>
<h3 id="weighted-sum-alternative">Weighted Sum Alternative</h3>
<p>The authors also implement a weighted sum (WS) scheme with dynamic weights proportional to the ratio of undesired to desired molecules per objective:</p>
<p>$$
w_i = \frac{r_i}{\sum_{k=1}^{M} r_k}, \quad R^{*} = \sum_{i=1}^{n} w_i R_i
$$</p>
<p>This auto-adjusts importance toward under-performing objectives during training.</p>
<h3 id="molecular-diversity-metric">Molecular Diversity Metric</h3>
<p>Diversity is measured using the Solow-Polasky metric adapted from ecological biodiversity:</p>
<p>$$
I(A) = \frac{1}{|A|} \mathbf{e}^{\top} F(\mathbf{s})^{-1} \mathbf{e}
$$</p>
<p>where $F(\mathbf{s})$ is a distance matrix with entries $f(d_{ij}) = e^{-\theta d_{ij}}$ and $d_{ij}$ is the Tanimoto distance between ECFP6 fingerprints of molecules $s_i$ and $s_j$.</p>
<h2 id="multi-target-and-target-specific-experiments">Multi-Target and Target-Specific Experiments</h2>
<h3 id="qsar-environment">QSAR Environment</h3>
<p>Four ML algorithms were benchmarked for the bioactivity prediction environment: Random Forest (RF), SVM, PLS, and Multi-task DNN (MT-DNN). Input features combined 2048-bit ECFP6 fingerprints with 19 physicochemical descriptors (2067D total). The training data came from ChEMBL v26: 25,731 ligands with bioactivity measurements toward $A_1AR$, $A_{2A}AR$, and hERG. RF was selected as the final predictor based on superior performance in temporal-split independent testing ($R^2$ and RMSE), prioritizing robustness over cross-validation metrics.</p>
<h3 id="generative-model-architecture">Generative Model Architecture</h3>
<p>The RNN generator uses six layers: input, embedding (128D), three LSTM recurrent layers (512 hidden units), and output. LSTM was chosen over GRU based on higher valid SMILES rates (97.5% vs. 93.1% for pre-trained, 97.9% vs. 95.7% for fine-tuned). Pre-training used 1.7M molecules from ChEMBL; fine-tuning used the 25,731 LIGAND set molecules.</p>
<h3 id="baselines">Baselines</h3>
<p>DrugEx v2 was compared against DrugEx v1, <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a>, and <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGANIC</a>, all using the same RNN architecture and pre-trained/fine-tuned models, with only the RL framework differing. Both Pareto front (PF) and weighted sum (WS) reward schemes were tested.</p>
<h3 id="multi-target-results">Multi-Target Results</h3>
<p>In the multi-target case (high affinity for $A_1AR$ and $A_{2A}AR$, low affinity for hERG):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Scheme</th>
          <th>Validity</th>
          <th>Desirability</th>
          <th>Uniqueness</th>
          <th>Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DrugEx v2</td>
          <td>PF</td>
          <td>99.57%</td>
          <td>80.81%</td>
          <td>87.29%</td>
          <td>0.70</td>
      </tr>
      <tr>
          <td>DrugEx v2</td>
          <td>WS</td>
          <td>99.80%</td>
          <td><strong>97.45%</strong></td>
          <td>89.08%</td>
          <td>0.49</td>
      </tr>
      <tr>
          <td>REINVENT</td>
          <td>PF</td>
          <td>99.54%</td>
          <td>57.43%</td>
          <td><strong>98.84%</strong></td>
          <td><strong>0.77</strong></td>
      </tr>
      <tr>
          <td>ORGANIC</td>
          <td>PF</td>
          <td>98.84%</td>
          <td>66.01%</td>
          <td>82.67%</td>
          <td>0.65</td>
      </tr>
      <tr>
          <td>DrugEx v1</td>
          <td>PF</td>
          <td>98.28%</td>
          <td>43.27%</td>
          <td>88.96%</td>
          <td>0.71</td>
      </tr>
  </tbody>
</table>
<p>DrugEx v2 achieved the highest desirability under both schemes. The WS scheme maximized desirability (97.45%) but at the cost of diversity (0.49). The PF scheme maintained higher diversity (0.70) with still-strong desirability (80.81%).</p>
<h3 id="target-specific-results">Target-Specific Results</h3>
<p>In the target-specific case (high $A_{2A}AR$, low $A_1AR$ and hERG):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Scheme</th>
          <th>Validity</th>
          <th>Desirability</th>
          <th>Uniqueness</th>
          <th>Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DrugEx v2</td>
          <td>PF</td>
          <td>99.53%</td>
          <td><strong>89.49%</strong></td>
          <td>90.55%</td>
          <td>0.73</td>
      </tr>
      <tr>
          <td>DrugEx v2</td>
          <td>WS</td>
          <td>99.62%</td>
          <td><strong>97.86%</strong></td>
          <td>90.54%</td>
          <td>0.31</td>
      </tr>
      <tr>
          <td>REINVENT</td>
          <td>WS</td>
          <td>99.55%</td>
          <td>81.27%</td>
          <td>98.87%</td>
          <td>0.34</td>
      </tr>
      <tr>
          <td>ORGANIC</td>
          <td>PF</td>
          <td>98.29%</td>
          <td>86.98%</td>
          <td>80.30%</td>
          <td>0.64</td>
      </tr>
  </tbody>
</table>
<p>DrugEx v2 with PF achieved high desirability (89.49%) while maintaining diversity (0.73), outperforming both the WS scheme&rsquo;s diversity collapse (0.31) and competing methods.</p>
<h3 id="chemical-space-coverage">Chemical Space Coverage</h3>
<p>t-SNE visualization with ECFP6 descriptors showed that the PF scheme guided generators to cover chemical space more broadly than the WS scheme. DrugEx v1 and v2 covered nearly all of the chemical space occupied by known active ligands, while REINVENT and ORGANIC covered only partial regions in the target-specific case.</p>
<h3 id="substructure-distribution">Substructure Distribution</h3>
<p>Generated molecules were evaluated for purine ring, furan ring, and benzene ring frequencies. DrugEx v2 with PF produced substructure distributions closest to the LIGAND set, suggesting it better preserves the chemical characteristics of known active molecules compared to REINVENT (which over-represented benzene rings) and ORGANIC.</p>
<h3 id="guacamol-benchmark">GuacaMol Benchmark</h3>
<p>DrugEx v2 was tested on 20 goal-directed tasks from the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> benchmark, achieving the best score in 12 of 20 tasks and an overall second place. The method struggled with tasks requiring contradictory objectives in narrow chemical spaces (e.g., the Sitagliptin MPO task), reflecting its emphasis on diverse feasible molecules rather than optimal individual solutions.</p>
<h2 id="diversity-desirability-trade-off-and-limitations">Diversity-Desirability Trade-off and Limitations</h2>
<p>The key finding is that the Pareto front scheme and weighted sum scheme offer complementary strengths: PF produces molecules with higher diversity and more realistic substructure distributions, while WS achieves higher raw desirability scores. The Pareto front scheme is preferred for polypharmacology applications where chemical diversity matters for lead optimization.</p>
<p>The mutation rate $\varepsilon$ controls the diversity-desirability trade-off. Higher $\varepsilon$ increases diversity at the cost of desirability. The authors tested $\varepsilon \in {10^{-2}, 10^{-3}, 10^{-4}, 0}$ and found that appropriate tuning is important.</p>
<p>Limitations acknowledged by the authors include:</p>
<ul>
<li>The method is less effective for tasks with contradictory objectives in narrow chemical spaces</li>
<li>Emphasis is on generating diverse feasible molecules rather than individual optimal solutions</li>
<li>REINVENT 2.0 did not converge with the PF scheme, suggesting the Pareto approach may not be universally compatible with all RL frameworks</li>
<li>Bioactivity predictions rely on QSAR models (RF), which may not generalize perfectly to novel chemical scaffolds</li>
</ul>
<p>Future directions mentioned include adopting newer architectures (BERT, Transformer, GPT-2), handling graph and fragment representations, and integrating additional objectives like stability and synthesizability.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL v26 (ChEMBL set)</td>
          <td>1.7M molecules</td>
          <td>SMILES syntax learning, drug-like molecules</td>
      </tr>
      <tr>
          <td>Fine-tuning / Environment</td>
          <td>LIGAND set</td>
          <td>25,731 ligands</td>
          <td>Bioactivities for $A_1AR$, $A_{2A}AR$, hERG from ChEMBL</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>GuacaMol</td>
          <td>20 tasks</td>
          <td>Goal-directed generation tasks</td>
      </tr>
  </tbody>
</table>
<p>Active/inactive thresholds: $pX \geq 6.5$ (active), $pX &lt; 6.5$ (inactive). Low-quality data without exact pX assigned $pX = 3.99$ with sample weight 0.1.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>QSAR predictor</strong>: Random Forest, 1000 trees, Gini criterion. Input: 2048-bit ECFP6 + 19 physicochemical properties (2067D). MinMax normalization.</li>
<li><strong>Generator</strong>: 6-layer RNN with LSTM cells (512 hidden units), embedding dim 128, vocabulary 84 tokens. Adam optimizer, lr $10^{-3}$, batch size 512, 1000 epochs.</li>
<li><strong>RL training</strong>: Policy gradient with Pareto-based or weighted-sum reward. Mutation rates tested: $\varepsilon \in {10^{-2}, 10^{-3}, 10^{-4}, 0}$.</li>
<li><strong>Pareto ranking</strong>: GPU-accelerated non-dominated sorting via PyTorch. Tanimoto-based crowding distance with ECFP6 fingerprints.</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Generator</td>
          <td>LSTM (3 layers, 512 hidden)</td>
          <td>Embedding 128D, vocab 84</td>
      </tr>
      <tr>
          <td>Predictor</td>
          <td>Random Forest</td>
          <td>1000 trees, 2067D input</td>
      </tr>
      <tr>
          <td>MT-DNN (alternative)</td>
          <td>3 hidden layers (4000, 2000, 1000)</td>
          <td>ReLU, 20% dropout</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>Fraction of generated SMILES that parse to valid molecules</td>
      </tr>
      <tr>
          <td>Desirability</td>
          <td>Fraction of molecules meeting all activity thresholds ($pX \geq 6.5$ on-targets, $pX &lt; 6.5$ off-targets)</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>Fraction of non-duplicate molecules</td>
      </tr>
      <tr>
          <td>Diversity</td>
          <td>Solow-Polasky metric on ECFP6 Tanimoto distances</td>
      </tr>
      <tr>
          <td>SA score</td>
          <td>Synthetic accessibility (1-10, lower is easier)</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>Quantitative estimate of drug-likeness (0-1, higher is better)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>GPU acceleration was used for Pareto optimization via PyTorch. Specific hardware details (GPU model, training time) are not reported in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XuhanLiu/DrugEx">DrugEx GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation (Python, PyTorch)</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL v26</a></td>
          <td>Dataset</td>
          <td>CC BY-SA 3.0</td>
          <td>Source of training molecules and bioactivity data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, X., Ye, K., van Vlijmen, H. W. T., Emmerich, M. T. M., IJzerman, A. P., &amp; van Westen, G. J. P. (2021). DrugEx v2: de novo design of drug molecules by Pareto-based multi-objective reinforcement learning in polypharmacology. <em>Journal of Cheminformatics</em>, 13(1), 85. <a href="https://doi.org/10.1186/s13321-021-00561-9">https://doi.org/10.1186/s13321-021-00561-9</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{liu2021drugex,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DrugEx v2: de novo design of drug molecules by Pareto-based multi-objective reinforcement learning in polypharmacology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Liu, Xuhan and Ye, Kai and van Vlijmen, Herman W. T. and Emmerich, Michael T. M. and IJzerman, Adriaan P. and van Westen, Gerard J. P.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{85}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-021-00561-9}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DrugChat: Conversational QA on Drug Molecule Graphs</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugchat-chatgpt-drug-molecule-graphs/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugchat-chatgpt-drug-molecule-graphs/</guid><description>DrugChat connects a GNN molecular encoder with Vicuna-13B via a linear adaptor, enabling multi-turn conversational QA about drug compound graphs.</description><content:encoded><![CDATA[<h2 id="a-prototype-for-conversational-drug-compound-analysis">A Prototype for Conversational Drug Compound Analysis</h2>
<p><strong>Method ($\Psi_{\text{Method}}$)</strong></p>
<p>DrugChat is a prototype system that enables ChatGPT-like conversational interaction with drug molecule graphs. Users upload a compound&rsquo;s molecular graph and ask free-form, multi-turn questions about its properties, mechanism of action, or therapeutic applications. The system generates natural language answers by combining a graph neural network (GNN) encoder, a large language model (LLM), and a lightweight linear adaptor that bridges the two modalities. The primary contribution is the architecture and the accompanying instruction tuning datasets (10,834 drug compounds, 143,517 QA pairs) that make this graph-to-language interaction possible.</p>
<h2 id="why-conversational-interfaces-for-drug-molecules">Why Conversational Interfaces for Drug Molecules?</h2>
<p>Drug discovery is time-intensive and expensive, often requiring years and billions of dollars to bring a single compound to market. Traditional computational chemistry tools provide specialized outputs but lack the ability to support open-ended, interactive exploration of molecular properties. Researchers working with drug compound data frequently need quick answers to diverse questions: What is the mechanism of action? Are there known drug interactions? What structural modifications could improve efficacy?</p>
<p>At the time of this work, large language models had demonstrated strong conversational capabilities for text, and multimodal extensions (MiniGPT-4, LLaVA) had connected vision encoders to LLMs. However, no system had bridged graph-structured molecular data with LLMs for interactive dialogue. DrugChat addresses this gap by proposing the first system (to the authors&rsquo; knowledge) that connects molecular graph representations directly to an LLM for multi-turn question answering.</p>
<h2 id="architecture-gnn-adaptor-llm-pipeline">Architecture: GNN-Adaptor-LLM Pipeline</h2>
<p>The core innovation is the three-component architecture and its training strategy:</p>
<p><strong>Graph Neural Network (GNN)</strong>: A pre-trained GNN from Hu et al. (2020) processes the compound&rsquo;s molecular graph. At each layer $k$, node representations are updated by aggregating features from neighboring nodes:</p>
<p>$$
h_{v}^{k} = \sigma\left(h_{v}^{k-1}, \text{AGG}\left(\left\{h_{u}^{k-1}, u \in \mathcal{N}(v)\right\}\right)\right)
$$</p>
<p>A permutation-invariant pooling function produces the graph-level representation:</p>
<p>$$
h_{G} = f\left(\left\{h_{v}^{K}, v \in G\right\}\right)
$$</p>
<p><strong>Linear Adaptor</strong>: A single linear transformation matrix converts the GNN graph representation into a soft prompt vector compatible with the LLM&rsquo;s input space. This is the only component whose weights are updated during training.</p>
<p><strong>Large Language Model (Vicuna-13B)</strong>: The pre-trained Vicuna-13B model takes the transformed graph prompt vector along with user questions and generates answers. Both the GNN and LLM weights remain frozen during training.</p>
<p>The prompt template follows the Vicuna conversational format:</p>
<p>$$
\mathbf{Q}: \langle\text{Graph}\rangle\langle\text{GraphFeature}\rangle\langle/\text{Graph}\rangle\langle\text{Instruction}\rangle \quad \mathbf{A}: \langle\text{Desc}\rangle
$$</p>
<p>During training, the system minimizes a negative log-likelihood loss between generated and ground-truth answers. The entire training procedure updates only the adaptor&rsquo;s parameters, making the approach computationally lightweight compared to full fine-tuning.</p>
<h2 id="instruction-tuning-datasets-from-chembl-and-pubchem">Instruction Tuning Datasets from ChEMBL and PubChem</h2>
<p>The authors constructed two instruction tuning datasets:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Drug Compounds</th>
          <th>QA Pairs</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChEMBL</td>
          <td>3,892</td>
          <td>129,699</td>
          <td>ChEMBL database (Feb 2023)</td>
      </tr>
      <tr>
          <td>PubChem</td>
          <td>6,942</td>
          <td>13,818</td>
          <td>PubChem (May 2023)</td>
      </tr>
      <tr>
          <td><strong>Total</strong></td>
          <td><strong>10,834</strong></td>
          <td><strong>143,517</strong></td>
          <td></td>
      </tr>
  </tbody>
</table>
<p><strong>ChEMBL Dataset</strong>: Starting from 2,354,965 compounds in <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, the authors identified 14,816 with drug information and filtered to 3,892 with sufficient descriptive content. For each drug, they gathered <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, molecular features (formula, acid/base classification), and drug-specific properties (mechanism of action, therapeutic applications). They manually crafted QA pairs covering topics like rotatable bond count, <a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski rule</a> violations, <a href="https://en.wikipedia.org/wiki/Chirality_(chemistry)">chirality</a>, <a href="https://en.wikipedia.org/wiki/Polar_surface_area">polar surface area</a>, development stage, approval year, and <a href="https://en.wikipedia.org/wiki/United_States_Adopted_Name">USAN</a> classification.</p>
<p><strong>PubChem Dataset</strong>: From 66,469,244 compounds in <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>, 19,319 had drug information, and 6,942 were retained after filtering for detailed descriptions. Descriptions were sourced from <a href="https://en.wikipedia.org/wiki/ChEBI">ChEBI</a>, LOTUS, and YMDB databases, yielding 13,818 QA pairs primarily asking for drug descriptions.</p>
<p>The QA pairs are formulaic: the ChEMBL set covers up to 34 question types per drug (an example drug in the paper shows all 34), while PubChem questions ask for descriptive summaries from different source databases.</p>
<h2 id="qualitative-demonstrations-only">Qualitative Demonstrations Only</h2>
<p>The paper presents only qualitative results. Two demonstration examples show DrugChat answering multi-turn questions about test compounds not seen during training. Questions like &ldquo;what makes this compound unique?&rdquo; and &ldquo;what diseases can this compound potentially treat?&rdquo; are answered in natural language.</p>
<p>No systematic quantitative evaluation is reported. The authors state they &ldquo;will perform a systematic quantitative evaluation by collaborating with pharmaceutical scientists,&rdquo; but this evaluation is not included in the technical report.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>The authors identify <strong>language hallucination</strong> as the primary limitation. Since DrugChat incorporates an LLM, it may produce convincing but incorrect text descriptions about drugs, which could mislead decision-makers in real drug discovery pipelines.</p>
<p>Proposed mitigations include:</p>
<ul>
<li>Higher-quality training data and filtering strategies</li>
<li>More advanced GNN encoders and LLMs</li>
<li>Reinforcement learning from human feedback (RLHF) as the user base grows</li>
</ul>
<p>Several additional limitations are worth noting:</p>
<ul>
<li>The QA pairs are largely factoid-style questions with short, formulaic answers, which may not capture the nuanced reasoning needed for real drug discovery tasks</li>
<li>The evaluation is entirely qualitative, with no comparison to baselines or quantitative metrics</li>
<li>The linear adaptor is a minimal alignment mechanism; it remains unclear how much molecular structural information is preserved through this single linear transformation</li>
<li>The training data covers only a small fraction of known chemical space (10,834 compounds out of millions)</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL Drug Instruction Tuning</td>
          <td>3,892 drugs, 129,699 QA pairs</td>
          <td>From ChEMBL (Feb 2023 dump)</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>PubChem Drug Instruction Tuning</td>
          <td>6,942 drugs, 13,818 QA pairs</td>
          <td>From PubChem (May 2023)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>GNN</strong>: Pre-trained model from Hu et al. (2020), &ldquo;Strategies for Pre-training Graph Neural Networks&rdquo;</li>
<li><strong>Adaptor</strong>: Single linear transformation matrix (only trainable component)</li>
<li><strong>Loss</strong>: Negative log-likelihood between generated and ground-truth answers</li>
<li><strong>Training</strong>: Only adaptor weights updated; GNN and LLM weights frozen</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Model</th>
          <th>Parameters</th>
          <th>Status</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GNN Encoder</td>
          <td>Pre-trained GNN (Hu et al., 2020)</td>
          <td>Not specified</td>
          <td>Frozen during training</td>
      </tr>
      <tr>
          <td>LLM</td>
          <td>Vicuna-13B</td>
          <td>~13B</td>
          <td>Frozen during training</td>
      </tr>
      <tr>
          <td>Adaptor</td>
          <td>Linear projection</td>
          <td>Not specified</td>
          <td>Trained</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>No quantitative evaluation metrics are reported. The paper provides only qualitative demonstrations on unseen compounds.</p>
<h3 id="hardware">Hardware</h3>
<p>No hardware specifications are reported for training or inference.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/UCSD-AI4H/drugchat">DrugChat Code</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation (repository returned 404 as of March 2026)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liang, Y., Zhang, R., Zhang, L., &amp; Xie, P. (2023). DrugChat: Towards Enabling ChatGPT-Like Capabilities on Drug Molecule Graphs. <em>arXiv preprint arXiv:2309.03907</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{liang2023drugchat,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DrugChat: Towards Enabling ChatGPT-Like Capabilities on Drug Molecule Graphs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Liang, Youwei and Zhang, Ruiyi and Zhang, Li and Xie, Pengtao}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2309.03907}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DrugAssist: Interactive LLM Molecule Optimization</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugassist-llm-molecule-optimization/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/drugassist-llm-molecule-optimization/</guid><description>DrugAssist fine-tunes Llama2-7B-Chat for interactive molecule optimization via natural language dialogue, releasing the MolOpt-Instructions dataset.</description><content:encoded><![CDATA[<h2 id="an-interactive-llm-for-molecule-optimization">An Interactive LLM for Molecule Optimization</h2>
<p>DrugAssist is a <strong>Method</strong> paper that proposes an interactive molecule optimization model built by fine-tuning Llama2-7B-Chat with LoRA on a newly constructed instruction dataset. The primary contribution is twofold: (1) the MolOpt-Instructions dataset containing over one million molecule pairs with six molecular properties and three optimization task categories, and (2) a dialogue-based molecule optimization system that allows domain experts to iteratively refine molecular modifications through multi-turn natural language conversations.</p>
<h2 id="why-interactive-molecule-optimization-matters">Why Interactive Molecule Optimization Matters</h2>
<p>Molecule optimization is a core step in the drug discovery pipeline, where lead compounds must be modified to improve specific pharmacological properties while maintaining structural similarity. Existing approaches fall into sequence-based methods (treating <a href="/notes/chemistry/molecular-representations/">SMILES</a> optimization as machine translation) and graph-based methods (graph-to-graph translation), but they share a critical limitation: they are non-interactive. These models learn patterns from chemical structure data without incorporating expert feedback.</p>
<p>The drug discovery process is inherently iterative and requires integrating domain expertise. Medicinal chemists typically refine candidates through repeated cycles of suggestion, evaluation, and adjustment. Prior LLM-based approaches like <a href="/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/">ChatDrug</a> relied on prompt engineering with general-purpose models (GPT-3.5-turbo) rather than fine-tuning, limiting their optimization accuracy. Additionally, most existing molecule optimization benchmarks focus on single-property optimization with vague objectives (e.g., &ldquo;maximize QED&rdquo;), while real-world drug design requires optimizing property values within specific ranges across multiple properties simultaneously.</p>
<h2 id="instruction-based-fine-tuning-with-molopt-instructions">Instruction-Based Fine-Tuning with MolOpt-Instructions</h2>
<p>The core innovation has two components: the MolOpt-Instructions dataset construction pipeline and the multi-task instruction tuning strategy.</p>
<h3 id="dataset-construction">Dataset Construction</h3>
<p>MolOpt-Instructions is built from one million molecules randomly sampled from the <a href="/notes/chemistry/datasets/zinc-22/">ZINC database</a>. The construction workflow uses mmpdb (an open-source Matched Molecular Pair platform) to generate structurally similar molecule pairs through <a href="https://en.wikipedia.org/wiki/Matched_molecular_pair_analysis">Matched Molecular Pair Analysis (MMPA)</a>. Pairs are filtered to satisfy two criteria: <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> greater than 0.65 and <a href="https://en.wikipedia.org/wiki/Partition_coefficient">logP</a> difference greater than 2.5. Property values for six properties (Solubility, BBBP, <a href="https://en.wikipedia.org/wiki/KCNH2">hERG</a> inhibition, QED, hydrogen bond donor count, and hydrogen bond acceptor count) are computed using Tencent&rsquo;s iDrug platform. The final dataset contains 1,029,949 unique pairs covering 1,595,839 unique molecules, with mean similarity of 0.69 and mean logP difference of 2.82.</p>
<p>Three categories of optimization tasks are defined:</p>
<ul>
<li><strong>Loose</strong>: Increase or decrease a given property value (no threshold)</li>
<li><strong>Strict</strong>: Increase or decrease by at least a specified threshold</li>
<li><strong>Range</strong>: Optimize the property value to fall within a given interval</li>
</ul>
<p>Instruction templates are generated with ChatGPT assistance and manually refined. To ensure balance, source and target molecules are swapped for some pairs to maintain a roughly 1:1 ratio of property increases to decreases.</p>
<p>Murcko scaffold analysis confirms chemical diversity: the average molecules per scaffold is 2.95, and over 93.7% of scaffolds contain no more than five molecules.</p>
<h3 id="multi-task-instruction-tuning">Multi-Task Instruction Tuning</h3>
<p>The model is fine-tuned on Llama2-7B-Chat using LoRA (rank 64, alpha 128). To prevent catastrophic forgetting of general language capabilities, the training data combines MolOpt-Instructions with the Stanford Alpaca dataset (52k instruction-following examples, replicated 5x to balance the mixture). The training objective minimizes the negative log-likelihood over the response tokens:</p>
<p>$$L(R; \boldsymbol{\theta}) = -\sum_{u_i \in R} \log \Phi(u_i \mid u_{&lt;i}, I)$$</p>
<p>where $I$ is the instruction, $R$ is the response, and $\Phi$ is the model&rsquo;s conditional probability.</p>
<p>Training runs for 10 epochs with batch size 512, using AdamW ($\beta = (0.9, 0.999)$), learning rate 1e-4, 3% warm-up steps with cosine decay, and no weight decay. The data is split 90/5/5 for train/validation/test.</p>
<h2 id="experimental-setup-and-multi-property-optimization-results">Experimental Setup and Multi-Property Optimization Results</h2>
<h3 id="comparison-with-traditional-approaches">Comparison with Traditional Approaches</h3>
<p>DrugAssist is compared against Mol-Seq2Seq and Mol-Transformer (He et al., 2021) on simultaneous Solubility and BBBP optimization with range constraints. The evaluation prompt asks the model to generate an optimized molecule with solubility within a given range and BBBP category changed from one level to another.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Solubility</th>
          <th>BBBP</th>
          <th>Both</th>
          <th>Valid Rate</th>
          <th>Similarity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mol-Seq2Seq</td>
          <td>0.46</td>
          <td>0.55</td>
          <td>0.35</td>
          <td>0.76</td>
          <td>0.61</td>
      </tr>
      <tr>
          <td>Mol-Transformer</td>
          <td>0.70</td>
          <td>0.78</td>
          <td>0.59</td>
          <td>0.96</td>
          <td>0.70</td>
      </tr>
      <tr>
          <td>DrugAssist</td>
          <td>0.74</td>
          <td>0.80</td>
          <td>0.62</td>
          <td>0.98</td>
          <td>0.69</td>
      </tr>
  </tbody>
</table>
<p>DrugAssist achieves the highest success rates in both single-property and multi-property optimization while maintaining high validity (0.98) and comparable structural similarity (0.69).</p>
<h3 id="comparison-with-llms">Comparison with LLMs</h3>
<p>DrugAssist is compared against Llama2-7B-Chat, GPT-3.5-turbo (via ChatDrug), and BioMedGPT-LM-7B on 16 tasks covering all three optimization categories. These comparisons use multi-turn dialogues following the ChatDrug protocol: if the model&rsquo;s output fails to meet requirements, a database-retrieved molecule meeting the criteria and similar to the model&rsquo;s output is provided as a hint for iterative refinement.</p>
<p>Selected results on single-property tasks (valid ratio / correct ratio, loose/strict):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Llama2-7B-Chat</th>
          <th>GPT-3.5-turbo</th>
          <th>BioMedGPT-LM</th>
          <th>DrugAssist</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QED+</td>
          <td>0.17 / 0.16</td>
          <td>0.15 / 0.15</td>
          <td>0.15 / 0.09</td>
          <td>0.76 / 0.63</td>
      </tr>
      <tr>
          <td>Acceptor+</td>
          <td>0.08 / 0.08</td>
          <td>0.04 / 0.06</td>
          <td>0.18 / 0.13</td>
          <td>0.71 / 0.67</td>
      </tr>
      <tr>
          <td>Donor+</td>
          <td>0.15 / 0.08</td>
          <td>0.10 / 0.04</td>
          <td>0.17 / 0.09</td>
          <td>0.72 / 0.76</td>
      </tr>
      <tr>
          <td>Solubility+</td>
          <td>0.36 / 0.20</td>
          <td>0.16 / 0.05</td>
          <td>0.18 / 0.09</td>
          <td>0.80 / 0.41</td>
      </tr>
      <tr>
          <td>BBBP+</td>
          <td>0.19 / 0.14</td>
          <td>0.10 / 0.10</td>
          <td>0.16 / 0.07</td>
          <td>0.82 / 0.61</td>
      </tr>
      <tr>
          <td>hERG-</td>
          <td>0.39 / 0.31</td>
          <td>0.13 / 0.15</td>
          <td>0.13 / 0.12</td>
          <td>0.71 / 0.67</td>
      </tr>
  </tbody>
</table>
<p>Multi-property tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Llama2-7B-Chat</th>
          <th>GPT-3.5-turbo</th>
          <th>BioMedGPT-LM</th>
          <th>DrugAssist</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sol+ &amp; Acc+</td>
          <td>0.15 / 0.04</td>
          <td>0.09 / 0.02</td>
          <td>0.10 / 0.07</td>
          <td>0.50 / 0.27</td>
      </tr>
      <tr>
          <td>QED+ &amp; BBBP+</td>
          <td>0.14 / 0.09</td>
          <td>0.09 / 0.06</td>
          <td>0.16 / 0.11</td>
          <td>0.65 / 0.41</td>
      </tr>
  </tbody>
</table>
<p>DrugAssist outperforms all baselines across every task. BioMedGPT-LM frequently misunderstands the task, generating guidance text rather than molecules. GPT-3.5-turbo achieves high validity but often outputs the input molecule unchanged.</p>
<h2 id="transferability-iterative-refinement-and-limitations">Transferability, Iterative Refinement, and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<p><strong>Zero-shot transferability</strong>: Although DrugAssist trains on single-property optimization data, it successfully handles multi-property optimization requests at inference time. In a case study, the model simultaneously increased both BBBP and QED by at least 0.1 while maintaining structural similarity, without any multi-property training examples.</p>
<p><strong>Few-shot generalization</strong>: DrugAssist optimizes properties not seen during training (e.g., logP) when provided with a few in-context examples of successful optimizations, a capability that traditional sequence-based or graph-based models cannot achieve without retraining.</p>
<p><strong>Iterative optimization</strong>: When an initial optimization fails to meet requirements, DrugAssist can incorporate feedback (a database-retrieved hint molecule) and modify different functional groups in a second attempt to produce a compliant molecule.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge that DrugAssist has a relatively lower success rate on the most challenging task category, strict range-constrained solubility optimization (0.41 success rate under strict criteria vs. 0.80 under loose criteria). The model also relies on iDrug for property prediction of Solubility, BBBP, and hERG inhibition, meaning its optimization quality is bounded by the accuracy of these property predictors. The evaluation uses only 500 test molecules for LLM comparisons, which is a relatively small evaluation set. The paper does not report statistical significance tests or confidence intervals for any results.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors plan to improve multimodal data handling to reduce hallucination problems and to further enhance DrugAssist&rsquo;s interactive capabilities for better understanding of user needs and feedback.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>MolOpt-Instructions</td>
          <td>1,029,949 molecule pairs</td>
          <td>Sourced from ZINC via mmpdb; 6 properties</td>
      </tr>
      <tr>
          <td>Training (auxiliary)</td>
          <td>Stanford Alpaca</td>
          <td>52k instructions (5x replicated)</td>
          <td>Mitigates catastrophic forgetting</td>
      </tr>
      <tr>
          <td>Evaluation (traditional)</td>
          <td>From He et al. (2021)</td>
          <td>Not specified</td>
          <td>Multi-property optimization test</td>
      </tr>
      <tr>
          <td>Evaluation (LLM)</td>
          <td>ZINC subset</td>
          <td>500 molecules</td>
          <td>Randomly selected</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Base model</strong>: Llama2-7B-Chat</li>
<li><strong>Fine-tuning</strong>: LoRA with rank 64, alpha 128</li>
<li><strong>Optimizer</strong>: AdamW, $\beta = (0.9, 0.999)$, lr = 1e-4, no weight decay</li>
<li><strong>Schedule</strong>: 3% warm-up, cosine decay</li>
<li><strong>Epochs</strong>: 10</li>
<li><strong>Batch size</strong>: 512</li>
<li><strong>Property calculation</strong>: iDrug (Solubility, BBBP, hERG); RDKit (H-bond donors/acceptors, QED)</li>
<li><strong>Molecular pairs</strong>: mmpdb for Matched Molecular Pair Analysis</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Fine-tuned Llama2-7B-Chat with LoRA adapters</li>
<li>No pre-trained weights released (code and data available)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success rate</td>
          <td>Fraction of molecules meeting optimization criteria</td>
      </tr>
      <tr>
          <td>Valid rate</td>
          <td>Fraction of generated SMILES that parse as valid molecules</td>
      </tr>
      <tr>
          <td>Similarity</td>
          <td>Tanimoto similarity between input and optimized molecules</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>8 NVIDIA Tesla A100-SXM4-40GB GPUs</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/blazerye/DrugAssist">DrugAssist Code</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Training and inference code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/blazerye/DrugAssist">MolOpt-Instructions</a></td>
          <td>Dataset</td>
          <td>Not specified</td>
          <td>1M+ molecule pairs, 6 properties</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ye, G., Cai, X., Lai, H., Wang, X., Huang, J., Wang, L., Liu, W., &amp; Zeng, X. (2024). DrugAssist: A Large Language Model for Molecule Optimization. <em>Briefings in Bioinformatics</em>, 26(1), bbae693.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ye2024drugassist,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DrugAssist: A Large Language Model for Molecule Optimization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Geyan and Cai, Xibao and Lai, Houtim and Wang, Xing and Huang, Junhong and Wang, Longyue and Liu, Wei and Zeng, Xiangxiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{26}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{bbae693}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bib/bbae693}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Data Transfer Approaches for Seq-to-Seq Retrosynthesis</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/data-transfer-seq-to-seq-retrosynthesis/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/data-transfer-seq-to-seq-retrosynthesis/</guid><description>Systematic comparison of joint training, self-training, and pre-training plus fine-tuning for Transformer-based retrosynthesis on USPTO-50K.</description><content:encoded><![CDATA[<h2 id="systematic-study-of-data-transfer-for-retrosynthesis">Systematic Study of Data Transfer for Retrosynthesis</h2>
<p>This is an <strong>Empirical</strong> paper that systematically compares three standard data transfer methods (joint training, self-training, and pre-training plus fine-tuning) applied to a Transformer-based sequence-to-sequence model for single-step retrosynthesis. The primary contribution is demonstrating that pre-training on a large augmented dataset (USPTO-Full, 877K reactions) followed by fine-tuning on the smaller target dataset (USPTO-50K) produces substantial accuracy improvements over the baseline Transformer, achieving competitive or superior results to contemporaneous state-of-the-art graph-based models at higher values of n-best accuracy.</p>
<h2 id="bridging-the-data-gap-in-retrosynthesis-prediction">Bridging the Data Gap in Retrosynthesis Prediction</h2>
<p><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a>, the problem of predicting reactant compounds needed to synthesize a target product, has seen rapid progress through increasingly sophisticated model architectures: <a href="/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/">LSTM seq-to-seq models</a>, <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Transformer models</a>, and graph-to-graph approaches. However, the authors identify a gap in this research trajectory. While model architecture has received extensive attention, the role of training data strategies has been largely neglected in the retrosynthesis literature.</p>
<p>The core practical problem is that high-quality supervised datasets for retrosynthesis (like USPTO-50K) tend to be small and distribution-skewed, with all samples pre-classified into ten major reaction classes. Meanwhile, larger datasets (USPTO-Full with 877K samples, USPTO-MIT with 479K samples) exist but have different distributional properties. Data transfer techniques are standard practice in computer vision, NLP, and machine translation for exactly this scenario, yet they had not been systematically evaluated for retrosynthesis at the time of this work.</p>
<p>The authors also note a contrast with Zoph et al. (2020), who found that self-training outperforms pre-training in image recognition. They hypothesize that chemical compound strings may have more universal representations than images, making pre-training more effective in the chemistry domain.</p>
<h2 id="three-data-transfer-methods-for-retrosynthesis">Three Data Transfer Methods for Retrosynthesis</h2>
<p>The paper formalizes retrosynthesis as a seq-to-seq problem where both the product $x$ and reactant set $y$ are represented as <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. A retrosynthesis model defines a likelihood $p_{\mathcal{M}}(y \mid x; \theta)$ optimized via maximum log-likelihood:</p>
<p>$$
\theta^{*} = \arg\max_{\theta} \sum_{(x_{i}, y_{i}) \in \mathcal{D}^{T}_{\text{Train}}} \log p(y_{i} \mid x_{i})
$$</p>
<p>Given a target dataset $\mathcal{D}^{T}$ and an augment dataset $\mathcal{D}^{A}$, three transfer methods are examined:</p>
<p><strong>Joint Training</strong> concatenates the training sets and optimizes over the union:</p>
<p>$$
\theta^{*}_{\text{joint}} = \arg\max_{\theta} \sum_{(x_{i}, y_{i}) \in \mathcal{D}_{\text{joint}}} \log p(y_{i} \mid x_{i}), \quad \mathcal{D}_{\text{joint}} = \mathcal{D}^{T}_{\text{Train}} \cup \mathcal{D}^{A}_{\text{Train}}
$$</p>
<p>This requires that both datasets share the same input/output domain (same SMILES canonicalization rules).</p>
<p><strong>Self-Training</strong> (pseudo labeling) first trains a base model on $\mathcal{D}^{T}$ alone, then uses this model to relabel the augment dataset products:</p>
<p>$$
\hat{y}_{i} = \arg\max_{y} \log p(y \mid x_{i}; \theta^{*}_{\text{single}}) \quad \text{for } x_{i} \in \mathcal{D}^{A}_{\text{Train}}
$$</p>
<p>The pseudo-labeled augment set is then combined with $\mathcal{D}^{T}_{\text{Train}}$ for joint training. This approach does not require consistent label domains between datasets.</p>
<p><strong>Pre-training plus Fine-tuning</strong> trains first on the augment dataset to obtain $\theta^{*}_{\text{pretrain}}$, then initializes fine-tuning from this checkpoint:</p>
<p>$$
\theta^{0}_{\text{finetune}} \leftarrow \theta^{*}_{\text{pretrain}}, \quad \theta^{\ell+1}_{\text{finetune}} \leftarrow \theta^{\ell}_{\text{finetune}} - \gamma^{\ell} \nabla \mathcal{L}(\mathcal{D}^{T}_{\text{Train}}) \big|_{{\theta^{\ell}_{\text{finetune}}}}
$$</p>
<h2 id="experimental-setup-on-uspto-benchmarks">Experimental Setup on USPTO Benchmarks</h2>
<p>The experiments use a fixed Transformer architecture (3 self-attention layers, 500-dimensional latent vectors) implemented in OpenNMT-py, evaluated across all three transfer methods.</p>
<p><strong>Datasets:</strong></p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Target</td>
          <td>USPTO-50K</td>
          <td>40K/5K/5K (train/val/test)</td>
          <td>10 reaction classes, curated by Lowe (2012)</td>
      </tr>
      <tr>
          <td>Augment (main)</td>
          <td>USPTO-Full</td>
          <td>844K train (after cleansing)</td>
          <td>Curated by Lowe (2017)</td>
      </tr>
      <tr>
          <td>Augment (smaller)</td>
          <td>USPTO-MIT</td>
          <td>384K train (after cleansing)</td>
          <td>Curated by Jin et al. (2017)</td>
      </tr>
  </tbody>
</table>
<p>Data cleansing removed all augment dataset samples whose product SMILES appeared in any USPTO-50K subset, preventing data leakage. All datasets were re-canonicalized with a unified <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> version.</p>
<p><strong>Evaluation</strong> uses n-best accuracy with k=50 beam search, computing accuracy at n=1, 3, 5, 10, 20, 50. Models are selected by best validation perplexity. All experiments report averages and standard deviations over 5 runs.</p>
<p><strong>Optimization</strong> uses Adam with cyclic learning rate scheduling (warm-up) for all methods except fine-tuning, which uses a standard non-cyclic scheduler.</p>
<p><strong>Results comparing data transfer methods (USPTO-Full augment):</strong></p>
<table>
  <thead>
      <tr>
          <th>Training Method</th>
          <th>n=1</th>
          <th>n=3</th>
          <th>n=5</th>
          <th>n=10</th>
          <th>n=20</th>
          <th>n=50</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Single model (No Transfer)</td>
          <td>35.3 +/- 1.4</td>
          <td>52.8 +/- 1.4</td>
          <td>58.9 +/- 1.3</td>
          <td>64.5 +/- 1.2</td>
          <td>68.8 +/- 1.2</td>
          <td>72.1 +/- 1.3</td>
      </tr>
      <tr>
          <td>Joint Training</td>
          <td>39.1 +/- 1.3</td>
          <td>63.4 +/- 0.9</td>
          <td>71.9 +/- 0.5</td>
          <td>80.1 +/- 0.2</td>
          <td>85.4 +/- 0.3</td>
          <td>89.4 +/- 0.2</td>
      </tr>
      <tr>
          <td>Self-Training</td>
          <td>41.5 +/- 1.0</td>
          <td>60.4 +/- 0.7</td>
          <td>66.1 +/- 0.7</td>
          <td>71.8 +/- 0.6</td>
          <td>75.3 +/- 0.5</td>
          <td>78.0 +/- 0.3</td>
      </tr>
      <tr>
          <td>Pre-training + Fine-Tune</td>
          <td>57.4 +/- 0.4</td>
          <td>77.6 +/- 0.4</td>
          <td>83.1 +/- 0.2</td>
          <td>87.4 +/- 0.4</td>
          <td>89.6 +/- 0.3</td>
          <td>90.9 +/- 0.2</td>
      </tr>
  </tbody>
</table>
<p><strong>Comparison with state-of-the-art models:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture</th>
          <th>n=1</th>
          <th>n=3</th>
          <th>n=5</th>
          <th>n=10</th>
          <th>n=20</th>
          <th>n=50</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GLN (Dai et al., 2019)</td>
          <td>Logic Network</td>
          <td>52.5</td>
          <td>69.0</td>
          <td>75.6</td>
          <td>83.7</td>
          <td>88.5</td>
          <td>92.4</td>
      </tr>
      <tr>
          <td>G2Gs (Shi et al., 2020)</td>
          <td>Graph-to-Graph</td>
          <td>48.9</td>
          <td>67.6</td>
          <td>72.5</td>
          <td>75.5</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>RetroXpert (Yan et al., 2020)</td>
          <td>Graph-to-Graph</td>
          <td>65.6</td>
          <td>78.7</td>
          <td>80.8</td>
          <td>83.3</td>
          <td>84.6</td>
          <td>86.0</td>
      </tr>
      <tr>
          <td>GraphRetro (Somnath et al., 2020)</td>
          <td>Graph-to-Graph</td>
          <td>63.8</td>
          <td>80.5</td>
          <td>84.1</td>
          <td>85.9</td>
          <td>N/A</td>
          <td>87.2</td>
      </tr>
      <tr>
          <td>Pre-training + Fine-Tune (ours)</td>
          <td>Seq-to-Seq</td>
          <td>57.4</td>
          <td>77.6</td>
          <td>83.1</td>
          <td>87.4</td>
          <td>89.6</td>
          <td>90.9</td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Primary findings:</strong></p>
<ol>
<li>All three data transfer methods improve over the no-transfer baseline across all n-best accuracy levels.</li>
<li>Pre-training plus fine-tuning provides the largest gains, improving top-1 accuracy by 22.1 absolute percentage points (from 35.3% to 57.4%) and achieving the best n=10 and n=20 accuracy among all compared models, including graph-based approaches.</li>
<li>Augment dataset size matters: using USPTO-Full (844K) yields substantially better results than USPTO-MIT (384K) for joint training and pre-training plus fine-tuning, though self-training gains are surprisingly robust to augment dataset size.</li>
<li>Manual inspection of erroneous predictions shows that over 99% of top-1 predictions from the pre-trained/fine-tuned model are chemically appropriate or sensible, even when they do not exactly match the gold-standard reactants.</li>
<li>Pre-training plus fine-tuning shows a distinct advantage in training dynamics: the 1-best and n-best accuracy curves evolve similarly during fine-tuning, unlike the single model where these curves can diverge significantly. This makes early stopping more reliable.</li>
</ol>
<p><strong>Class-wise improvements</strong> are observed across all 10 reaction classes, with the largest gains in heterocycle formation (0.40 to 0.86 at 50-best) and functional group interconversion (0.57 to 0.90).</p>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ul>
<li>The model struggles with compounds containing multiple similar substituents (e.g., long-chain hydrocarbons), occasionally selecting the wrong one.</li>
<li>Some reactions involving rare chemical groups (<a href="https://en.wikipedia.org/wiki/Polycyclic_aromatic_hydrocarbon">polycyclic aromatic hydrocarbons</a>) still produce invalid SMILES, suggesting the augment dataset lacks sufficient examples of these structures.</li>
<li>Top-1 accuracy (57.4%) lags behind the best graph-based models (RetroXpert at 65.6%), though the gap narrows at higher n values.</li>
<li>The study uses a fixed Transformer architecture without architecture-specific optimization for each transfer method.</li>
</ul>
<p><strong>Future directions</strong> proposed include freezing parts of the network during fine-tuning, applying data transfer to graph-to-graph models, and testing transferability to other retrosynthesis datasets beyond USPTO-50K.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Target</td>
          <td>USPTO-50K</td>
          <td>50K reactions</td>
          <td>Curated by Lowe (2012), 10 reaction classes</td>
      </tr>
      <tr>
          <td>Augment</td>
          <td>USPTO-Full</td>
          <td>877K reactions (844K after cleansing)</td>
          <td>Curated by Lowe (2017), available via Figshare</td>
      </tr>
      <tr>
          <td>Augment (alt)</td>
          <td>USPTO-MIT</td>
          <td>479K reactions (384K after cleansing)</td>
          <td>Curated by Jin et al. (2017)</td>
      </tr>
  </tbody>
</table>
<p>Data cleansing removes augment samples whose products appear in any USPTO-50K subset. Unified RDKit canonicalization applied to all datasets.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer seq-to-seq model (3 self-attention layers, 500-dim latent vectors)</li>
<li>Positional encoding enabled</li>
<li>Maximum sequence length: 200 tokens</li>
<li>Adam optimizer</li>
<li>Cyclic learning rate scheduler with warm-up (all methods except fine-tuning)</li>
<li>Non-cyclic scheduler for fine-tuning phase (Klein et al., 2017)</li>
<li>Beam search with k=50 for inference</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Implementation: OpenNMT-py</li>
<li>No pre-trained weights or model checkpoints released</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Top-1 accuracy</td>
          <td>57.4%</td>
          <td>35.3% (no transfer)</td>
          <td>Pre-train + fine-tune, USPTO-Full augment</td>
      </tr>
      <tr>
          <td>Top-10 accuracy</td>
          <td>87.4%</td>
          <td>64.5% (no transfer)</td>
          <td>Best among all compared models</td>
      </tr>
      <tr>
          <td>Top-20 accuracy</td>
          <td>89.6%</td>
          <td>68.8% (no transfer)</td>
          <td>Best among all compared models</td>
      </tr>
      <tr>
          <td>Top-50 accuracy</td>
          <td>90.9%</td>
          <td>72.1% (no transfer)</td>
          <td>Competitive with GLN (92.4%)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware details are not specified in the paper. The authors mention GPU memory constraints motivating the 200-token sequence length limit.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ishiguro, K., Ujihara, K., Sawada, R., Akita, H., &amp; Kotera, M. (2020). Data Transfer Approaches to Improve Seq-to-Seq Retrosynthesis. <em>arXiv preprint arXiv:2010.00792</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ishiguro2020data,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data Transfer Approaches to Improve Seq-to-Seq Retrosynthesis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ishiguro, Katsuhiko and Ujihara, Kazuya and Sawada, Ryohto and Akita, Hirotaka and Kotera, Masaaki}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2010.00792}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemLLM: A Chemical Large Language Model Framework</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/</guid><description>ChemLLM introduces the first LLM dedicated to chemistry, with ChemData for instruction tuning and ChemBench for evaluation across nine chemical tasks.</description><content:encoded><![CDATA[<h2 id="a-resource-for-chemistry-specific-language-modeling">A Resource for Chemistry-Specific Language Modeling</h2>
<p>ChemLLM is a <strong>Resource</strong> paper that delivers three interconnected artifacts: ChemData (a 7M-sample instruction tuning dataset for chemistry), ChemBench (a 4,100-question multiple-choice benchmark spanning nine chemistry tasks), and ChemLLM itself (a 7B-parameter language model fine-tuned on InternLM2-Base-7B). Together, these components form the first comprehensive framework for building and evaluating LLMs dedicated to the chemical domain. The primary contribution is not a novel architecture but rather the data curation pipeline, evaluation benchmark, and training methodology that converts structured chemical knowledge into dialogue-formatted instruction data.</p>
<h2 id="bridging-structured-chemical-databases-and-conversational-llms">Bridging Structured Chemical Databases and Conversational LLMs</h2>
<p>While general-purpose LLMs like GPT-4 have shown promise on chemistry tasks, they are not specifically designed for the chemical domain. Several challenges motivate ChemLLM:</p>
<ol>
<li>
<p><strong>Structured data incompatibility</strong>: Most chemical information resides in structured databases (<a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>, <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, <a href="https://en.wikipedia.org/wiki/ChEBI">ChEBI</a>, <a href="/notes/chemistry/datasets/zinc-22/">ZINC</a>, USPTO) that are not naturally suited for training conversational language models. Using this data directly can degrade natural language processing capabilities.</p>
</li>
<li>
<p><strong>Molecular notation understanding</strong>: Molecules are represented in specialized notations like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, which differ from natural language and require explicit alignment during training.</p>
</li>
<li>
<p><strong>Task diversity</strong>: Chemical tasks span name conversion, property prediction, molecular captioning, <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthesis</a>, product prediction, yield prediction, and more. A uniform training pipeline must handle this diversity without task-specific adaptation.</p>
</li>
<li>
<p><strong>Evaluation gaps</strong>: Existing chemical benchmarks (e.g., <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>) are designed for specialist models, not LLMs. Text-based evaluation metrics like <a href="https://en.wikipedia.org/wiki/BLEU">BLEU</a> and <a href="https://en.wikipedia.org/wiki/ROUGE_(metric)">ROUGE</a> are sensitive to output style rather than factual correctness, making them unreliable for scientific accuracy assessment.</p>
</li>
</ol>
<p>Prior work focused on developing specialist models for individual downstream tasks while neglecting instruction-following and dialogue capabilities that are essential for broader reasoning and generalization.</p>
<h2 id="template-based-instruction-construction-from-structured-data">Template-Based Instruction Construction from Structured Data</h2>
<p>The core innovation is a systematic approach for converting structured chemical data into instruction-tuning format through two techniques:</p>
<h3 id="seed-template-prompt-technique">Seed Template Prompt Technique</h3>
<p>For each task type, the authors design a foundational seed template and use GPT-4 to generate variations that differ in expression but maintain semantic consistency. For each structured data entry, one template is randomly selected to create a single-turn dialogue sample. For example, converting <a href="https://en.wikipedia.org/wiki/IUPAC_nomenclature_of_organic_chemistry">IUPAC</a>-to-SMILES entries:</p>
<ul>
<li>&ldquo;Convert the IUPAC name [name] to its corresponding SMILES representation.&rdquo;</li>
<li>&ldquo;What&rsquo;s the SMILES notation for the chemical known as [name]?&rdquo;</li>
<li>&ldquo;Show me the SMILES sequence for [name], please.&rdquo;</li>
</ul>
<h3 id="play-as-playwrights-technique">Play as Playwrights Technique</h3>
<p>To generate richer, multi-turn dialogues, the authors prompt GPT-4 with a chain-of-thought (CoT) style &ldquo;script&rdquo; construction method. GPT-4 is guided to create multi-turn exchanges that simulate expert discussions, smoothly transitioning between question and answer stages. An additional &ldquo;answer masking&rdquo; variant has the model inquire about supplementary chemical information before providing a final answer, simulating realistic expert reasoning.</p>
<h3 id="training-objective">Training Objective</h3>
<p>The model is fine-tuned using <a href="https://en.wikipedia.org/wiki/LoRA_(machine_learning)">LoRA</a> with an autoregressive cross-entropy loss:</p>
<p>$$L_{CE} = -\sum_{c=1}^{M} y_{o,c} \log(p_{o,c})$$</p>
<p>where $M$ is the vocabulary size, $y_{o,c}$ is a binary indicator for whether observation $o$ belongs to class $c$, and $p_{o,c}$ is the predicted probability.</p>
<h2 id="two-stage-training-pipeline-and-chembench-evaluation">Two-Stage Training Pipeline and ChemBench Evaluation</h2>
<h3 id="training-setup">Training Setup</h3>
<p>ChemLLM uses a two-stage instruction tuning approach built on InternLM2-Base-7B:</p>
<p><strong>Stage 1</strong>: Fine-tune on Multi-Corpus (1.7M Q&amp;A pairs from Hugging Face) to enhance general linguistic capabilities, producing InternLM2-Chat-7B.</p>
<p><strong>Stage 2</strong>: Fine-tune on a mixture of ChemData (7M entries) and Multi-Corpus, balancing domain-specific chemical expertise with general language ability.</p>
<p>Training details include:</p>
<ul>
<li>LoRA with rank 8, scale factor 16.0, dropout 0.1</li>
<li>AdamW optimizer with initial learning rate $5.0 \times 10^{-5}$</li>
<li>NEFTune noise injection (alpha = 5) to prevent overfitting</li>
<li>Flash Attention-2 and KV Cache for efficiency</li>
<li>ZeRO Stage-2 for parameter offloading</li>
<li>Per-card batch size of 8 (total batch size 128)</li>
<li>1.06 epochs, 85,255 steps</li>
<li>Training loss reduced from 1.4998 to 0.7158</li>
</ul>
<h3 id="chemdata-composition">ChemData Composition</h3>
<p>ChemData spans three principal task categories with 7M instruction-tuning Q&amp;A pairs:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Tasks</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Molecules</td>
          <td>Name Conversion, Caption2Mol, Mol2Caption, Molecular Property Prediction</td>
      </tr>
      <tr>
          <td>Reactions</td>
          <td>Retrosynthesis, Product Prediction, Yield Prediction, Temperature Prediction, Solvent Prediction</td>
      </tr>
      <tr>
          <td>Domain-specific</td>
          <td>General chemical knowledge for broader chemical space understanding</td>
      </tr>
  </tbody>
</table>
<p>Data sources include PubChem, ChEMBL, ChEBI, ZINC, USPTO, ORDerly, ChemRxiv, LibreTexts Chemistry, Wikipedia, and Wikidata.</p>
<h3 id="chembench-design">ChemBench Design</h3>
<p>ChemBench contains 4,100 multiple-choice questions across the same nine tasks as ChemData. The choice of multiple-choice format is deliberate: it minimizes the influence of output style and focuses evaluation on factual correctness, unlike BLEU/ROUGE-based evaluation. Wrong answers are generated by sampling nearby values (for prediction tasks) or using GPT-4 to create plausible distractors. Deduplication ensures no overlap between ChemData training entries and ChemBench questions.</p>
<p>ChemBench has been contributed to the OpenCompass evaluation platform.</p>
<h3 id="baselines">Baselines</h3>
<p>All evaluations use 5-shot prompting. Baselines include:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LLaMA-2</td>
          <td>Open-source</td>
          <td>7B</td>
      </tr>
      <tr>
          <td>Mistral</td>
          <td>Open-source</td>
          <td>7B</td>
      </tr>
      <tr>
          <td>ChatGLM3</td>
          <td>Open-source</td>
          <td>7B</td>
      </tr>
      <tr>
          <td>Qwen</td>
          <td>Open-source</td>
          <td>7B</td>
      </tr>
      <tr>
          <td>InternLM2-Chat-7B</td>
          <td>Open-source (Stage 1 only)</td>
          <td>7B</td>
      </tr>
      <tr>
          <td>GPT-3.5</td>
          <td>Closed-source</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>GPT-4</td>
          <td>Closed-source</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<h2 id="chemllm-matches-gpt-4-on-chemical-tasks-and-outperforms-7b-peers">ChemLLM Matches GPT-4 on Chemical Tasks and Outperforms 7B Peers</h2>
<h3 id="chemical-evaluation-chembench">Chemical Evaluation (ChemBench)</h3>
<p>ChemLLM significantly outperforms general LLMs of similar scale and surpasses GPT-3.5 across all nine tasks. Compared to GPT-4, ChemLLM achieves higher scores on six of nine tasks, with the remaining three ranking just below GPT-4. LLaMA-2 scores near random chance (~25 per task), highlighting the difficulty of these tasks for models without chemical training.</p>
<p>Compared to InternLM2-Chat-7B (the Stage 1 model), ChemLLM shows substantial improvement, confirming the effectiveness of the Stage 2 chemical fine-tuning.</p>
<h3 id="general-evaluation">General Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>ChemLLM</th>
          <th>Best 7B Baseline</th>
          <th>GPT-4</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MMLU</td>
          <td>65.6</td>
          <td>&lt; 65.6</td>
          <td>Higher</td>
      </tr>
      <tr>
          <td>C-Eval</td>
          <td>67.2</td>
          <td>&lt; 67.2</td>
          <td>Higher</td>
      </tr>
      <tr>
          <td>GSM8K</td>
          <td>67.2</td>
          <td>&lt; 67.2</td>
          <td>Higher</td>
      </tr>
      <tr>
          <td>C-MHChem</td>
          <td>76.4</td>
          <td>&lt; 76.4</td>
          <td>&lt; 76.4</td>
      </tr>
  </tbody>
</table>
<p>ChemLLM outperforms all competing 7B models on MMLU, C-Eval, and GSM8K. On C-MHChem (Chinese middle and high school chemistry), ChemLLM scores 76.4, surpassing GPT-4. The authors note that chemical data fine-tuning may enhance reasoning capabilities due to the logical reasoning required in chemical problem-solving. ChemLLM also comprehensively surpasses InternLM2-Chat-7B on all four general benchmarks, indicating that chemical data does not harm general capabilities.</p>
<h3 id="qualitative-capabilities">Qualitative Capabilities</h3>
<p>The paper demonstrates qualitative performance on chemistry-related NLP tasks including:</p>
<ul>
<li>Chemical literature translation (English to Chinese and vice versa)</li>
<li>Chemical poetry creation</li>
<li>Information extraction from chemical text</li>
<li>Text summarization of chemical research</li>
<li>Reading comprehension on chemistry topics</li>
<li>Named entity recognition for chemical entities</li>
<li>Ethics and safety reasoning in chemical contexts</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper does not provide individual task-level scores in tabular form for ChemBench (only radar charts), making precise comparison difficult. Specific scores for each of the nine tasks across all baselines are not reported numerically. The evaluation is limited to 5-shot prompting without exploration of zero-shot or chain-of-thought prompting variants. The paper also does not discuss failure modes or systematic weaknesses of ChemLLM on particular task types.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Stage 1 Training</td>
          <td>Multi-Corpus</td>
          <td>1.7M Q&amp;A</td>
          <td>Collected from Hugging Face</td>
      </tr>
      <tr>
          <td>Stage 2 Training</td>
          <td>ChemData + Multi-Corpus</td>
          <td>7M + 1.7M</td>
          <td>Chemical + general mixture</td>
      </tr>
      <tr>
          <td>Chemical Evaluation</td>
          <td>ChemBench</td>
          <td>4,100 MCQ</td>
          <td>9 tasks, contributed to OpenCompass</td>
      </tr>
      <tr>
          <td>General Evaluation</td>
          <td>MMLU, C-Eval, GSM8K, C-MHChem</td>
          <td>Varies</td>
          <td>Standard benchmarks</td>
      </tr>
  </tbody>
</table>
<p>Data sources for ChemData: PubChem, ChEMBL, ChEBI, ZINC, USPTO, ORDerly, ChemRxiv, LibreTexts Chemistry, Wikipedia, Wikidata.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Two-stage instruction tuning (general then chemical)</li>
<li>LoRA fine-tuning (rank 8, scale 16.0, dropout 0.1)</li>
<li>Template-based instruction construction with GPT-4 for diversity</li>
<li>Play as Playwrights CoT prompting for multi-turn dialogue generation</li>
<li>NEFTune noise injection (alpha 5)</li>
<li>DeepSpeed ZeRO++ for distributed training</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Base</th>
          <th>Parameters</th>
          <th>Availability</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemLLM-7B-Chat</td>
          <td>InternLM2-Base-7B</td>
          <td>7B</td>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-7B-Chat">Hugging Face</a></td>
      </tr>
      <tr>
          <td>ChemLLM-7B-Chat-1.5-DPO</td>
          <td>InternLM2</td>
          <td>7B</td>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-7B-Chat-1_5-DPO">Hugging Face</a></td>
      </tr>
      <tr>
          <td>ChemLLM-20B-Chat-DPO</td>
          <td>InternLM</td>
          <td>20B</td>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-20B-Chat-DPO">Hugging Face</a></td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>5-shot evaluation across all benchmarks. Multiple-choice format for ChemBench to minimize output style bias.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li>2 machines, each with 8 NVIDIA A100 SMX GPUs</li>
<li>2 AMD EPYC 7742 64-Core CPUs per machine (256 threads each)</li>
<li>SLURM cluster management</li>
<li>BF16 mixed precision training</li>
<li>Flash Attention-2 + KV Cache</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-7B-Chat">ChemLLM-7B-Chat</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Original 7B chat model</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-7B-Chat-1_5-DPO">ChemLLM-7B-Chat-1.5-DPO</a></td>
          <td>Model</td>
          <td>Other</td>
          <td>Updated v1.5 with DPO</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemLLM-20B-Chat-DPO">ChemLLM-20B-Chat-DPO</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>20B parameter variant</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem">AI4Chem HuggingFace</a></td>
          <td>Collection</td>
          <td>Various</td>
          <td>All models, datasets, and code</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, D., Liu, W., Tan, Q., Chen, J., Yan, H., Yan, Y., Li, J., Huang, W., Yue, X., Ouyang, W., Zhou, D., Zhang, S., Su, M., Zhong, H.-S., &amp; Li, Y. (2024). ChemLLM: A Chemical Large Language Model. <em>arXiv preprint arXiv:2402.06852</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2024chemllm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemLLM: A Chemical Large Language Model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Di and Liu, Wei and Tan, Qian and Chen, Jingdan and Yan, Hang and Yan, Yuliang and Li, Jiatong and Huang, Weiran and Yue, Xiangyu and Ouyang, Wanli and Zhou, Dongzhan and Zhang, Shufei and Su, Mao and Zhong, Han-Sen and Li, Yuqiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2402.06852}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChatDrug: Conversational Drug Editing with ChatGPT</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chatdrug-conversational-drug-editing/</guid><description>ChatDrug uses ChatGPT with retrieval and domain feedback for drug editing across small molecules, peptides, and proteins on 39 tasks.</description><content:encoded><![CDATA[<h2 id="a-framework-for-conversational-drug-editing-with-llms">A Framework for Conversational Drug Editing with LLMs</h2>
<p>ChatDrug is a <strong>Method</strong> paper that introduces a parameter-free framework for drug editing using conversational large language models (specifically ChatGPT/GPT-3.5). The primary contribution is a three-module pipeline that combines prompt engineering, retrieval-augmented domain feedback, and iterative conversation to perform text-guided editing of small molecules, peptides, and proteins. The paper also establishes a benchmark of 39 drug editing tasks spanning these three drug types.</p>
<h2 id="bridging-conversational-ai-and-drug-discovery">Bridging Conversational AI and Drug Discovery</h2>
<p>Drug editing (also called <a href="https://en.wikipedia.org/wiki/Hit_to_lead">lead optimization</a> or protein design) is a critical step in the drug discovery pipeline where molecular substructures are modified to achieve desired properties. Traditional approaches rely on domain experts for manual editing, which can be subjective and biased. Recent multi-modal approaches like MoleculeSTM and ProteinDT have started exploring text-guided drug editing, but they are domain-specific (limited to one drug type) and lack conversational capabilities for iterative refinement.</p>
<p>The authors identify three properties of conversational LLMs that make them suitable for drug discovery: (1) pretraining on comprehensive knowledge bases covering drug-related concepts, (2) strong few-shot adaptation and generalization abilities, and (3) interactive communication enabling iterative feedback incorporation. However, directly applying LLMs to drug editing yields suboptimal results because the models do not fully utilize prior domain knowledge. ChatDrug addresses this gap through structured retrieval and feedback mechanisms.</p>
<h2 id="three-module-pipeline-pdds-redf-and-conversation">Three-Module Pipeline: PDDS, ReDF, and Conversation</h2>
<p>ChatDrug consists of three modules that operate sequentially without any parameter learning.</p>
<h3 id="pdds-module-prompt-design-for-domain-specific">PDDS Module (Prompt Design for Domain-Specific)</h3>
<p>The PDDS module constructs domain-specific prompts for ChatGPT. Given an input drug $\pmb{x}_{\text{in}}$ and a text prompt $\pmb{x}_t$ describing the desired property change, the goal is:</p>
<p>$$
\pmb{x}_{\text{out}} = \text{ChatDrug}(\pmb{x}_{\text{in}}, \pmb{x}_t)
$$</p>
<p>The prompts are designed around high-level property descriptions (e.g., &ldquo;more soluble in water&rdquo;) rather than exact substructure replacements. The authors argue that ChatDrug is better suited for &ldquo;fuzzy searching&rdquo; (property-based editing with non-deterministic answers) rather than &ldquo;exact searching&rdquo; (precise substructure replacement that experts can do directly).</p>
<h3 id="redf-module-retrieval-and-domain-feedback">ReDF Module (Retrieval and Domain Feedback)</h3>
<p>The ReDF module retrieves structurally similar examples from a domain-specific database and injects them into the conversation as demonstrations. For an input drug $\pmb{x}_{\text{in}}$, a candidate drug $\tilde{\pmb{x}}$ that failed the desired property change, and a retrieval database, ReDF returns:</p>
<p>$$
\pmb{x}_R = \text{ReDF}(\pmb{x}_{\text{in}}, \tilde{\pmb{x}}; \pmb{x}_t) = \underset{\pmb{x}&rsquo;_R \in \text{RetrievalDB}}{\arg\max} \langle \tilde{\pmb{x}}, \pmb{x}&rsquo;_R \rangle \wedge D(\pmb{x}_{\text{in}}, \pmb{x}&rsquo;_R; \pmb{x}_t)
$$</p>
<p>where $D(\cdot, \cdot; \cdot) \in {\text{True}, \text{False}}$ is a domain feedback function checking whether the retrieved drug satisfies the desired property change, and $\langle \tilde{\pmb{x}}, \pmb{x}&rsquo;_R \rangle$ is a similarity function (<a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> for small molecules, <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a> for peptides and proteins).</p>
<p>The retrieved example $\pmb{x}_R$ is injected into the prompt as: &ldquo;Your provided sequence [$\tilde{\pmb{x}}$] is not correct. We find a sequence [$\pmb{x}_R$] which is correct and similar to the molecule you provided. Can you give me a new molecule?&rdquo;</p>
<h3 id="conversation-module">Conversation Module</h3>
<p>The conversation module enables iterative refinement over $C$ rounds. At each round $c$, if the edited drug $\pmb{x}_c$ does not satisfy the evaluation condition, ChatDrug retrieves a new example via ReDF using $\tilde{\pmb{x}} = \pmb{x}_c$ and continues the conversation. This aligns with the iterative nature of real drug discovery workflows.</p>
<h2 id="experiments-across-39-drug-editing-tasks">Experiments Across 39 Drug Editing Tasks</h2>
<h3 id="task-design">Task Design</h3>
<p>The benchmark includes 39 tasks across three drug types:</p>
<ul>
<li><strong>Small molecules</strong> (28 tasks): 16 single-objective (tasks 101-108, each with loose and strict thresholds) and 12 multi-objective tasks (tasks 201-206, each with two thresholds). Properties include solubility (<a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a>), drug-likeness (QED), permeability (<a href="https://en.wikipedia.org/wiki/Polar_surface_area">tPSA</a>), <a href="https://en.wikipedia.org/wiki/Hydrogen_bond">hydrogen bond</a> acceptors/donors.</li>
<li><strong>Peptides</strong> (9 tasks): 6 single-objective and 3 multi-objective tasks for editing <a href="https://en.wikipedia.org/wiki/Major_histocompatibility_complex">peptide-MHC binding</a> affinity across different <a href="https://en.wikipedia.org/wiki/Human_leukocyte_antigen">HLA allele</a> types.</li>
<li><strong>Proteins</strong> (2 tasks): Editing protein sequences to increase <a href="https://en.wikipedia.org/wiki/Alpha_helix">alpha-helix</a> or <a href="https://en.wikipedia.org/wiki/Beta_sheet">beta-strand</a> secondary structures.</li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>For small molecules, baselines include Random, PCA, High-Variance, and GS-Mutate (all based on MegaMolBART), plus MoleculeSTM with <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and Graph representations. For peptides and proteins, random mutation baselines with 1-3 mutated positions are used.</p>
<h3 id="main-results">Main Results</h3>
<p>ChatDrug achieves the best performance on 33 out of 39 tasks. Key results for small molecule editing (hit ratio):</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Property</th>
          <th>ChatDrug (loose)</th>
          <th>Best Baseline (loose)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>101</td>
          <td>More soluble</td>
          <td>94.13</td>
          <td>67.86 (MoleculeSTM-Graph)</td>
      </tr>
      <tr>
          <td>102</td>
          <td>Less soluble</td>
          <td>96.86</td>
          <td>64.79 (MoleculeSTM-Graph)</td>
      </tr>
      <tr>
          <td>106</td>
          <td>Lower permeability</td>
          <td>77.35</td>
          <td>34.13 (MoleculeSTM-SMILES)</td>
      </tr>
      <tr>
          <td>107</td>
          <td>More HBA</td>
          <td>95.35</td>
          <td>54.01 (MoleculeSTM-SMILES)</td>
      </tr>
      <tr>
          <td>108</td>
          <td>More HBD</td>
          <td>96.54</td>
          <td>60.97 (MoleculeSTM-Graph)</td>
      </tr>
  </tbody>
</table>
<p>ChatDrug underperforms on tasks 104 (less like a drug) and 105 (higher permeability) and most multi-objective tasks involving permeability (205), where MoleculeSTM variants perform better.</p>
<p>For peptide editing, ChatDrug achieves 41-69% hit ratios compared to 0.4-14.4% for random mutation baselines. For protein editing, ChatDrug reaches 34.79% and 51.38% hit ratios on helix and strand tasks respectively, compared to 26.90% and 21.44% for the best random mutation baseline.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Conversation rounds</strong>: Performance increases with more rounds, converging around $C = 2$. For example, on task 101 (loose threshold), zero-shot achieves 78.26%, $C = 1$ reaches 89.56%, and $C = 2$ reaches 93.37%.</p>
<p><strong>ReDF threshold</strong>: Using a stricter threshold in the domain feedback function $D$ (matching the evaluation threshold) yields substantially higher performance than using a loose threshold. For example, on task 107 with strict evaluation, the strict-threshold ReDF achieves 72.60% vs. 14.96% for the loose-threshold ReDF.</p>
<p><strong>Similarity analysis</strong>: Retrieved molecules $\pmb{x}_R$ tend to have lower similarity to input molecules than the intermediate outputs $\pmb{x}_1$, yet they have higher hit ratios. This suggests the ReDF module explores the chemical space effectively, and the conversation module balances similarity preservation with property optimization.</p>
<p><strong>Knowledge extraction</strong>: ChatDrug can articulate domain-specific reasoning for its edits (e.g., summarizing rules for increasing water solubility by introducing polar functional groups), though the extracted knowledge shows some redundancy.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>ChatDrug demonstrates that conversational LLMs can serve as useful tools for drug editing, achieving strong results across diverse drug types with a parameter-free approach. The framework exhibits open vocabulary and compositional properties, allowing it to handle novel drug concepts and multi-objective tasks through natural language.</p>
<p>The authors acknowledge two main limitations. First, ChatDrug struggles with understanding complex 3D drug geometries, which would require deeper geometric modeling. Second, the framework requires multiple conversation rounds to achieve strong performance, adding computational cost through repeated API calls. The authors suggest that knowledge summarization capabilities of LLMs could help reduce this cost.</p>
<p>The evaluation relies entirely on computational oracles (RDKit for small molecules, MHCflurry2.0 for peptides, ProteinCLAP for proteins) rather than wet-lab validation. The hit ratio metric also excludes invalid outputs from the denominator, so the effective success rate on all attempted edits may be lower than reported.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Small molecule inputs</td>
          <td><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a></td>
          <td>200 molecules</td>
          <td>Sampled SMILES strings</td>
      </tr>
      <tr>
          <td>Small molecule retrieval DB</td>
          <td>ZINC</td>
          <td>10K molecules</td>
          <td>For ReDF similarity search</td>
      </tr>
      <tr>
          <td>Peptide inputs</td>
          <td>Peptide-MHC binding dataset</td>
          <td>500 peptides per task</td>
          <td>From 30 common MHC alleles</td>
      </tr>
      <tr>
          <td>Peptide retrieval DB</td>
          <td>Experimental binding data</td>
          <td>Varies by allele</td>
          <td>Target allele experimental data</td>
      </tr>
      <tr>
          <td>Protein inputs</td>
          <td>TAPE test set</td>
          <td>Varies</td>
          <td>Secondary structure prediction test data</td>
      </tr>
      <tr>
          <td>Protein retrieval DB</td>
          <td>TAPE training set</td>
          <td>Varies</td>
          <td>Secondary structure prediction training data</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>GPT-3.5-turbo via OpenAI ChatCompletion API, temperature=0, frequency_penalty=0.2</li>
<li>System prompt: &ldquo;You are an expert in the field of molecular chemistry.&rdquo;</li>
<li>$C = 2$ conversation rounds for main results</li>
<li>5 random seeds (0-4) for small molecule main results, seed 0 for ablations</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>ChatGPT (GPT-3.5-turbo): used as-is, no fine-tuning</li>
<li>MHCflurry 2.0: pseudo-oracle for peptide binding affinity evaluation</li>
<li>ProteinCLAP-EBM-NCE from ProteinDT: protein secondary structure prediction</li>
<li>ESMFold: protein folding for visualization</li>
<li>RDKit: molecular property calculations for small molecules</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Hit Ratio</td>
          <td>Fraction of valid edits satisfying property requirements</td>
          <td>Invalid sequences excluded from denominator</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>All experiments conducted on a single NVIDIA RTX A6000 GPU (used only for peptide and protein evaluation). Total OpenAI API cost was less than $100.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/chao1224/ChatDrug">ChatDrug GitHub</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, S., Wang, J., Yang, Y., Wang, C., Liu, L., Guo, H., &amp; Xiao, C. (2024). Conversational Drug Editing Using Retrieval and Domain Feedback. <em>ICLR 2024</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liu2024chatdrug,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Conversational Drug Editing Using Retrieval and Domain Feedback}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Liu, Shengchao and Wang, Jiongxiao and Yang, Yijin and Wang, Chengpeng and Liu, Ling and Guo, Hongyu and Xiao, Chaowei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>BioT5: Cross-Modal Integration of Biology and Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/biot5-cross-modal-biology/</link><pubDate>Sat, 28 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/biot5-cross-modal-biology/</guid><description>BioT5 is a T5-based pretraining framework that jointly models molecules, proteins, and natural language using SELFIES for robust molecular generation.</description><content:encoded><![CDATA[<h2 id="a-unified-pretraining-framework-for-molecules-proteins-and-text">A Unified Pretraining Framework for Molecules, Proteins, and Text</h2>
<p>BioT5 is a <strong>Method</strong> paper that introduces a comprehensive <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-based pretraining framework for cross-modal integration of molecules, proteins, and natural language. The primary contribution is a multi-task pretraining approach that uses <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (instead of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) for 100% valid molecular representations, separate tokenization for each modality, and a combination of masked language modeling and translation objectives to connect structured biological data with unstructured scientific text. After fine-tuning, BioT5 (252M parameters) achieves state-of-the-art performance on 10 out of 15 downstream tasks spanning molecule property prediction, protein property prediction, drug-target interaction, protein-protein interaction, molecule captioning, and text-based molecule generation.</p>
<h2 id="bridging-the-gap-between-molecular-sequences-and-scientific-knowledge">Bridging the Gap Between Molecular Sequences and Scientific Knowledge</h2>
<p>Prior cross-modal models in computational biology face three recurring challenges. First, models like MolT5 and MolXPT rely on SMILES to represent molecules, but SMILES strings are syntactically fragile: random perturbations or model-generated sequences frequently produce invalid molecular structures. Edwards et al. (2022) and Li et al. (2023) both highlight this validity problem as a bottleneck for text-to-molecule generation. Second, the contextual information surrounding molecular and protein names in scientific literature (e.g., mentions in <a href="https://en.wikipedia.org/wiki/PubMed">PubMed</a> abstracts that describe properties, interactions, and experimental results) remains underutilized. Most models either ignore this context or treat it identically to structured database entries. Third, existing approaches like MolT5 and <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a> share a single tokenizer and embedding space across molecules, proteins, and text. This leads to chemically incorrect tokenization: the bromine atom &ldquo;Br&rdquo; in SMILES gets split into &ldquo;B&rdquo; (boron) and &ldquo;r&rdquo;, producing erroneous downstream predictions.</p>
<p>BioT5 addresses all three issues simultaneously by adopting SELFIES for molecular representation, extracting entity-linked contextual knowledge from PubMed, and employing separate vocabularies for each modality.</p>
<h2 id="selfies-separate-tokenization-and-multi-task-pretraining">SELFIES, Separate Tokenization, and Multi-Task Pretraining</h2>
<p>The core innovations of BioT5 center on three design decisions:</p>
<h3 id="selfies-for-robust-molecular-representation">SELFIES for Robust Molecular Representation</h3>
<p>BioT5 replaces SMILES with SELFIES (Self-referencing Embedded Strings) for all molecular representations. Every permutation of symbols within the SELFIES alphabet generates a chemically valid molecular structure, guaranteeing 100% validity in generation tasks. Molecules from ZINC20 are converted from SMILES to SELFIES during data preprocessing.</p>
<h3 id="modality-specific-tokenization">Modality-Specific Tokenization</h3>
<p>Rather than sharing a single SentencePiece vocabulary across modalities, BioT5 maintains three separate dictionaries:</p>
<ul>
<li><strong>Molecules</strong>: Each SELFIES token corresponds to a chemically meaningful atom group enclosed in brackets (e.g., <code>[C]</code>, <code>[=C]</code>, <code>[Br]</code>).</li>
<li><strong>Proteins</strong>: Amino acids are prefixed with a special <code>&lt;p&gt;</code> token to distinguish them from text characters (e.g., <code>&lt;p&gt;M</code>, <code>&lt;p&gt;K</code>, <code>&lt;p&gt;R</code>).</li>
<li><strong>Text</strong>: The standard T5 vocabulary is retained.</li>
</ul>
<p>This prevents semantic conflation across modalities. The total vocabulary size is 35,073, and the model comprises 252M parameters using the T5-v1.1-base architecture.</p>
<h3 id="multi-task-pretraining-objectives">Multi-Task Pretraining Objectives</h3>
<p>BioT5 uses six pretraining tasks organized into three categories:</p>
<ol>
<li><strong>Single-modal T5 objective</strong>: Standard span corruption and recovery applied independently to molecule SELFIES (task 1), protein <a href="https://en.wikipedia.org/wiki/FASTA_format">FASTA</a> (task 2), and general text from C4 (task 3).</li>
<li><strong>Wrapped text T5 objective</strong> (task 4): Applied to PubMed articles where molecular names are replaced with corresponding SELFIES strings and gene names are appended with protein FASTA sequences, using BERN2 for named entity recognition and entity linking.</li>
<li><strong>Bidirectional translation</strong> (tasks 5 and 6): Molecule SELFIES to text description and vice versa (using 339K pairs from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>), and protein FASTA to text description and vice versa (using 569K pairs from <a href="https://en.wikipedia.org/wiki/UniProt">Swiss-Prot</a>).</li>
</ol>
<p>The translation direction is randomly sampled with probability 0.5 for each example. For downstream tasks, BioT5 uses prompt-based fine-tuning to cast all tasks into a sequence generation format, reducing the gap between pretraining and fine-tuning.</p>
<h2 id="evaluation-across-15-downstream-tasks">Evaluation Across 15 Downstream Tasks</h2>
<p>BioT5 is evaluated on 15 tasks organized into three categories: single-instance prediction, multi-instance prediction, and cross-modal generation.</p>
<h3 id="molecule-property-prediction-moleculenet">Molecule Property Prediction (MoleculeNet)</h3>
<p>BioT5 is evaluated on six binary classification tasks from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> using scaffold splitting: BBBP, Tox21, ClinTox, HIV, BACE, and SIDER. Results are averaged over three random runs.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>GEM</th>
          <th>MolXPT</th>
          <th>BioT5</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>72.4</td>
          <td>80.0</td>
          <td>77.7</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>78.1</td>
          <td>77.1</td>
          <td>77.9</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>90.1</td>
          <td>95.3</td>
          <td>95.4</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>80.6</td>
          <td>78.1</td>
          <td><strong>81.0</strong></td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>85.6</td>
          <td>88.4</td>
          <td><strong>89.4</strong></td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>67.2</td>
          <td>71.7</td>
          <td><strong>73.2</strong></td>
      </tr>
      <tr>
          <td><strong>Avg</strong></td>
          <td>79.0</td>
          <td>81.9</td>
          <td><strong>82.4</strong></td>
      </tr>
  </tbody>
</table>
<p>BioT5 achieves the best average AUROC (82.4) across all six datasets, surpassing both GNN-based methods (GEM) and language model baselines (MolXPT).</p>
<h3 id="protein-property-prediction-peer-benchmark">Protein Property Prediction (PEER Benchmark)</h3>
<p>On the PEER benchmark, BioT5 is evaluated on protein solubility and subcellular localization prediction:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params</th>
          <th>Solubility (Acc)</th>
          <th>Localization (Acc)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESM-1b</td>
          <td>652.4M</td>
          <td>70.23</td>
          <td><strong>92.40</strong></td>
      </tr>
      <tr>
          <td>ProtBert</td>
          <td>419.9M</td>
          <td>68.15</td>
          <td>91.32</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td>252.1M</td>
          <td><strong>74.65</strong></td>
          <td>91.69</td>
      </tr>
  </tbody>
</table>
<p>BioT5 achieves the best solubility prediction accuracy (74.65%) despite being 2-3x smaller than dedicated protein language models like ESM-1b and ProtBert.</p>
<h3 id="drug-target-interaction-prediction">Drug-Target Interaction Prediction</h3>
<p>BioT5 is evaluated on three DTI datasets (BioSNAP, Human, BindingDB) with five random runs:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>BioSNAP AUROC</th>
          <th>Human AUROC</th>
          <th>BindingDB AUROC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DrugBAN</td>
          <td>0.903</td>
          <td>0.982</td>
          <td>0.960</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td><strong>0.937</strong></td>
          <td><strong>0.989</strong></td>
          <td><strong>0.963</strong></td>
      </tr>
  </tbody>
</table>
<p>BioT5 consistently outperforms DrugBAN and other specialized DTI models across all three datasets.</p>
<h3 id="molecule-captioning-and-text-based-molecule-generation">Molecule Captioning and Text-Based Molecule Generation</h3>
<p>On the ChEBI-20 dataset, BioT5 outperforms all baselines in molecule captioning:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params</th>
          <th>BLEU-4</th>
          <th>METEOR</th>
          <th>Text2Mol</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolT5-large</td>
          <td>783M</td>
          <td>0.508</td>
          <td>0.614</td>
          <td>0.582</td>
      </tr>
      <tr>
          <td>MolXPT</td>
          <td>350M</td>
          <td>0.505</td>
          <td>0.626</td>
          <td>0.594</td>
      </tr>
      <tr>
          <td>BioT5</td>
          <td>252M</td>
          <td><strong>0.556</strong></td>
          <td><strong>0.656</strong></td>
          <td><strong>0.603</strong></td>
      </tr>
  </tbody>
</table>
<p>For text-based molecule generation, BioT5 achieves an exact match score of 0.413 (vs. 0.311 for MolT5-large) while maintaining 100% validity, compared to 90.5% for MolT5-large. This demonstrates the direct benefit of SELFIES: every generated sequence is a valid molecule.</p>
<h3 id="protein-protein-interaction-prediction">Protein-Protein Interaction Prediction</h3>
<p>On the PEER PPI benchmarks (Yeast and Human), BioT5 achieves competitive results, outperforming fully fine-tuned ProtBert and ESM-1b on the Yeast dataset (64.89% vs. 63.72% for ProtBert) and placing second on Human (86.22% vs. 88.06% for ESM-1b with frozen weights).</p>
<h2 id="key-findings-limitations-and-future-directions">Key Findings, Limitations, and Future Directions</h2>
<p>BioT5 demonstrates that integrating molecular, protein, and textual modalities within a single pretraining framework yields consistent improvements across diverse biological tasks. Three factors drive BioT5&rsquo;s performance: (1) SELFIES guarantees 100% molecular validity in generation tasks, eliminating a persistent failure mode of SMILES-based models; (2) separate tokenization preserves the semantic integrity of each modality; (3) wrapped text pretraining on PubMed provides contextual biological knowledge that pure sequence models miss.</p>
<p>The authors acknowledge several limitations. BioT5 requires full-parameter fine-tuning for each downstream task because instruction-tuning does not generalize across tasks, and combining datasets via instructions causes data leakage (the authors note overlaps between BindingDB training data and BioSNAP/Human test sets). The model only handles sequence-format bio-entities and does not incorporate 2D or 3D structural information. Additional biological modalities such as DNA/RNA sequences and cell-level data are also left for future work.</p>
<p>The authors also note risks: BioT5 could potentially be misused to generate dangerous molecules, and it may fail to generate effective therapeutic molecules or produce compounds with adverse side effects.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining (molecules)</td>
          <td>ZINC20</td>
          <td>~300M molecules</td>
          <td>Converted from SMILES to SELFIES</td>
      </tr>
      <tr>
          <td>Pretraining (proteins)</td>
          <td><a href="https://en.wikipedia.org/wiki/UniProt">UniRef50</a></td>
          <td>27M proteins</td>
          <td>Filtered by length</td>
      </tr>
      <tr>
          <td>Pretraining (text)</td>
          <td>C4</td>
          <td>Large</td>
          <td>Standard T5 corpus</td>
      </tr>
      <tr>
          <td>Pretraining (wrapped text)</td>
          <td>PubMed</td>
          <td>33M articles</td>
          <td>Entity linking via BERN2</td>
      </tr>
      <tr>
          <td>Pretraining (molecule-text pairs)</td>
          <td>PubChem</td>
          <td>339K pairs</td>
          <td>Excludes ChEBI-20 molecules</td>
      </tr>
      <tr>
          <td>Pretraining (protein-text pairs)</td>
          <td>Swiss-Prot</td>
          <td>569K pairs</td>
          <td>High-quality annotations</td>
      </tr>
      <tr>
          <td>Evaluation (molecular properties)</td>
          <td>MoleculeNet</td>
          <td>6 datasets</td>
          <td>Scaffold splitting</td>
      </tr>
      <tr>
          <td>Evaluation (protein properties)</td>
          <td>PEER</td>
          <td>2 tasks</td>
          <td>Solubility and localization</td>
      </tr>
      <tr>
          <td>Evaluation (DTI)</td>
          <td>BioSNAP, Human, BindingDB</td>
          <td>3 datasets</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Evaluation (PPI)</td>
          <td>Yeast, Human</td>
          <td>2 datasets</td>
          <td>From PEER benchmark</td>
      </tr>
      <tr>
          <td>Evaluation (generation)</td>
          <td>ChEBI-20</td>
          <td>33K pairs</td>
          <td>Molecule captioning and text-to-molecule</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: T5-v1.1-base (encoder-decoder transformer)</li>
<li>Optimizer: AdamW with RMS scaling</li>
<li>Learning rate: cosine annealing, base $1 \times 10^{-2}$, minimum $1 \times 10^{-5}$</li>
<li>Warmup steps: 10,000</li>
<li>Dropout: 0.0</li>
<li>Maximum input length: 512 tokens</li>
<li>Pretraining steps: 350K</li>
<li>Batch size: 96 per GPU (6 data types per batch)</li>
<li>Prompt-based fine-tuning for all downstream tasks</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Vocabulary Size</th>
          <th>Architecture</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BioT5</td>
          <td>252M</td>
          <td>35,073</td>
          <td>T5-v1.1-base</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Molecule property prediction: AUROC on 6 MoleculeNet tasks (scaffold split, 3 runs)</li>
<li>Protein property prediction: accuracy on PEER benchmark (3 runs)</li>
<li>Drug-target interaction: AUROC, AUPRC, accuracy on 3 DTI datasets (5 runs)</li>
<li>Protein-protein interaction: accuracy on 2 PPI datasets (3 runs)</li>
<li>Molecule captioning: BLEU, ROUGE, METEOR, Text2Mol on ChEBI-20</li>
<li>Text-based molecule generation: BLEU, exact match, fingerprint similarities, FCD, validity on ChEBI-20</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>8x NVIDIA A100 80GB GPUs for pretraining</li>
<li>Codebase: nanoT5</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/QizhiPei/BioT5">BioT5 Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Pei, Q., Zhang, W., Zhu, J., Wu, K., Gao, K., Wu, L., Xia, Y., &amp; Yan, R. (2023). BioT5: Enriching Cross-modal Integration in Biology with Chemical Knowledge and Natural Language Associations. <em>Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing</em>, 1102-1123. <a href="https://doi.org/10.18653/v1/2023.emnlp-main.70">https://doi.org/10.18653/v1/2023.emnlp-main.70</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{pei2023biot5,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{BioT5: Enriching Cross-modal Integration in Biology with Chemical Knowledge and Natural Language Associations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Pei, Qizhi and Zhang, Wei and Zhu, Jinhua and Wu, Kehan and Gao, Kaiyuan and Wu, Lijun and Xia, Yingce and Yan, Rui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1102--1123}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2023.emnlp-main.70}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MTL-BERT: Multitask BERT for Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/mtl-bert-multitask-smiles-enumeration/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/mtl-bert-multitask-smiles-enumeration/</guid><description>MTL-BERT combines BERT pretraining, multitask learning, and SMILES enumeration for molecular property prediction across 60 drug discovery datasets.</description><content:encoded><![CDATA[<h2 id="a-multitask-bert-framework-for-molecular-property-prediction">A Multitask BERT Framework for Molecular Property Prediction</h2>
<p>MTL-BERT is a <strong>Method</strong> paper that introduces a multitask learning framework built on BERT for predicting molecular properties from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a>. The primary contribution is the combination of three strategies to address data scarcity in drug discovery: (1) masked token pretraining on 1.7 million unlabeled molecules from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a>, (2) multitask fine-tuning across 60 property prediction datasets simultaneously, and (3) <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> as a data augmentation technique applied during pretraining, fine-tuning, and inference. The model achieves strong performance across 60 <a href="https://en.wikipedia.org/wiki/ADME">ADMET</a> and molecular property datasets (44 classification and 16 regression), outperforming baselines including GNNs, XGBoost with molecular fingerprints, and prior <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a> approaches.</p>
<h2 id="data-scarcity-in-molecular-property-prediction">Data Scarcity in Molecular Property Prediction</h2>
<p>Deep learning methods for molecular property prediction face a fundamental tension: they require large amounts of labeled data to learn effectively, but labeled bioactivity data is scarce due to the cost and time of laboratory experiments. Existing approaches at the time of publication addressed this in isolation. Graph neural networks (GNNs) learn from molecular graphs but are typically shallow (2-3 layers) and prone to overfitting on small datasets. The original SMILES-BERT model applied masked language modeling to SMILES strings but fine-tuned separately for each task, missing opportunities to share information across related properties. Fixed molecular representations like <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (continuous and data-driven descriptors) cannot be further optimized for specific downstream tasks.</p>
<p>The authors identify three specific gaps: (1) single-task fine-tuning wastes the correlations between related ADMET properties (e.g., <a href="https://en.wikipedia.org/wiki/Lipophilicity">lipophilicity</a> relates to many ADMET endpoints), (2) using only canonical SMILES limits the model&rsquo;s ability to learn robust molecular features, and (3) no prior work had combined pretraining, multitask learning, and SMILES enumeration into a unified framework.</p>
<h2 id="three-strategies-combined-pretraining-multitask-learning-and-smiles-enumeration">Three Strategies Combined: Pretraining, Multitask Learning, and SMILES Enumeration</h2>
<p>The core innovation of MTL-BERT is the synergistic combination of three strategies in a single pipeline.</p>
<h3 id="masked-smiles-pretraining">Masked SMILES Pretraining</h3>
<p>Following the BERT paradigm, MTL-BERT pretrains on 1.7 million unlabeled molecules from ChEMBL using a masked token recovery task. For each SMILES string, 15% of tokens are randomly selected: 80% are replaced with a [MASK] token, 10% are replaced with a random token, and 10% remain unchanged. The loss is computed only at masked positions. Unlike the original BERT, MTL-BERT omits the next-sentence prediction task since there is no sequential relationship between SMILES strings (following the RoBERTa finding that this task is unnecessary).</p>
<p>SMILES strings are tokenized with a regular expression that captures multi-character tokens (e.g., Si, Br, Cl) and common SMILES syntax. The model uses positional encoding to capture token order.</p>
<h3 id="transformer-architecture">Transformer Architecture</h3>
<p>The model uses a standard Transformer encoder with multihead self-attention. The scaled dot-product attention computes:</p>
<p>$$\mathbf{O}_h = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h$$</p>
<p>where $\mathbf{Q}_h$, $\mathbf{K}_h$, and $\mathbf{V}_h$ are the query, key, and value matrices for head $h$, and $\sqrt{d_k}$ is a scaling factor. The outputs from all heads are concatenated and projected. Each attention sublayer is followed by a position-wise feedforward network with GELU activation, layer normalization, and residual connections.</p>
<p>Three model sizes were compared:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>Heads</th>
          <th>Embedding Size</th>
          <th>FFN Size</th>
          <th>Recovery Accuracy</th>
          <th>Fine-tuning Performance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MTL-BERT_SMALL</td>
          <td>4</td>
          <td>4</td>
          <td>128</td>
          <td>512</td>
          <td>0.931</td>
          <td>0.826</td>
      </tr>
      <tr>
          <td>MTL-BERT_MEDIUM</td>
          <td>8</td>
          <td>8</td>
          <td>256</td>
          <td>1,024</td>
          <td>0.962</td>
          <td>0.852</td>
      </tr>
      <tr>
          <td>MTL-BERT_LARGE</td>
          <td>12</td>
          <td>12</td>
          <td>576</td>
          <td>2,304</td>
          <td>0.974</td>
          <td>0.848</td>
      </tr>
  </tbody>
</table>
<p>The medium model was selected for its best fine-tuning performance with lower computational cost, despite the large model achieving higher pretraining recovery accuracy. The slight performance drop for the large model suggests mild overfitting.</p>
<h3 id="multitask-fine-tuning-with-task-tokens">Multitask Fine-tuning with Task Tokens</h3>
<p>During fine-tuning, task tokens ([T0], [T1], &hellip;) are prepended to each input SMILES string. The Transformer output at each task token position is passed through a task-specific two-layer feedforward network for the corresponding prediction task. An attention mask prevents direct information exchange between task tokens, allowing each task to learn directly from SMILES tokens without interference. This design also reduces the discrepancy between pretraining (no task tokens visible) and fine-tuning.</p>
<p>Cross-entropy loss is used for classification tasks and mean squared error for regression tasks. The total multitask loss is a simple sum of per-task losses without learned weighting.</p>
<h3 id="smiles-enumeration-as-data-augmentation">SMILES Enumeration as Data Augmentation</h3>
<p>A molecule can be represented by multiple valid SMILES strings by varying starting atoms and traversal orders. MTL-BERT applies SMILES enumeration at all three stages:</p>
<ol>
<li><strong>Pretraining</strong>: Enumerated SMILES increase diversity of the self-supervised training data.</li>
<li><strong>Fine-tuning</strong>: Each dataset is augmented 20x with random SMILES variants, increasing data diversity and helping the model learn position-invariant features.</li>
<li><strong>Inference</strong>: Multiple SMILES are generated per test molecule, predictions are fused (averaged) for a more robust final prediction.</li>
</ol>
<p>The 20x augmentation factor was chosen based on prior work showing diminishing returns beyond this level while significantly increasing computational cost.</p>
<h2 id="experimental-evaluation-across-60-datasets">Experimental Evaluation Across 60 Datasets</h2>
<h3 id="setup">Setup</h3>
<p>MTL-BERT was evaluated on 60 datasets (44 classification, 16 regression) covering ADMET properties and common molecular benchmarks. Datasets were sourced from ADMETlab and <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>. Each dataset was split 8:1:1 (train/validation/test), and experiments were repeated 10 times with random splits, reporting mean and standard deviation.</p>
<p>Classification tasks were evaluated with <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic">ROC-AUC</a> and accuracy; regression tasks with $R^2$ and RMSE.</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines were compared:</p>
<ul>
<li><strong>ECFP4-XGBoost</strong>: Extended-connectivity fingerprints (diameter 4) with gradient boosting</li>
<li><strong>Graph Attention Network (GAT)</strong></li>
<li><strong>Graph Convolutional Network (GCN)</strong></li>
<li><strong>AttentiveFP</strong>: A GNN with attention for molecular property prediction</li>
<li><strong>CDDD</strong>: Continuous and data-driven descriptors from a pretrained RNN auto-encoder</li>
</ul>
<h3 id="ablation-study">Ablation Study</h3>
<p>Three model variants were compared to isolate contributions:</p>
<ul>
<li><strong>MTL-BERT</strong>: Full model (pretraining + multitask + SMILES enumeration)</li>
<li><strong>STL-BERT</strong>: Single-task fine-tuning with SMILES enumeration (no multitask)</li>
<li><strong>Cano-BERT</strong>: Canonical SMILES only, single-task fine-tuning (equivalent to SMILES-BERT)</li>
</ul>
<p>Cano-BERT showed more than 10% degradation on several datasets (CL, Fu, LC50DM) compared to STL-BERT, demonstrating the importance of SMILES enumeration. MTL-BERT outperformed STL-BERT on most datasets, with improvements exceeding 5% on $F_{20\%}$, SR-ARE, and SR-ATAD5, confirming that multitask learning provides additional benefit on top of enumeration.</p>
<h3 id="results-vs-baselines">Results vs. Baselines</h3>
<p>MTL-BERT outperformed all baselines on nearly all 60 datasets. Specific findings:</p>
<ul>
<li>ECFP4-XGBoost performed inconsistently, doing well on some tasks (e.g., $F_{30\%}$, BACE, CL) but poorly on others, reflecting the limitation of fixed-length fingerprint representations.</li>
<li>GNNs generally improved over fingerprints but still suffered from data scarcity, falling behind ECFP4-XGBoost by more than 3% on $F_{30\%}$, Carcinogenicity, CL, and VD.</li>
<li>MTL-BERT surpassed all baselines except on CYP2C19-sub and BACE (by less than 1.1%).</li>
<li>On 14 tasks (NR-ER, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, Bioconcentration Factor, Fu, LC50FM, Lipophilicity, CL, PPB, VD, LC50DM), MTL-BERT exceeded the best baseline by more than 5-10%.</li>
<li>Improvements were statistically significant at the 95% confidence level (paired t-test, $P \leq 0.001$).</li>
</ul>
<h3 id="representation-analysis">Representation Analysis</h3>
<p><a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding">t-SNE</a> visualization of pretrained token embeddings (from 1,000 randomly selected molecules, approximately 35,000 tokens) showed that:</p>
<ul>
<li>Tokens of the same type cluster together (capturing atomic type information).</li>
<li>Within type clusters, sub-groups correspond to different chemical environments (e.g., oxygen atoms in nitrate groups vs. carbonyl groups).</li>
<li>Nearby embeddings share similar molecular neighborhood environments.</li>
</ul>
<h3 id="attention-based-interpretability">Attention-based Interpretability</h3>
<p>The model&rsquo;s attention weights provide interpretability for predictions:</p>
<ul>
<li>For a solubility task (LogS/LogD), attention concentrated on polar groups, which are known determinants of aqueous solubility.</li>
<li>For <a href="https://en.wikipedia.org/wiki/Ames_test">AMES</a> (mutagenicity), attention focused on <a href="https://en.wikipedia.org/wiki/Azide">azide</a>, nitrosamide, <a href="https://en.wikipedia.org/wiki/Acyl_chloride">acylchloride</a>, and nitrite groups, which are known mutagenic structural alerts.</li>
</ul>
<h2 id="performance-gains-from-combined-strategies-with-interpretable-attention">Performance Gains from Combined Strategies with Interpretable Attention</h2>
<p>MTL-BERT demonstrates that the combination of pretraining, multitask learning, and SMILES enumeration is more effective than any individual strategy for molecular property prediction. The ablation study provides clear evidence for the additive benefit of each component.</p>
<p>Key strengths include the breadth of evaluation (60 datasets covering diverse ADMET endpoints), the consistent improvement over multiple baseline types (fingerprints, GNNs, pretrained representations), and the interpretable attention mechanism that highlights chemically meaningful substructures.</p>
<p>Limitations to note: the simple sum of multitask losses (no learned task weighting) may not be optimal when tasks have very different scales or when some tasks are unrelated. The authors observe slight degradation on a few datasets (AMES, CYP1A2-Sub, FreeSolv), suggesting negative transfer in those cases. The 20x SMILES enumeration significantly increases computational cost during fine-tuning and inference. The paper does not report wall-clock training times or GPU hours, making it difficult to assess the practical cost of the enumeration strategy. Hardware details are not specified beyond acknowledgment of the High-Performance Computing Center at Central South University.</p>
<p>The hierarchical clustering of task representations reveals meaningful task groupings (e.g., LogD and LogP cluster together due to their shared relationship with water solubility), supporting the premise that multitask learning captures cross-task correlations.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL</td>
          <td>1.7M molecules</td>
          <td>Unlabeled SMILES; 10% held out for evaluation</td>
      </tr>
      <tr>
          <td>Fine-tuning/Evaluation</td>
          <td>ADMETlab + MoleculeNet</td>
          <td>60 datasets (44 classification, 16 regression)</td>
          <td>8:1:1 train/val/test split</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining</strong>: Masked token prediction (15% masking rate: 80% [MASK], 10% random, 10% unchanged). Adam optimizer, learning rate 1e-4, batch size 512, 50 epochs.</li>
<li><strong>Fine-tuning</strong>: Adam optimizer, learning rate 5e-5, batch size 64, dropout 0.1. Cross-entropy for classification, MSE for regression. Early stopping with patience 20, max 200 epochs.</li>
<li><strong>SMILES enumeration</strong>: 20x augmentation. Repeated search up to 100 times if enumerated SMILES is identical to a previous one.</li>
<li><strong>Inference fusion</strong>: Predictions from multiple enumerated SMILES are averaged.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>MTL-BERT_MEDIUM (selected model): 8 layers, 8 attention heads, 256 embedding size, 1,024 FFN size</li>
<li>Pretraining recovery accuracy: 0.962</li>
<li>1,000 task tokens pre-allocated for future tasks</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>Classification</td>
          <td>Secondary metric</td>
      </tr>
      <tr>
          <td>$R^2$</td>
          <td>Regression</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression</td>
          <td>Secondary metric</td>
      </tr>
  </tbody>
</table>
<p>All experiments repeated 10 times with random splits; mean and standard deviation reported.</p>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not reported in the paper. The authors acknowledge the High-Performance Computing Center of Central South University.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/zhang-xuan1314/MTL-BERT">MTL-BERT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL</a></td>
          <td>Dataset</td>
          <td>CC BY-SA 3.0</td>
          <td>Pretraining data source</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Fine-tuning benchmark</td>
      </tr>
      <tr>
          <td><a href="https://admetmesh.scbdd.com/">ADMETlab</a></td>
          <td>Dataset</td>
          <td>Free for academic use</td>
          <td>ADMET property datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Wu, C.-K., Yi, J.-C., Zeng, X.-X., Yang, C.-Q., Lu, A.-P., Hou, T.-J., &amp; Cao, D.-S. (2022). Pushing the boundaries of molecular property prediction for drug discovery with multitask learning BERT enhanced by SMILES enumeration. <em>Research</em>, 2022, Article 0004. <a href="https://doi.org/10.34133/research.0004">https://doi.org/10.34133/research.0004</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2022mtlbert,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Pushing the Boundaries of Molecular Property Prediction for Drug Discovery with Multitask Learning BERT Enhanced by SMILES Enumeration}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Xiao-Chen and Wu, Cheng-Kun and Yi, Jia-Cai and Zeng, Xiang-Xiang and Yang, Can-Qun and Lu, Ai-Ping and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{Article 0004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.34133/research.0004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Association for the Advancement of Science (AAAS)}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Mol2vec: Unsupervised ML with Chemical Intuition</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/</guid><description>Mol2vec applies Word2vec to Morgan substructures, learning dense vector representations of molecules that capture chemical similarity for property prediction.</description><content:encoded><![CDATA[<h2 id="word2vec-meets-cheminformatics">Word2vec Meets Cheminformatics</h2>
<p>Mol2vec is a <strong>Method</strong> paper that introduces an unsupervised approach for learning dense vector representations of molecular substructures. The core idea is a direct analogy to <a href="/notes/machine-learning/model-architectures/distributed-representations/">Word2vec</a> from natural language processing: molecular substructures (derived from the Morgan algorithm) are treated as &ldquo;words,&rdquo; and entire molecules are treated as &ldquo;sentences.&rdquo; By training on a large unlabeled corpus of 19.9 million compounds, Mol2vec produces embeddings where chemically related substructures occupy nearby regions of vector space. Compound-level vectors are then obtained by summing constituent substructure vectors, and these can serve as features for downstream supervised learning tasks.</p>
<h2 id="sparse-fingerprints-and-their-limitations">Sparse Fingerprints and Their Limitations</h2>
<p>Molecular fingerprints, particularly Morgan fingerprints (extended-connectivity fingerprints, ECFP), are among the most widely used molecular representations in cheminformatics. They perform well for similarity searching, virtual screening, and activity prediction. However, they suffer from several practical drawbacks:</p>
<ul>
<li><strong>High dimensionality and sparsity</strong>: Morgan fingerprints are typically hashed to fixed-length binary vectors (e.g., 2048 or 4096 bits), resulting in very sparse representations.</li>
<li><strong>Bit collisions</strong>: The hashing step can map distinct substructures to the same bit position, losing structural information.</li>
<li><strong>No learned relationships</strong>: Each bit is independent, so the representation does not encode any notion of chemical similarity between substructures.</li>
</ul>
<p>At the time of this work (2017), NLP techniques had started to appear in cheminformatics. The <a href="https://en.wikipedia.org/wiki/Tf%E2%80%93idf">tf-idf</a> method had been applied to Morgan fingerprints for compound-protein interaction prediction, and <a href="https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation">Latent Dirichlet Allocation</a> had been used for chemical topic modeling. The Word2vec concept had been adapted for protein sequences (ProtVec) but had not yet been applied to small molecules. Mol2vec fills this gap.</p>
<h2 id="from-substructure-identifiers-to-dense-embeddings">From Substructure Identifiers to Dense Embeddings</h2>
<p>The central insight of Mol2vec is that the Morgan algorithm already produces a natural &ldquo;vocabulary&rdquo; of molecular substructures, and the order in which these substructures appear in a molecule provides local context, analogous to word order in a sentence.</p>
<h3 id="corpus-construction">Corpus Construction</h3>
<p>The training corpus was assembled from <a href="https://en.wikipedia.org/wiki/ZINC_database">ZINC</a> v15 and <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> v23, merged and deduplicated, then filtered by molecular weight (12-600), heavy atom count (3-50), clogP (-5 to 7), and allowed elements (H, B, C, N, O, F, P, S, Cl, Br). This yielded 19.9 million compounds.</p>
<h3 id="sentence-generation">Sentence Generation</h3>
<p>For each molecule, the Morgan algorithm generates atom identifiers at radius 0 and radius 1. Each atom contributes two identifiers (one per radius), ordered according to the atom order in the canonical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>. This sequence of identifiers forms a &ldquo;sentence&rdquo; for Word2vec training.</p>
<h3 id="word2vec-training">Word2vec Training</h3>
<p>The model was trained using the gensim implementation of Word2vec. After evaluating both CBOW and Skip-gram architectures with window sizes of 5, 10, and 20, and embedding dimensions of 100 and 300, the best configuration was:</p>
<ul>
<li><strong>Architecture</strong>: Skip-gram</li>
<li><strong>Window size</strong>: 10</li>
<li><strong>Embedding dimension</strong>: 300</li>
</ul>
<p>Rare identifiers appearing fewer than 3 times in the corpus were replaced with a special &ldquo;UNSEEN&rdquo; token, which learns a near-zero vector. This allows the model to handle novel substructures at inference time.</p>
<h3 id="compound-vector-generation">Compound Vector Generation</h3>
<p>The final vector for a molecule is the sum of all its substructure vectors:</p>
<p>$$\mathbf{v}_{\text{mol}} = \sum_{i=1}^{N} \mathbf{v}_{s_i}$$</p>
<p>where $\mathbf{v}_{s_i}$ is the 300-dimensional embedding for the $i$-th substructure identifier in the molecule. This summation implicitly captures substructure counts and importance through vector amplitude.</p>
<h2 id="benchmarking-across-regression-and-classification-tasks">Benchmarking Across Regression and Classification Tasks</h2>
<h3 id="datasets">Datasets</h3>
<p>The authors evaluated Mol2vec on four datasets:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Task</th>
          <th>Size</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>Regression</td>
          <td>1,144</td>
          <td>Aqueous solubility prediction</td>
      </tr>
      <tr>
          <td>Ames</td>
          <td>Classification</td>
          <td>6,511</td>
          <td><a href="https://en.wikipedia.org/wiki/Mutagen">Mutagenicity</a> (balanced: 3,481 positive, 2,990 negative)</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>Classification</td>
          <td>8,192</td>
          <td>12 human toxicity targets (imbalanced)</td>
      </tr>
      <tr>
          <td>Kinase</td>
          <td>Classification</td>
          <td>284 kinases</td>
          <td>Bioactivity from ChEMBL v23</td>
      </tr>
  </tbody>
</table>
<h3 id="machine-learning-methods">Machine Learning Methods</h3>
<p>Three ML methods were compared using both Mol2vec and Morgan FP features:</p>
<ul>
<li><strong>Random Forest (RF)</strong>: scikit-learn, 500 estimators</li>
<li><strong>Gradient Boosting Machine (GBM)</strong>: XGBoost, 2000 estimators, max depth 3, learning rate 0.1</li>
<li><strong>Deep Neural Network (DNN)</strong>: Keras/TensorFlow, 4 hidden layers with 2000 neurons each for Mol2vec; 1 hidden layer with 512 neurons for Morgan FP</li>
</ul>
<p>All models were validated using 20x 5-fold cross-validation with the <a href="https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test">Wilcoxon signed-rank test</a> for statistical comparison.</p>
<h3 id="esol-regression-results">ESOL Regression Results</h3>
<table>
  <thead>
      <tr>
          <th>Features</th>
          <th>Method</th>
          <th>$R^2_{\text{ext}}$</th>
          <th>MSE</th>
          <th>MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Descriptors</td>
          <td>MLR</td>
          <td>0.81 +/- 0.01</td>
          <td>0.82</td>
          <td>0.69</td>
      </tr>
      <tr>
          <td>Molecular Graph</td>
          <td>CNN</td>
          <td>0.93</td>
          <td>0.31 +/- 0.03</td>
          <td>0.40 +/- 0.00</td>
      </tr>
      <tr>
          <td>Morgan FP</td>
          <td>GBM</td>
          <td>0.66 +/- 0.00</td>
          <td>1.43 +/- 0.00</td>
          <td>0.88 +/- 0.00</td>
      </tr>
      <tr>
          <td>Mol2vec</td>
          <td>GBM</td>
          <td>0.86 +/- 0.00</td>
          <td>0.62 +/- 0.00</td>
          <td>0.60 +/- 0.00</td>
      </tr>
  </tbody>
</table>
<p>Mol2vec substantially outperformed Morgan FP ($R^2_{\text{ext}}$ 0.86 vs. 0.66) but did not match the best graph convolution methods ($R^2_{\text{ext}}$ ~0.93).</p>
<h3 id="classification-results-ames-and-tox21">Classification Results (Ames and Tox21)</h3>
<p>On the Ames dataset, Mol2vec and Morgan FP performed comparably (AUC 0.87 vs. 0.88), both matching or exceeding prior SVM and Naive Bayes results. On Tox21, both achieved an average AUC of 0.83, outperforming literature results from graph convolution (0.71) and DNN/SVM approaches (0.71-0.72).</p>
<h3 id="proteochemometric-pcm-extension">Proteochemometric (PCM) Extension</h3>
<p>Mol2vec was combined with ProtVec (protein sequence embeddings using the same Word2vec approach on 3-grams) by concatenating vectors, forming PCM2vec. This was evaluated using a rigorous 4-level cross-validation scheme:</p>
<ul>
<li><strong>CV1</strong>: New compound-target pairs</li>
<li><strong>CV2</strong>: New targets</li>
<li><strong>CV3</strong>: New compounds</li>
<li><strong>CV4</strong>: New compounds and targets</li>
</ul>
<p>On Tox21, PCM2vec improved predictions for new compound-target pairs (CV1: AUC 0.87 vs. 0.79 for Morgan FP) and new compounds (CV3: AUC 0.85 vs. 0.78). On the kinase dataset, PCM2vec approached the performance of classical PCM (Morgan + z-scales) while being alignment-independent, meaning it can be applied to proteins with low sequence similarity.</p>
<h2 id="chemical-intuition-and-practical-value">Chemical Intuition and Practical Value</h2>
<h3 id="embedding-quality">Embedding Quality</h3>
<p>The learned substructure embeddings capture meaningful chemical relationships. Hierarchical clustering of the 25 most common substructures shows expected groupings: aromatic carbons cluster together, aliphatic ring carbons form a separate group, and carbonyl carbons and oxygens are closely related. Similarly, t-SNE projections of amino acid vectors encoded by Mol2vec reproduce known amino acid relationships (e.g., similar distances between Glu/Gln and Asp/Asn pairs, reflecting the carboxylic acid to amide transition).</p>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Skip-gram with 300-dimensional embeddings</strong> provides the best Mol2vec representations, consistent with NLP best practices.</li>
<li><strong>Mol2vec excels at regression tasks</strong>, substantially outperforming Morgan FP on ESOL solubility prediction ($R^2_{\text{ext}}$ 0.86 vs. 0.66).</li>
<li><strong>Classification performance is competitive</strong> with Morgan FP across Ames and Tox21 datasets.</li>
<li><strong>PCM2vec enables alignment-independent proteochemometrics</strong>, extending PCM approaches to diverse protein families with low sequence similarity.</li>
<li><strong>Tree-based methods (RF, GBM) outperformed DNNs</strong> on these tasks, though the authors note further DNN tuning could help.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The compound vector is a simple sum of substructure vectors, which discards information about substructure arrangement and molecular topology.</li>
<li>Only Morgan identifiers at radii 0 and 1 were used. Larger radii might capture more context but would increase vocabulary size.</li>
<li>DNN architectures were not extensively optimized, leaving open the question of how well Mol2vec pairs with deep learning.</li>
<li>The approach was benchmarked against Morgan FP but not against other learned representations such as graph neural networks in a controlled comparison.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC v15 + ChEMBL v23</td>
          <td>19.9M compounds</td>
          <td>Filtered by MW, atom count, clogP, element types</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,144 compounds</td>
          <td>Aqueous solubility regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Ames</td>
          <td>6,511 compounds</td>
          <td>Mutagenicity classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Tox21</td>
          <td>8,192 compounds</td>
          <td>12 toxicity targets, retrieved via DeepChem</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Kinase (ChEMBL v23)</td>
          <td>284 kinases</td>
          <td>IC50/Kd/Ki binding assays</td>
      </tr>
      <tr>
          <td>Protein corpus</td>
          <td><a href="https://en.wikipedia.org/wiki/UniProt">UniProt</a></td>
          <td>554,241 sequences</td>
          <td>For ProtVec training</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Word2vec</strong>: Skip-gram, window size 10, 300-dimensional embeddings, min count 3</li>
<li><strong>Morgan algorithm</strong>: Radii 0 and 1 (119 and 19,831 unique identifiers respectively)</li>
<li><strong>UNSEEN token</strong>: Replaces identifiers occurring fewer than 3 times</li>
<li><strong>Compound vector</strong>: Sum of all substructure vectors</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>RF</strong>: scikit-learn, 500 estimators, sqrt features, balanced class weights</li>
<li><strong>GBM</strong>: XGBoost, 2000 estimators, max depth 3, learning rate 0.1</li>
<li><strong>DNN</strong>: Keras/TensorFlow, 4 layers x 2000 neurons (Mol2vec) or 1 layer x 512 neurons (Morgan FP), ReLU activation, dropout 0.1</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Mol2vec Best</th>
          <th>Morgan FP Best</th>
          <th>Task</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>$R^2_{\text{ext}}$</td>
          <td>0.86 (GBM)</td>
          <td>0.66 (GBM)</td>
          <td>ESOL regression</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>0.87 (RF)</td>
          <td>0.88 (RF)</td>
          <td>Ames classification</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>0.83 (RF)</td>
          <td>0.83 (RF)</td>
          <td>Tox21 classification</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/samoturk/mol2vec">mol2vec</a></td>
          <td>Code</td>
          <td>BSD-3-Clause</td>
          <td>Python package with pre-trained model</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jaeger, S., Fulle, S., &amp; Turk, S. (2018). Mol2vec: Unsupervised Machine Learning Approach with Chemical Intuition. <em>Journal of Chemical Information and Modeling</em>, 58(1), 27-35. <a href="https://doi.org/10.1021/acs.jcim.7b00616">https://doi.org/10.1021/acs.jcim.7b00616</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jaeger2018mol2vec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Mol2vec: Unsupervised Machine Learning Approach with Chemical Intuition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jaeger, Sabrina and Fulle, Simone and Turk, Samo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{58}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{27--35}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.7b00616}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MG-BERT: Graph BERT for Molecular Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/mg-bert-molecular-graph-bert/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/mg-bert-molecular-graph-bert/</guid><description>MG-BERT integrates graph neural network message passing into BERT with masked atom pretraining on 1.7M molecules for molecular property prediction.</description><content:encoded><![CDATA[<h2 id="a-graph-aware-bert-for-molecular-property-prediction">A Graph-Aware BERT for Molecular Property Prediction</h2>
<p>MG-BERT is a <strong>Method</strong> paper that adapts the BERT pretraining paradigm from NLP to molecular graphs. The primary contribution is a modified Transformer architecture that replaces global self-attention with bond-based local attention, allowing atoms to exchange information only through chemical bonds. This creates a deep message-passing network that avoids the oversmoothing problem of conventional graph neural networks (GNNs). Combined with a masked atom prediction pretraining strategy on 1.7 million unlabeled molecules from ChEMBL, MG-BERT learns context-sensitive atomic representations that transfer effectively to downstream property prediction tasks.</p>
<h2 id="data-scarcity-in-molecular-property-prediction">Data Scarcity in Molecular Property Prediction</h2>
<p><a href="/notes/chemistry/molecular-design/property-prediction/">Molecular property prediction</a> is central to drug discovery, particularly for ADMET (Absorption, Distribution, Metabolism, Excretion, and Toxicity) endpoints. While deep learning has advanced many domains, molecular property prediction faces a persistent challenge: labeled data scarcity. ADMET measurements require expensive, time-consuming experiments, and typical datasets contain only hundreds to thousands of examples.</p>
<p>Prior approaches fall into three categories, each with limitations:</p>
<ol>
<li><strong>Feature engineering</strong> (molecular fingerprints, descriptors): Requires expert design, suffers from low scalability, and fixed representations cannot be optimized for specific tasks.</li>
<li><strong>SMILES-based deep learning</strong> (CNNs, LSTMs, Transformers on SMILES strings): Must learn to parse molecular information from complex string syntax, increasing learning difficulty. Autoencoder-based methods (e.g., <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a>) learn fixed representations that cannot be fine-tuned.</li>
<li><strong>Graph neural networks</strong> (GAT, GCN): Can learn directly from molecular topology, but are limited to 2-3 layers due to oversmoothing, restricting their capacity to capture deep-level patterns.</li>
</ol>
<p>The BERT model from NLP demonstrated that self-supervised pretraining on large unlabeled corpora followed by fine-tuning on small labeled datasets can substantially improve downstream performance. <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a> applied this idea to SMILES strings directly, but suffered from interpretability issues due to auxiliary characters in the SMILES syntax. MG-BERT addresses these limitations by operating directly on molecular graphs.</p>
<h2 id="bond-based-local-attention-and-masked-atom-pretraining">Bond-Based Local Attention and Masked Atom Pretraining</h2>
<p>The core innovation of MG-BERT has two components: a modified Transformer architecture for molecular graphs and a self-supervised pretraining strategy.</p>
<h3 id="architecture-modifications">Architecture Modifications</h3>
<p>The original BERT model uses three components: an embedding layer, Transformer encoder layers, and a task-specific output layer. MG-BERT makes three key modifications:</p>
<ol>
<li>
<p><strong>Atom embeddings replace word embeddings.</strong> The dictionary contains 16 tokens: 13 common atom types ([H], [C], [N], [O], [F], [S], [Cl], [P], [Br], [B], [I], [Si], [Se]), plus [UNK] for rare atoms, [MASK] for pretraining, and [GLOBAL] for graph-level readout.</p>
</li>
<li>
<p><strong>No positional encoding.</strong> Unlike sequential text, atoms in a molecular graph have no inherent ordering, so positional embeddings are removed.</p>
</li>
<li>
<p><strong>Local attention replaces global attention.</strong> The adjacency matrix of the molecular graph is used as a visibility matrix to modulate the attention scores. Each atom can only attend to atoms connected by chemical bonds. Formally, the attention is constrained so that:</p>
</li>
</ol>
<p>$$A&rsquo;_{ij} = \begin{cases} A_{ij} &amp; \text{if bond exists between } i \text{ and } j \\ -\infty &amp; \text{otherwise} \end{cases}$$</p>
<p>where $A_{ij}$ is the standard scaled dot-product attention score. This local message passing makes MG-BERT a variant of GNN, but one that can stack many layers (6 in the medium configuration) without oversmoothing, thanks to the residual connections inherited from the Transformer architecture.</p>
<ol start="4">
<li><strong>Supernode for graph-level readout.</strong> A [GLOBAL] supernode is added to each molecular graph, connected to all atoms. This node aggregates information from the entire molecule and serves as the molecular representation for downstream prediction.</li>
</ol>
<h3 id="masked-atom-prediction">Masked Atom Prediction</h3>
<p>The pretraining strategy mirrors BERT&rsquo;s masked language model but operates on atoms:</p>
<ul>
<li>15% of atoms in each molecule are randomly selected (at least one atom per molecule)</li>
<li>Of selected atoms: 80% are replaced with [MASK], 10% are randomly replaced with another atom type, and 10% remain unchanged</li>
<li>The model is trained to predict the original atom type at masked positions</li>
<li>Loss is computed only at masked positions</li>
</ul>
<h3 id="model-configurations">Model Configurations</h3>
<p>Three model sizes were compared:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Layers</th>
          <th>Heads</th>
          <th>Embedding Size</th>
          <th>FFN Size</th>
          <th>Recovery Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MG-BERT Small</td>
          <td>3</td>
          <td>2</td>
          <td>128</td>
          <td>256</td>
          <td>95.27%</td>
      </tr>
      <tr>
          <td>MG-BERT Medium</td>
          <td>6</td>
          <td>4</td>
          <td>256</td>
          <td>512</td>
          <td>98.31%</td>
      </tr>
      <tr>
          <td>MG-BERT Large</td>
          <td>12</td>
          <td>8</td>
          <td>576</td>
          <td>1152</td>
          <td>98.35%</td>
      </tr>
  </tbody>
</table>
<p>The medium configuration was selected for all experiments because it achieved the best downstream performance, despite the large model having slightly higher pretraining recovery accuracy. The authors attribute this to overfitting risk with the larger model.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<h3 id="pretraining">Pretraining</h3>
<p>MG-BERT was pretrained on 1.7 million compounds randomly selected from ChEMBL, with 10% held out for evaluation (1.53M training molecules). Molecules were converted to 2D undirected graphs using RDKit, with hydrogen atoms explicitly included. The model was pretrained for 10 epochs using Adam with learning rate 1e-4 and batch size 256.</p>
<h3 id="fine-tuning-datasets">Fine-tuning Datasets</h3>
<p>Sixteen datasets covering ADMET endpoints and common molecular properties were collected from ADMETlab and <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>:</p>
<table>
  <thead>
      <tr>
          <th>Type</th>
          <th>Dataset</th>
          <th>Category</th>
          <th>Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Regression</td>
          <td>Caco2</td>
          <td>Absorption</td>
          <td>979</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>logD</td>
          <td>Physicochemical</td>
          <td>10,354</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>logS</td>
          <td>Physicochemical</td>
          <td>5,045</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>PPB</td>
          <td>Distribution</td>
          <td>1,480</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>tox</td>
          <td>Toxicity</td>
          <td>7,295</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL</td>
          <td>Physicochemical</td>
          <td>1,128</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv</td>
          <td>Physicochemical</td>
          <td>642</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipo</td>
          <td>Physicochemical</td>
          <td>4,200</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Ames</td>
          <td>Toxicity</td>
          <td>6,719</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBB</td>
          <td>Distribution</td>
          <td>1,855</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>FDAMDD</td>
          <td>Toxicity</td>
          <td>795</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>H_HT</td>
          <td>Toxicity</td>
          <td>2,170</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Pgp_inh</td>
          <td>Absorption</td>
          <td>2,125</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Pgp_sub</td>
          <td>Absorption</td>
          <td>1,210</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE</td>
          <td>Biophysics</td>
          <td>1,513</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP</td>
          <td>Physiology</td>
          <td>2,039</td>
      </tr>
  </tbody>
</table>
<p>Datasets were split 8:1:1 (train:validation:test) with stratified sampling by SMILES length. Each experiment was repeated 10 times with random splits, reporting mean and standard deviation. Regression was evaluated by R-squared, classification by ROC-AUC. Early stopping with a maximum of 100 epochs was used.</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines were compared:</p>
<ol>
<li><strong>ECFP4-XGBoost</strong>: Extended connectivity fingerprints (diameter 4) with gradient-boosted trees</li>
<li><strong>GAT</strong>: Graph Attention Network</li>
<li><strong>GCN</strong>: Graph Convolutional Network</li>
<li><strong>CDDD</strong>: Continuous and Data-Driven Descriptors (pretrained RNN encoder on SMILES with a fully connected network)</li>
<li><strong>SMILES-BERT</strong>: Original BERT applied directly to SMILES strings</li>
</ol>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>Two ablation studies were conducted:</p>
<ol>
<li><strong>Pretraining effectiveness</strong>: Comparing pretrained vs. non-pretrained MG-BERT under identical hyperparameters</li>
<li><strong>Hydrogen atoms</strong>: Comparing MG-BERT with and without explicit hydrogen atoms in the molecular graph</li>
</ol>
<h2 id="consistent-improvements-across-admet-benchmarks">Consistent Improvements Across ADMET Benchmarks</h2>
<h3 id="main-results">Main Results</h3>
<p>MG-BERT consistently outperformed all baselines across all 16 datasets. Key results on the 11 ADMET datasets:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>ECFP4-XGBoost</th>
          <th>GAT</th>
          <th>GCN</th>
          <th>CDDD</th>
          <th>SMILES-BERT</th>
          <th>MG-BERT</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Caco2 (R2)</td>
          <td>61.41</td>
          <td>69.16</td>
          <td>67.15</td>
          <td>73.42</td>
          <td>72.39</td>
          <td><strong>74.68</strong></td>
      </tr>
      <tr>
          <td>logD (R2)</td>
          <td>70.84</td>
          <td>84.62</td>
          <td>86.22</td>
          <td>85.85</td>
          <td>86.31</td>
          <td><strong>87.46</strong></td>
      </tr>
      <tr>
          <td>logS (R2)</td>
          <td>73.73</td>
          <td>84.06</td>
          <td>83.47</td>
          <td>84.01</td>
          <td>85.20</td>
          <td><strong>87.66</strong></td>
      </tr>
      <tr>
          <td>PPB (R2)</td>
          <td>55.11</td>
          <td>59.96</td>
          <td>57.34</td>
          <td>54.12</td>
          <td>62.37</td>
          <td><strong>65.94</strong></td>
      </tr>
      <tr>
          <td>Ames (AUC)</td>
          <td>87.21</td>
          <td>86.38</td>
          <td>87.04</td>
          <td>86.82</td>
          <td>87.69</td>
          <td><strong>89.33</strong></td>
      </tr>
      <tr>
          <td>BBB (AUC)</td>
          <td>94.62</td>
          <td>93.03</td>
          <td>92.67</td>
          <td>94.44</td>
          <td>94.02</td>
          <td><strong>95.41</strong></td>
      </tr>
      <tr>
          <td>BBBP (AUC)</td>
          <td>89.16</td>
          <td>90.33</td>
          <td>90.74</td>
          <td>91.12</td>
          <td>91.32</td>
          <td><strong>92.08</strong></td>
      </tr>
  </tbody>
</table>
<p>The overall improvement across all datasets was 28.1% (7.02% on classification, 21.28% on regression). Improvements were statistically significant at the 95% confidence level (paired t-test, P &lt;= 0.001).</p>
<h3 id="pretraining-ablation">Pretraining Ablation</h3>
<p>Pretraining improved performance by more than 2% on all datasets. The benefit was largest for small datasets: Caco2 improved by approximately 10 percentage points (64.79 to 74.68 R2), and FDAMDD improved by about 7.5 points (80.76 to 88.23 AUC). This confirms that self-supervised pretraining effectively addresses the labeled data scarcity problem.</p>
<h3 id="hydrogen-atom-ablation">Hydrogen Atom Ablation</h3>
<p>Including explicit hydrogen atoms improved pretraining recovery accuracy from 92.25% to 98.31% and consistently improved downstream performance. The authors provide an intuitive explanation: hydrogen atoms help determine bond counts for neighboring atoms, which is critical for the masked atom recovery task. They also show that removing hydrogens can make structurally distinct molecules (e.g., benzene and cyclohexane) indistinguishable at the graph level.</p>
<h3 id="interpretability-via-attention-visualization">Interpretability via Attention Visualization</h3>
<p>The authors provide two forms of interpretability analysis:</p>
<ol>
<li>
<p><strong>t-SNE visualization of atomic representations</strong>: Pretrained atomic representations cluster by atom type and, more specifically, by local chemical environment (e.g., aromatic carbons separate from aliphatic carbons, C-N bonds from C-O bonds). This demonstrates that pretraining captures neighborhood context beyond simple atom identity.</p>
</li>
<li>
<p><strong>Attention weight visualization</strong>: On the logD task, the supernode&rsquo;s attention focuses on polar groups (which govern lipophilicity). On the Ames mutagenicity task, attention concentrates on known mutagenic structural alerts (acylchloride, nitrosamide, azide groups). This provides chemically meaningful explanations for predictions.</p>
</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The paper does not extensively discuss limitations, but several can be identified:</p>
<ul>
<li>The model uses only 2D molecular topology (atom types and bonds) without 3D conformational information or bond-type features</li>
<li>The atom dictionary is limited to 13 common types plus [UNK], which may lose information for molecules containing rarer elements</li>
<li>Evaluation is limited to ADMET-focused datasets; broader chemical spaces (e.g., materials, catalysts) are not tested</li>
<li>The comparison baselines do not include other graph-based pretraining methods (e.g., the contemporaneous Strategies for Pre-training Graph Neural Networks by Hu et al.)</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL (random subset)</td>
          <td>1.7M molecules (1.53M train)</td>
          <td>10% held out for evaluation</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>ADMETlab + MoleculeNet</td>
          <td>16 datasets (642-10,354 molecules)</td>
          <td>8:1:1 splits, stratified by SMILES length</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam (pretraining: lr=1e-4, batch=256; fine-tuning: lr from {1e-5, 5e-5, 1e-4}, batch from {16, 32, 64})</li>
<li><strong>Pretraining epochs</strong>: 10</li>
<li><strong>Fine-tuning</strong>: Up to 100 epochs with early stopping</li>
<li><strong>Dropout</strong>: Optimized per task in range [0.0, 0.5]</li>
<li><strong>Masking</strong>: 15% of atoms (80% [MASK], 10% random, 10% unchanged)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: MG-BERT Medium (6 layers, 4 heads, embedding size 256, FFN size 512)</li>
<li><strong>Molecule processing</strong>: RDKit for graph conversion with explicit hydrogens</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>R-squared (R2)</td>
          <td>Regression</td>
          <td>Higher is better</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Higher is better</td>
      </tr>
      <tr>
          <td>Accuracy, RMSE</td>
          <td>Both</td>
          <td>Reported in supplementary Table S1</td>
      </tr>
  </tbody>
</table>
<p>All results averaged over 10 random splits with standard deviations reported.</p>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements (GPU type, training time, or memory usage).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/zhang-xuan1314/Molecular-graph-BERT">Molecular-graph-BERT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Jupyter Notebook implementation; last code push August 2021</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Wu, C.-K., Yang, Z.-J., Wu, Z.-X., Yi, J.-C., Hsieh, C.-Y., Hou, T.-J., &amp; Cao, D.-S. (2021). MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction. <em>Briefings in Bioinformatics</em>, 22(6), bbab152. <a href="https://doi.org/10.1093/bib/bbab152">https://doi.org/10.1093/bib/bbab152</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhang2021mgbert,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{MG-BERT}: leveraging unsupervised atomic representation learning for molecular property prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhang, Xiao-Chen and Wu, Cheng-Kun and Yang, Zhi-Jiang and Wu, Zhen-Xing and Yi, Jia-Cai and Hsieh, Chang-Yu and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{bbab152}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bib/bbab152}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Maxsmi: SMILES Augmentation for Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/</guid><description>Maxsmi systematically evaluates five SMILES augmentation strategies with CNN and RNN models across solubility, lipophilicity, and bioactivity tasks.</description><content:encoded><![CDATA[<h2 id="systematic-benchmarking-of-smiles-data-augmentation">Systematic Benchmarking of SMILES Data Augmentation</h2>
<p>This is an <strong>Empirical</strong> paper that systematically evaluates how SMILES augmentation affects deep learning molecular property prediction. The primary contribution is a comprehensive comparison of five augmentation strategies across three neural network architectures and four datasets, producing the &ldquo;Maxsmi&rdquo; models that maximize prediction performance. The study also demonstrates that test-time augmentation provides a practical confidence measure for predictions.</p>
<h2 id="the-data-scarcity-problem-in-qsar-modeling">The Data Scarcity Problem in QSAR Modeling</h2>
<p>Deep learning models require large training sets to perform well, but experimental physico-chemical and bioactivity datasets remain small, typically ranging from hundreds to a few thousand compounds. SMILES augmentation, where the non-unique <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES representation</a> of a molecule is exploited to generate multiple training examples per compound, has been shown to help in prior work by Bjerrum (2017), Kimber et al. (2018), and Li and Fourches (2020). However, no prior study had systematically compared different augmentation strategies, analyzed how much augmentation is needed, or examined the relationship between augmentation factor and prediction confidence. Most previous work chose augmentation numbers a priori without justification. Maxsmi fills this gap by providing a systematic analysis and practical guidelines.</p>
<h2 id="five-augmentation-strategies-and-test-time-ensemble-learning">Five Augmentation Strategies and Test-Time Ensemble Learning</h2>
<p>The core insight is twofold. First, the authors define five distinct strategies for generating augmented SMILES:</p>
<ol>
<li><strong>No augmentation</strong>: use only the canonical SMILES (baseline)</li>
<li><strong>Augmentation with duplication</strong>: generate $m$ random SMILES per compound, allowing duplicates; dataset grows to $N \times m$</li>
<li><strong>Augmentation without duplication</strong>: generate $m$ random SMILES and discard exact duplicates</li>
<li><strong>Augmentation with reduced duplication</strong>: keep only $f(m) = \sqrt{m}$ copies of each duplicate, a compromise between the above</li>
<li><strong>Augmentation with estimated maximum</strong>: sample random SMILES until the same string has been generated 10 times, attempting to cover most of the valid SMILES space</li>
</ol>
<p>Second, the authors formalize test-time augmentation as ensemble learning. Given a trained model $M_{\Theta}$, each test compound $C$ is represented by $k$ random SMILES $S_1(C), \ldots, S_k(C)$. The per-SMILES predictions are:</p>
<p>$$
\hat{y}_i(C) = M_{\Theta}(S_i(C))
$$</p>
<p>The compound-level prediction is an aggregation (mean) over these:</p>
<p>$$
\hat{y}(C) = A\big(\hat{y}_1(C), \ldots, \hat{y}_k(C)\big)
$$</p>
<p>The standard deviation of the per-SMILES predictions serves as a confidence measure: high variance indicates the model is uncertain about a compound.</p>
<h2 id="experimental-design-three-architectures-four-datasets">Experimental Design: Three Architectures, Four Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size (after preprocessing)</th>
          <th>Train / Test</th>
          <th>Task</th>
          <th>Provenance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>1,128</td>
          <td>902 / 226</td>
          <td>Water solubility</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></td>
      </tr>
      <tr>
          <td>ESOL_small</td>
          <td>1,068</td>
          <td>854 / 214</td>
          <td>Solubility (max 25 heavy atoms)</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>513 / 129</td>
          <td>Hydration free energy</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>4,199</td>
          <td>3,359 / 840</td>
          <td>Octanol/water distribution</td>
          <td><a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a></td>
      </tr>
      <tr>
          <td>Affinity (EGFR)</td>
          <td>5,849</td>
          <td>4,679 / 1,170</td>
          <td><a href="https://en.wikipedia.org/wiki/IC50">pIC50</a> against <a href="https://en.wikipedia.org/wiki/Epidermal_growth_factor_receptor">EGFR</a> kinase</td>
          <td>Kinodata</td>
      </tr>
  </tbody>
</table>
<h3 id="architectures">Architectures</h3>
<p>Three shallow neural networks are compared:</p>
<ul>
<li><strong>CONV1D</strong>: 1D convolution (kernel size 10, stride 1) followed by two fully connected layers</li>
<li><strong>CONV2D</strong>: 2D convolution on the one-hot encoded SMILES matrix, followed by two fully connected layers</li>
<li><strong>RNN</strong>: LSTM layer followed by two fully connected layers (128 and 64 units)</li>
</ul>
<p>All models are trained for 250 epochs with batch size 16, MSE loss, SGD optimizer, and learning rate 0.001. A Random Forest baseline with Morgan fingerprints (radius 2, length 1024) is also included.</p>
<h3 id="augmentation-sweep">Augmentation sweep</h3>
<p>The augmentation number $m$ is varied from 1 to 20 (step 1) and from 20 to 100 (step 10) for three strategies (with, without, and reduced duplication). The estimated maximum strategy is tested on the smaller datasets. Both training and test sets receive the same augmentation.</p>
<h2 id="key-findings-augmentation-consistently-improves-rmse">Key Findings: Augmentation Consistently Improves RMSE</h2>
<h3 id="augmentation-always-helps">Augmentation always helps</h3>
<p>Across all datasets and architectures, SMILES augmentation reduces test RMSE compared to the no-augmentation baseline. Performance improves sharply in the low augmentation range (1 to 10) and reaches a plateau around 40 to 70, after which additional augmentation provides diminishing returns.</p>
<h3 id="best-models-maxsmi">Best models (Maxsmi)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Model</th>
          <th>Augmentation Number</th>
          <th>Strategy</th>
          <th>Test RMSE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>CONV1D</td>
          <td>70</td>
          <td>Reduced duplication</td>
          <td>0.569</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>CONV1D</td>
          <td>70</td>
          <td>With duplication</td>
          <td>1.032</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>CONV1D</td>
          <td>80</td>
          <td>Without duplication</td>
          <td>0.593</td>
      </tr>
  </tbody>
</table>
<p>The CONV1D architecture consistently outperforms RNN and CONV2D. For ESOL, the CONV1D model improves from 0.839 RMSE (no augmentation) to 0.569 RMSE (70x reduced duplication), a 32% reduction.</p>
<h3 id="no-single-best-augmentation-strategy">No single best augmentation strategy</h3>
<p>The three main augmentation strategies (with, without, and reduced duplication) perform similarly. Generating the estimated maximum number of unique SMILES does not yield the best results, suggesting a saturation point exists where additional SMILES diversity stops helping.</p>
<h3 id="canonical-smiles-outperform-single-random-smiles">Canonical SMILES outperform single random SMILES</h3>
<p>When augmentation is limited to a single representation ($m = 1$), the canonical SMILES consistently outperforms a single random SMILES. On ESOL with CONV1D, the canonical model achieves 0.839 RMSE versus 0.964 for a random SMILES. The authors attribute this to the simpler, more readable structure of canonical SMILES (fewer branches and brackets).</p>
<h3 id="comparison-to-prior-work">Comparison to prior work</h3>
<table>
  <thead>
      <tr>
          <th>Study</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
          <th>Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Maxsmi</td>
          <td>0.569</td>
          <td>1.032</td>
          <td>0.593</td>
          <td>CNN</td>
      </tr>
      <tr>
          <td>MoleculeNet</td>
          <td>0.58 +/- 0.03</td>
          <td>1.15 +/- 0.12</td>
          <td>0.655 +/- 0.036</td>
          <td>GNN</td>
      </tr>
      <tr>
          <td>CNF</td>
          <td>0.62</td>
          <td>1.11</td>
          <td>0.67</td>
          <td>CNN</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT</a></td>
          <td>N/A</td>
          <td>1.197 +/- 0.127</td>
          <td>0.565 +/- 0.037</td>
          <td>RNN</td>
      </tr>
  </tbody>
</table>
<p>Maxsmi outperforms or matches MoleculeNet&rsquo;s graph neural networks and the CNF model on all three tasks. MolPMoFiT slightly outperforms Maxsmi on lipophilicity (0.565 vs 0.593) but performs worse on FreeSolv.</p>
<h3 id="confidence-estimation">Confidence estimation</h3>
<p>The standard deviation of per-SMILES predictions correlates with prediction error. Confidence curves show that sequentially removing compounds with the highest uncertainty leads to monotonically decreasing mean prediction error. For ESOL, keeping only the top 10% most confident predictions yields errors below 0.25.</p>
<h3 id="egfr-affinity-test-case">EGFR affinity test case</h3>
<p>Applying the Maxsmi approach (CONV1D, 70x augmentation, reduced duplication) to EGFR kinase affinity prediction yields test RMSE of 0.777 and R2 of 0.712, compared to 1.031 RMSE and 0.494 R2 for the canonical model (a 25% RMSE improvement). The Random Forest baseline (0.758 RMSE, 0.726 R2) performs comparably, which the authors note without further explanation.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>All experiments use a single train/test split (80/20) without cross-validation, due to the computational cost of the full augmentation sweep. This means reported RMSE values lack uncertainty estimates for the Maxsmi models.</li>
<li>The study uses shallow networks only. Whether the same augmentation benefits apply to deeper architectures or pre-trained models is untested.</li>
<li>The EGFR test case shows the Random Forest baseline performing comparably to the Maxsmi model, raising questions about when SMILES augmentation provides a meaningful advantage over traditional fingerprint-based methods.</li>
<li>The comparison to prior work uses different splits, preprocessing, and evaluation protocols across studies, which the authors acknowledge limits direct comparability.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>MoleculeNet, water solubility</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>MoleculeNet, hydration free energy</td>
      </tr>
      <tr>
          <td>Training/Evaluation</td>
          <td>Lipophilicity</td>
          <td>4,199</td>
          <td>ChEMBL, logD</td>
      </tr>
      <tr>
          <td>Test case</td>
          <td>EGFR Affinity</td>
          <td>5,849</td>
          <td>Kinodata (ChEMBL v28), pIC50</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available through MoleculeNet/DeepChem and Kinodata.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>SMILES generation via <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>&rsquo;s random SMILES enumeration</li>
<li>One-hot encoding of SMILES characters with padding to max length</li>
<li>Five augmentation strategies applied to both training and test sets</li>
<li>Mean aggregation for compound-level predictions</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CONV1D</td>
          <td>1D conv (kernel 10, stride 1) + 2 FC layers</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>CONV2D</td>
          <td>2D conv (single channel) + 2 FC layers</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>RNN</td>
          <td>LSTM + FC(128) + FC(64)</td>
          <td>Not specified</td>
      </tr>
      <tr>
          <td>RF Baseline</td>
          <td>Random Forest (default sklearn)</td>
          <td>Morgan FP, radius 2, length 1024</td>
      </tr>
  </tbody>
</table>
<p>Training: 250 epochs, batch size 16, MSE loss, SGD, lr=0.001.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSE (ESOL)</td>
          <td>0.569</td>
          <td>1.102 (RF)</td>
          <td>CONV1D, 70x reduced dup</td>
      </tr>
      <tr>
          <td>RMSE (FreeSolv)</td>
          <td>1.032</td>
          <td>2.563 (RF)</td>
          <td>CONV1D, 70x with dup</td>
      </tr>
      <tr>
          <td>RMSE (Lipophilicity)</td>
          <td>0.593</td>
          <td>0.860 (RF)</td>
          <td>CONV1D, 80x without dup</td>
      </tr>
      <tr>
          <td>RMSE (EGFR)</td>
          <td>0.777</td>
          <td>0.758 (RF)</td>
          <td>CONV1D, 70x reduced dup</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on a GeForce GTX 1080 Ti, provided by the HPC cluster at Freie Universitat Berlin. Training CONV1D on ESOL with 100x augmentation (keeping duplicates, 90,200 data points) takes approximately 3 hours. Training with 19x augmentation achieves RMSE of 0.605 in under 30 minutes.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/volkamerlab/maxsmi">volkamerlab/maxsmi</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Full source code, trained models, CLI for prediction</td>
      </tr>
      <tr>
          <td><a href="https://maxsmi.readthedocs.io/en/latest/">Documentation</a></td>
          <td>Docs</td>
          <td>N/A</td>
          <td>Read the Docs documentation</td>
      </tr>
      <tr>
          <td><a href="https://github.com/openkinome/kinodata">Kinodata</a></td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Curated kinase bioactivity data from ChEMBL v28</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Highly Reproducible. Code, data, trained models, and a command-line prediction tool are all publicly available under the MIT license.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kimber, T. B., Gagnebin, M., &amp; Volkamer, A. (2021). Maxsmi: Maximizing molecular property prediction performance with confidence estimation using SMILES augmentation and deep learning. <em>Artificial Intelligence in the Life Sciences</em>, 1, 100014. <a href="https://doi.org/10.1016/j.ailsci.2021.100014">https://doi.org/10.1016/j.ailsci.2021.100014</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kimber2021maxsmi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Maxsmi: Maximizing molecular property prediction performance with confidence estimation using SMILES augmentation and deep learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kimber, Talia B. and Gagnebin, Maxime and Volkamer, Andrea}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Artificial Intelligence in the Life Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{100014}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.ailsci.2021.100014}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MAT: Graph-Augmented Transformer for Molecules (2020)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/molecule-attention-transformer/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/molecule-attention-transformer/</guid><description>MAT augments the Transformer self-attention mechanism with inter-atomic distances and molecular graph adjacency for molecular property prediction.</description><content:encoded><![CDATA[<h2 id="a-graph-augmented-transformer-for-molecular-property-prediction">A Graph-Augmented Transformer for Molecular Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that proposes the Molecule Attention Transformer (MAT), a Transformer-based architecture adapted for molecular property prediction. The primary contribution is a modified self-attention mechanism that incorporates inter-atomic distances and molecular graph structure alongside the standard query-key attention. Combined with self-supervised pretraining on 2 million molecules from ZINC15, MAT achieves competitive performance across seven diverse molecular property prediction tasks while requiring minimal hyperparameter tuning.</p>
<h2 id="challenges-in-deep-learning-for-molecular-properties">Challenges in Deep Learning for Molecular Properties</h2>
<p>Predicting molecular properties is central to drug discovery and materials design, yet deep neural networks have struggled to consistently outperform shallow methods like random forests and SVMs on these tasks. Wu et al. (2018) demonstrated through the MoleculeNet benchmark that graph neural networks do not reliably beat classical models. Two recurring problems compound this:</p>
<ol>
<li><strong>Underfitting</strong>: Graph neural networks tend to underfit training data, with performance failing to scale with model complexity (Ishiguro et al., 2019).</li>
<li><strong>Hyperparameter sensitivity</strong>: Deep models for molecule property prediction require extensive hyperparameter search (often 500+ configurations) to achieve competitive results, making them impractical for many practitioners.</li>
</ol>
<p>Concurrent work explored using vanilla Transformers on SMILES string representations of molecules (Honda et al., 2019; Wang et al., 2019), but these approaches discard the explicit structural information encoded in molecular graphs and 3D conformations. The motivation for MAT is to combine the flexibility of the Transformer architecture with domain-specific inductive biases from molecular structure.</p>
<h2 id="molecule-self-attention-combining-attention-distance-and-graph-structure">Molecule Self-Attention: Combining Attention, Distance, and Graph Structure</h2>
<p>The core innovation is the Molecule Self-Attention layer, which replaces standard Transformer self-attention. In a standard Transformer, head $i$ computes:</p>
<p>$$
\mathcal{A}^{(i)} = \rho\left(\frac{\mathbf{Q}_{i} \mathbf{K}_{i}^{T}}{\sqrt{d_{k}}}\right) \mathbf{V}_{i}
$$</p>
<p>MAT augments this with two additional information sources. Let $\mathbf{A} \in {0, 1}^{N_{\text{atoms}} \times N_{\text{atoms}}}$ denote the molecular graph adjacency matrix and $\mathbf{D} \in \mathbb{R}^{N_{\text{atoms}} \times N_{\text{atoms}}}$ denote the inter-atomic distance matrix. The modified attention becomes:</p>
<p>$$
\mathcal{A}^{(i)} = \left(\lambda_{a} \rho\left(\frac{\mathbf{Q}_{i} \mathbf{K}_{i}^{T}}{\sqrt{d_{k}}}\right) + \lambda_{d}, g(\mathbf{D}) + \lambda_{g}, \mathbf{A}\right) \mathbf{V}_{i}
$$</p>
<p>where $\lambda_{a}$, $\lambda_{d}$, and $\lambda_{g}$ are scalar hyperparameters weighting each component, and $g$ is either a row-wise softmax or an element-wise exponential decay $g(d) = \exp(-d)$.</p>
<p>Key architectural details:</p>
<ul>
<li><strong>Atom embedding</strong>: Each atom is represented as a 26-dimensional vector encoding atomic identity (one-hot over B, N, C, O, F, P, S, Cl, Br, I, dummy, other), number of heavy neighbors, number of hydrogens, formal charge, ring membership, and aromaticity.</li>
<li><strong>Dummy node</strong>: An artificial disconnected node (distance $10^{6}$ from all atoms) is added to each molecule, allowing the model to &ldquo;skip&rdquo; attention heads when no relevant pattern exists, similar to how BERT uses the separation token.</li>
<li><strong>3D conformers</strong>: Distance matrices are computed from RDKit-generated 3D conformers using the Universal Force Field (UFF).</li>
<li><strong>Pretraining</strong>: Node-level masked atom prediction on 2 million ZINC15 molecules (following Hu et al., 2019), where 15% of atom features are masked and the model predicts them.</li>
</ul>
<h2 id="benchmark-evaluation-and-ablation-studies">Benchmark Evaluation and Ablation Studies</h2>
<h3 id="experimental-setup">Experimental setup</h3>
<p>MAT is evaluated on seven molecular property prediction datasets spanning regression and classification:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Task</th>
          <th>Size</th>
          <th>Metric</th>
          <th>Split</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FreeSolv</td>
          <td>Regression (hydration free energy)</td>
          <td>642</td>
          <td>RMSE</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>Regression (log solubility)</td>
          <td>1,128</td>
          <td>RMSE</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Classification (BBB permeability)</td>
          <td>2,039</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>Estrogen-alpha</td>
          <td>Classification (receptor activity)</td>
          <td>2,398</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>Estrogen-beta</td>
          <td>Classification (receptor activity)</td>
          <td>1,961</td>
          <td>ROC AUC</td>
          <td>Scaffold</td>
      </tr>
      <tr>
          <td>MetStab-high</td>
          <td>Classification (metabolic stability)</td>
          <td>2,127</td>
          <td>ROC AUC</td>
          <td>Random</td>
      </tr>
      <tr>
          <td>MetStab-low</td>
          <td>Classification (metabolic stability)</td>
          <td>2,127</td>
          <td>ROC AUC</td>
          <td>Random</td>
      </tr>
  </tbody>
</table>
<p>Baselines include GCN, Weave, EAGCN, Random Forest (RF), and SVM. Each model receives the same hyperparameter search budget (150 or 500 evaluations). Results are averaged over 6 random train/validation/test splits.</p>
<h3 id="main-results">Main results</h3>
<p>MAT achieves the best average rank across all seven tasks:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Avg. Rank (500 budget)</th>
          <th>Avg. Rank (150 budget)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MAT</td>
          <td>2.42</td>
          <td>2.71</td>
      </tr>
      <tr>
          <td>RF</td>
          <td>3.14</td>
          <td>3.14</td>
      </tr>
      <tr>
          <td>SVM</td>
          <td>3.57</td>
          <td>3.28</td>
      </tr>
      <tr>
          <td>GCN</td>
          <td>3.57</td>
          <td>3.71</td>
      </tr>
      <tr>
          <td>Weave</td>
          <td>3.71</td>
          <td>3.57</td>
      </tr>
      <tr>
          <td>EAGCN</td>
          <td>4.14</td>
          <td>4.14</td>
      </tr>
  </tbody>
</table>
<p>With self-supervised pretraining, Pretrained MAT achieves an average rank of 1.57, outperforming both Pretrained EAGCN (4.0) and SMILES Transformer (4.29). Pretrained MAT requires tuning only the learning rate (7 values tested), compared to 500 hyperparameter combinations for the non-pretrained models.</p>
<h3 id="ablation-results">Ablation results</h3>
<p>Ablation studies on BBBP, ESOL, and FreeSolv reveal:</p>
<table>
  <thead>
      <tr>
          <th>Variant</th>
          <th>BBBP (AUC)</th>
          <th>ESOL (RMSE)</th>
          <th>FreeSolv (RMSE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MAT (full)</td>
          <td>.723</td>
          <td>.286</td>
          <td>.250</td>
      </tr>
      <tr>
          <td>- Graph</td>
          <td>.716</td>
          <td>.316</td>
          <td>.276</td>
      </tr>
      <tr>
          <td>- Distance</td>
          <td>.729</td>
          <td>.281</td>
          <td>.281</td>
      </tr>
      <tr>
          <td>- Attention</td>
          <td>.692</td>
          <td>.306</td>
          <td>.329</td>
      </tr>
      <tr>
          <td>- Dummy node</td>
          <td>.714</td>
          <td>.317</td>
          <td>.249</td>
      </tr>
      <tr>
          <td>+ Edge features</td>
          <td>.683</td>
          <td>.314</td>
          <td>.358</td>
      </tr>
  </tbody>
</table>
<p>Removing any single component degrades performance on at least one task, supporting the value of combining all three information sources. Adding edge features does not help, suggesting the adjacency and distance matrices already capture sufficient bond-level information.</p>
<h3 id="interpretability-analysis">Interpretability analysis</h3>
<p>Individual attention heads in the first layer learn chemically meaningful functions. Six heads were identified that focus on specific chemical patterns: 2-neighbored aromatic carbons, sulfur atoms, non-ring nitrogens, carbonyl oxygens, 3-neighbored aromatic atoms (substitution positions), and aromatic ring nitrogens. Statistical validation using Kruskal-Wallis tests confirmed that atoms matching these SMARTS patterns receive significantly higher attention weights ($p &lt; 0.001$ for all patterns).</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>MAT demonstrates that augmenting Transformer self-attention with molecular graph structure and 3D distance information produces a model that performs consistently well across diverse property prediction tasks. The key practical finding is that self-supervised pretraining dramatically reduces the hyperparameter tuning burden: Pretrained MAT matches or exceeds the performance of extensively tuned models while requiring only learning rate selection.</p>
<p>Several limitations are acknowledged:</p>
<ul>
<li><strong>Fingerprint-based models still win on some tasks</strong>: RF and SVM with extended-connectivity fingerprints outperform MAT on metabolic stability and Estrogen-beta tasks, suggesting that incorporating fingerprint representations could improve MAT further.</li>
<li><strong>Single conformer</strong>: Only one pre-computed 3D conformer is used per molecule. More sophisticated conformer sampling or ensemble strategies were not explored.</li>
<li><strong>Limited pretraining exploration</strong>: Only the masked atom prediction task from Hu et al. (2019) was used. The authors note that exploring additional pretraining objectives is a promising direction.</li>
<li><strong>Scalability</strong>: The pretrained model uses 1024-dimensional embeddings with 8 layers and 16 attention heads, fitting the largest model that fits in GPU memory.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ZINC15</td>
          <td>2M molecules</td>
          <td>Sampled from ZINC database</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Hydration free energy regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Log solubility regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td>Blood-brain barrier classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Estrogen-alpha/beta</td>
          <td>2,398 / 1,961</td>
          <td>Receptor activity classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MetStab-high/low</td>
          <td>2,127 each</td>
          <td>Metabolic stability classification</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam with Noam learning rate scheduler (warmup then inverse square root decay)</li>
<li>Pretraining: 8 epochs, learning rate 0.001, batch size 256, binary cross-entropy loss</li>
<li>Fine-tuning: 100 epochs, batch size 32, learning rate selected from {1e-3, 5e-4, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6}</li>
<li>Distance kernel: exponential decay $g(d) = \exp(-d)$ for pretrained model</li>
<li>Lambda weights: $\lambda_{a} = \lambda_{d} = 0.33$ for pretrained model</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Pretrained MAT: 1024-dim embeddings, 8 layers, 16 attention heads, 1 feed-forward layer per block</li>
<li>Dropout: 0.0, weight decay: 0.0 for pretrained model</li>
<li>Atom featurization: 26-dimensional one-hot encoding (Table 1 in paper)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Regression: RMSE (FreeSolv, ESOL)</li>
<li>Classification: ROC AUC (BBBP, Estrogen-alpha/beta, MetStab-high/low)</li>
<li>All experiments repeated 6 times with different train/validation/test splits</li>
<li>Scaffold split for BBBP, Estrogen, random split for others</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify exact hardware details. The pretrained model is described as &ldquo;the largest model that still fits the GPU memory.&rdquo;</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gmum/MAT">gmum/MAT</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Maziarka, Ł., Danel, T., Mucha, S., Rataj, K., Tabor, J., &amp; Jastrzębski, S. (2020). Molecule Attention Transformer. <em>arXiv preprint arXiv:2002.08264</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{maziarka2020molecule,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecule Attention Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Maziarka, {\L}ukasz and Danel, Tomasz and Mucha, S{\l}awomir and Rataj, Krzysztof and Tabor, Jacek and Jastrz{\k{e}}bski, Stanis{\l}aw}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2002.08264}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DMP: Dual-View Molecule Pre-training (SMILES+GNN)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/dual-view-molecule-pretraining/</link><pubDate>Fri, 27 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/dual-view-molecule-pretraining/</guid><description>DMP pre-trains molecular encoders using both SMILES Transformer and GNN branches with a BYOL-style dual-view consistency loss for property prediction.</description><content:encoded><![CDATA[<h2 id="a-dual-branch-pre-training-method-for-molecular-property-prediction">A Dual-Branch Pre-training Method for Molecular Property Prediction</h2>
<p>DMP (Dual-view Molecule Pre-training) is a <strong>Method</strong> paper that introduces a pre-training framework combining two complementary molecular encoders: a Transformer operating on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings and a Graph Neural Network (GNN) operating on molecular graphs. The two branches are trained jointly with masked language modeling (MLM) objectives plus a BYOL-style dual-view consistency loss. After pre-training on 10M <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> molecules, either branch (or both) can be fine-tuned for downstream tasks. The authors recommend the Transformer branch based on empirical results. DMP achieves the best reported performance on 7 of 9 <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification tasks and 3 retrosynthesis benchmarks (at the time of the 2021 arXiv version).</p>
<h2 id="why-combine-smiles-and-graph-views-for-molecules">Why Combine SMILES and Graph Views for Molecules</h2>
<p>Prior molecule pre-training methods used either graph representations with GNNs or SMILES representations with Transformers, but not both. The authors observe that the two views are complementary: Transformers handle molecules with large atom distances (long chains) well, while GNNs handle molecules with many concatenated rings better. Neither model alone captures the full range of molecular structures effectively.</p>
<p>Existing GNN-based pre-training methods (Hu et al. 2020, MolCLR, GROVER) and SMILES-based methods (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>) each have blind spots dictated by their input representation. DMP addresses this by pre-training both views simultaneously and enforcing representation consistency between them, so each branch benefits from the structural knowledge of the other.</p>
<h2 id="dual-view-consistency-with-byol-style-training">Dual-View Consistency with BYOL-Style Training</h2>
<p>The core innovation is the dual-view consistency objective, inspired by Bootstrap Your Own Latent (BYOL). Given a molecule $M$ with SMILES representation $M_s$ and graph representation $M_g$, DMP obtains high-level features from each branch:</p>
<ul>
<li><strong>Transformer branch</strong>: A RoBERTa-base model encodes the SMILES sequence. The [CLS] token output serves as the molecule representation $f_s$.</li>
<li><strong>GNN branch</strong>: A DeeperGCN network encodes the molecular graph. Mean+max pooling over atom representations yields $f_g$.</li>
</ul>
<p>The dual-view consistency loss uses nonlinear projection heads $\psi_g, \psi_s$ and prediction heads $\rho_g, \rho_s$:</p>
<p>$$
p_g = \psi_g(f_g), \quad q_g = \rho_g(p_g); \quad p_s = \psi_s(f_s), \quad q_s = \rho_s(p_s)
$$</p>
<p>The consistency loss maximizes cross-view <a href="https://en.wikipedia.org/wiki/Cosine_similarity">cosine similarity</a> with stop-gradient (SG) on the target:</p>
<p>$$
\ell_{\text{dual}}(\tilde{M}_g, \tilde{M}_s) = -\cos(q_s, \text{SG}(p_g)) - \cos(q_g, \text{SG}(p_s))
$$</p>
<p>where $\cos(p, q) = \frac{p^\top q}{|p|_2 |q|_2}$ and $\tilde{M}_g, \tilde{M}_s$ are the masked versions of the inputs. The stop-gradient prevents representation collapse without requiring negative samples or a momentum encoder.</p>
<p>The full training objective combines three losses:</p>
<ol>
<li><strong>MLM on Transformer</strong>: Recover masked tokens in SMILES sequences</li>
<li><strong>MLM on GNN</strong>: Recover masked atoms in molecular graphs</li>
<li><strong>Dual-view consistency</strong>: The BYOL-style loss above</li>
</ol>
<p>Both MLM objectives and the consistency loss are necessary. Ablations show that removing MLM (using only dual-view loss) degrades performance, and using two branches of the same type (two Transformers or two GNNs) is less effective than the heterogeneous Transformer+GNN combination.</p>
<h2 id="experiments-on-moleculenet-and-retrosynthesis">Experiments on MoleculeNet and Retrosynthesis</h2>
<h3 id="pre-training-setup">Pre-training Setup</h3>
<p>DMP is pre-trained on 10M molecules from PubChem (matching prior work). The Transformer branch uses RoBERTa-base (12 layers, hidden dim 768, 87M parameters). The GNN branch uses DeeperGCN (12 layers, hidden dim 384, 7.4M parameters). Combined, DMP has 104.1M parameters. Training runs for 200K iterations on 8 V100 GPUs over 3.8 days with Adam optimizer (lr = 5e-4, weight decay 0.01).</p>
<h3 id="molecular-property-prediction-moleculenet">Molecular Property Prediction (MoleculeNet)</h3>
<p>DMP is evaluated on 6 binary classification tasks (BBBP, Tox21, ClinTox, HIV, BACE, SIDER) using official DeepChem splits, and on 3 additional tasks (BBBP, SIDER, ClinTox classification + ESOL, QM7, QM8 regression) using scaffold splits from GROVER.</p>
<p>Key results on DeepChem splits (ROC-AUC %):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MolCLR</th>
          <th>TF (MLM)</th>
          <th>DMP_TF</th>
          <th>DMP_TF+GNN</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>73.6</td>
          <td>74.9</td>
          <td><strong>78.1</strong></td>
          <td>77.8</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>79.8</td>
          <td>77.6</td>
          <td><strong>78.8</strong></td>
          <td>79.1</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>93.2</td>
          <td>92.9</td>
          <td><strong>95.0</strong></td>
          <td>95.6</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>80.6</td>
          <td>80.2</td>
          <td><strong>81.0</strong></td>
          <td>81.4</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>89.0</td>
          <td>88.0</td>
          <td><strong>89.3</strong></td>
          <td>89.4</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>68.0</td>
          <td>68.4</td>
          <td><strong>69.2</strong></td>
          <td>69.8</td>
      </tr>
  </tbody>
</table>
<p>On scaffold splits (comparison with GROVER and MPG):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>GROVER</th>
          <th>MPG</th>
          <th>DMP_TF</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP (AUC)</td>
          <td>0.940</td>
          <td>0.922</td>
          <td><strong>0.945</strong></td>
      </tr>
      <tr>
          <td>SIDER (AUC)</td>
          <td>0.658</td>
          <td>0.661</td>
          <td><strong>0.695</strong></td>
      </tr>
      <tr>
          <td>ClinTox (AUC)</td>
          <td>0.944</td>
          <td>0.963</td>
          <td><strong>0.968</strong></td>
      </tr>
      <tr>
          <td>ESOL (RMSE)</td>
          <td>0.831</td>
          <td>0.741</td>
          <td><strong>0.700</strong></td>
      </tr>
      <tr>
          <td>QM7 (MAE)</td>
          <td>72.6</td>
          <td>-</td>
          <td><strong>69.6</strong></td>
      </tr>
      <tr>
          <td>QM8 (MAE)</td>
          <td>0.0125</td>
          <td>-</td>
          <td><strong>0.0124</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="retrosynthesis">Retrosynthesis</h3>
<p>DMP is tested on USPTO-50K (reaction type known/unknown) and USPTO-full. Using a &ldquo;DMP fusion&rdquo; approach (fusing pre-trained representations into a Transformer encoder-decoder for <a href="/notes/chemistry/molecular-design/reaction-prediction/">retrosynthesis</a>), DMP improves top-1 accuracy by 2-3 points over the baseline Transformer across all settings:</p>
<table>
  <thead>
      <tr>
          <th>Setting</th>
          <th>Transformer</th>
          <th>ChemBERTa fusion</th>
          <th>DMP fusion</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>USPTO-50K (unknown)</td>
          <td>42.3</td>
          <td>43.9</td>
          <td><strong>46.1</strong></td>
      </tr>
      <tr>
          <td>USPTO-50K (known)</td>
          <td>54.2</td>
          <td>56.4</td>
          <td><strong>57.5</strong></td>
      </tr>
      <tr>
          <td>USPTO-full</td>
          <td>42.9</td>
          <td>-</td>
          <td><strong>45.0</strong></td>
      </tr>
  </tbody>
</table>
<p>For GNN-based retrosynthesis, replacing GLN&rsquo;s GNN modules with DMP&rsquo;s pre-trained GNN branch improves top-1 accuracy from 52.5% to 54.2% (unknown type) and from 64.2% to 66.5% (known type).</p>
<h3 id="representation-quality">Representation Quality</h3>
<p><a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding">t-SNE</a> visualization of pre-trained representations shows that DMP produces better scaffold-based clustering than either GNN-only or Transformer-only pre-training. The <a href="https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index">Davies-Bouldin index</a> improves from 3.56 (GNN) and 3.59 (Transformer) to 2.19 (DMP), indicating much tighter within-scaffold clusters.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Combining heterogeneous views (SMILES + graph) during pre-training is more effective than using two branches of the same type. TF(x2) and GNN(x2) variants show smaller gains.</li>
<li>Both MLM and dual-view consistency loss contribute. Removing MLM (dual-view only) hurts performance, especially on BBBP (71.1 vs 78.1 with both losses).</li>
<li>The Transformer branch alone is recommended for downstream tasks, as it achieves strong results without adding GNN parameters at inference time.</li>
<li>Scaling pre-training data from 10M to 100M compounds yields marginal additional improvement.</li>
</ul>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ol>
<li>Training cost is higher than single-branch methods (3.8 days vs 2.5 days for TF-only on 8 V100s), since both branches must be trained jointly.</li>
<li>A fixed branch selection strategy is used at inference time. The authors note that a meta-controller for dynamic branch selection per molecule would be preferable.</li>
<li>The GNN branch uses simple atom masking without bond deletion or subgraph removal, leaving room for stronger graph-level pre-training objectives.</li>
</ol>
<p><strong>Relation to co-training:</strong> The authors clarify that DMP differs from classical <a href="https://en.wikipedia.org/wiki/Co-training">co-training</a> (Blum and Mitchell 1998) in that it does not require conditional independence between views and produces a pre-trained model rather than additional labeled data.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>PubChem subset</td>
          <td>10M compounds</td>
          <td>Same subset as MolCLR and ChemBERTa</td>
      </tr>
      <tr>
          <td>Pre-training (large)</td>
          <td>PubChem subset</td>
          <td>100M compounds</td>
          <td>Additional scale experiment</td>
      </tr>
      <tr>
          <td>Evaluation (classification)</td>
          <td>MoleculeNet (BBBP, Tox21, ClinTox, HIV, BACE, SIDER)</td>
          <td>1.5K-41K molecules</td>
          <td>Official DeepChem splits</td>
      </tr>
      <tr>
          <td>Evaluation (regression)</td>
          <td>MoleculeNet (ESOL, QM7, QM8)</td>
          <td>Varies</td>
          <td>Scaffold splits from GROVER</td>
      </tr>
      <tr>
          <td>Evaluation (retrosynthesis)</td>
          <td>USPTO-50K, USPTO-full</td>
          <td>50K / 950K reactions</td>
          <td>Splits from Dai et al. (2019)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Transformer branch</strong>: RoBERTa-base with MLM. SMILES tokenized using regex from Schwaller et al. (2019).</li>
<li><strong>GNN branch</strong>: DeeperGCN with 12 layers, atom masking for MLM.</li>
<li><strong>Dual-view loss</strong>: BYOL-style with 3-layer MLP projection heads and 2-layer MLP prediction heads, stop-gradient on targets.</li>
<li><strong>Optimizer</strong>: Adam (lr=5e-4, beta1=0.9, beta2=0.98, epsilon=1e-6), weight decay 0.01, 10K warmup steps, linear decay.</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Architecture</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Transformer branch</td>
          <td>RoBERTa-base (12L, 768H, 12 heads)</td>
          <td>87M</td>
      </tr>
      <tr>
          <td>GNN branch</td>
          <td>DeeperGCN (12L, 384H)</td>
          <td>7.4M</td>
      </tr>
      <tr>
          <td>DMP (total)</td>
          <td>Transformer + GNN + projection/prediction heads</td>
          <td>104.1M</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Classification: ROC-AUC, averaged over 3 random seeds</li>
<li>Regression: RMSE (ESOL) or MAE (QM7, QM8)</li>
<li>Retrosynthesis: Top-k exact match accuracy (k=1,3,5,10,20,50)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 8 NVIDIA V100 GPUs, batch size 12288 tokens, gradient accumulation 16x</li>
<li>Pre-training time: 3.8 days (DMP), 2.5 days (TF-only), 1.7 days (GNN-only)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>No public code repository or pre-trained model weights were identified for this paper. The paper references GLN&rsquo;s code repository (<a href="https://github.com/Hanjun-Dai/GLN">https://github.com/Hanjun-Dai/GLN</a>) for the retrosynthesis baseline but does not release DMP-specific code.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Hanjun-Dai/GLN">GLN (baseline)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Retrosynthesis baseline, not DMP code</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhu, J., Xia, Y., Wu, L., Xie, S., Zhou, W., Qin, T., Li, H., &amp; Liu, T.-Y. (2023). Dual-view Molecular Pre-training. In <em>Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining</em> (pp. 3615-3627). <a href="https://doi.org/10.1145/3580305.3599317">https://doi.org/10.1145/3580305.3599317</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhu2023dualview,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Dual-view Molecular Pre-training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zhu, Jinhua and Xia, Yingce and Wu, Lijun and Xie, Shufang and Zhou, Wengang and Qin, Tao and Li, Houqiang and Liu, Tie-Yan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3615--3627}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3580305.3599317}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>X-MOL: Pre-training on 1.1B Molecules for SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/x-mol-pretraining-molecular-understanding/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/x-mol-pretraining-molecular-understanding/</guid><description>X-MOL pre-trains a shared encoder-decoder Transformer on 1.1 billion molecules, then fine-tunes for property prediction, reaction analysis, and generation.</description><content:encoded><![CDATA[<h2 id="a-unified-molecular-pre-training-framework">A Unified Molecular Pre-training Framework</h2>
<p>X-MOL is a <strong>Method</strong> paper that introduces a large-scale pre-training framework for <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>-based molecular understanding. The primary contribution is a Transformer encoder-decoder model pre-trained on 1.1 billion molecules from <a href="/notes/chemistry/datasets/zinc-22/">ZINC15</a>, which is then fine-tuned across five distinct molecular analysis tasks: molecular property prediction (classification and regression), chemical reaction productivity prediction, <a href="https://en.wikipedia.org/wiki/Drug_interaction">drug-drug interaction</a> (DDI) prediction, de novo molecule generation (distribution learning and goal-directed), and molecule optimization. The paper demonstrates that a single pre-trained model can serve as a universal foundation for diverse downstream chemistry tasks.</p>
<h2 id="bridging-scale-and-understanding-in-molecular-smiles">Bridging Scale and Understanding in Molecular SMILES</h2>
<p>Prior to X-MOL, most molecular analysis tasks were investigated individually with task-specific models. SMILES-based deep learning methods existed but lacked the benefit of large-scale pre-training that had proven transformative in NLP (BERT, RoBERTa, ERNIE, XLNet, <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>). Two challenges motivated this work:</p>
<ol>
<li><strong>SMILES sacrifices structural information for simplicity.</strong> While SMILES is a convenient linear representation, it does not directly encode molecular topology, making it harder for models to learn 3D structure from string input.</li>
<li><strong>Labelled molecular data is scarce.</strong> Most benchmark datasets (<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>) contain only thousands of labelled examples, making it difficult to train large models from scratch without overfitting.</li>
</ol>
<p>The authors hypothesized that massive-scale pre-training on unlabelled SMILES could teach a model the grammar rules and implicit structural information in SMILES, providing a strong initialization for multiple downstream tasks.</p>
<h2 id="generative-pre-training-with-random-smiles">Generative Pre-training with Random SMILES</h2>
<p>The core innovation in X-MOL is a <strong>generative pre-training strategy</strong> that exploits the non-uniqueness of SMILES. A single molecule can be represented by many valid SMILES strings (<a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">random SMILES</a>), depending on the starting atom, main chain selection, and ring-opening position. X-MOL trains the model to generate a valid alternative SMILES given an input SMILES of the same molecule, forcing the model to:</p>
<ol>
<li>Reconstruct the molecular structure from the input SMILES</li>
<li>Generate a valid output SMILES following SMILES grammar rules</li>
</ol>
<p>The architecture uses a shared-parameter encoder-decoder based on the Transformer. Unlike standard encoder-decoder models (e.g., for machine translation), X-MOL shares all parameters between encoder and decoder, forcing both encoding and decoding to occur in the same semantic space. The output SMILES is fully masked during training, and only unidirectional attention is permitted within the output sequence.</p>
<p>The self-attention mechanism computes attention for each character $i$ as:</p>
<p>$$
Z_{i} = \text{SoftMax}\left(\frac{Q_{i} \cdot K^{T}}{\sqrt{D}}\right) \cdot V
$$</p>
<p>where $Q_{i}$, $K$, and $V$ are the query, key, and value matrices, and $D$ is the feature dimension. The model uses 12 attention heads to capture different relational patterns.</p>
<h3 id="model-architecture">Model Architecture</h3>
<ul>
<li>12 Transformer encoder layers</li>
<li>768-dimensional hidden units</li>
<li>12 attention heads</li>
<li>Character-level SMILES tokenization (108 chemical characters plus 5 special tokens: [PAD], [CLS], [SEP], [MASK], [UNK])</li>
<li>Characters within square brackets and double digits preceded by &ldquo;%&rdquo; are treated as single tokens</li>
</ul>
<h3 id="data-augmentation-in-pre-training">Data Augmentation in Pre-training</h3>
<p>Because a molecule has multiple valid random SMILES, the output may differ from the predefined target. To handle this, X-MOL generates multiple training samples per molecule with the same input SMILES but different output random SMILES, and places these in the same mini-batch.</p>
<h2 id="experimental-setup-across-five-tasks">Experimental Setup Across Five Tasks</h2>
<p>X-MOL is fine-tuned with task-specific strategies organized into two categories: prediction tasks and generation tasks.</p>
<h3 id="prediction-tasks">Prediction Tasks</h3>
<p>For prediction tasks, the [CLS] token&rsquo;s output representation is passed through a fully connected network to produce predictions. The input format varies by task:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Input Format</th>
          <th>Loss Function</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Property prediction (classification)</td>
          <td>Single SMILES</td>
          <td>Cross-entropy</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Property prediction (regression)</td>
          <td>Single SMILES</td>
          <td>MSE</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Reaction productivity prediction</td>
          <td>Four SMILES (reactant, additive, base, ligand)</td>
          <td>MSE</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>DDI prediction</td>
          <td>Two SMILES (drug pair)</td>
          <td>Cross-entropy</td>
          <td>Accuracy</td>
      </tr>
  </tbody>
</table>
<p><strong>Molecular Property Prediction (Classification):</strong> Four MoleculeNet benchmarks were used: HIV (41,127 compounds), BACE (1,513), <a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">BBBP</a> (2,039), and ClinTox (1,484). Data were randomly split 20 times, and average ROC-AUC is reported.</p>
<p><strong>Molecular Property Prediction (Regression):</strong> Three MoleculeNet benchmarks: ESOL (1,128), FreeSolv (642), and Lipophilicity (4,200). Data augmentation with random SMILES was applied to the training set. Average RMSE over 20 random splits is reported.</p>
<p><strong>Chemical Reaction Productivity Prediction:</strong> The <a href="https://en.wikipedia.org/wiki/Cross-coupling_reaction">C-N cross-coupling</a> dataset (3,956 reactions) from Ahneman et al. was used with 10-fold cross-validation.</p>
<p><strong>DDI Prediction:</strong> The DeepDDI dataset (192,284 DDI pairs, 86 interaction types) was used as benchmark.</p>
<h3 id="generation-tasks">Generation Tasks</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Generation Source</th>
          <th>Sampling Strategy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Distribution learning (DL) generation</td>
          <td>Fixed initial symbol ([CLS])</td>
          <td>Random sampling</td>
      </tr>
      <tr>
          <td>Goal-directed (GD) generation</td>
          <td>Unfixed initial symbol</td>
          <td>Random sampling</td>
      </tr>
      <tr>
          <td>Molecule optimization</td>
          <td>Input molecule</td>
          <td>Beam search (beam size = 4)</td>
      </tr>
  </tbody>
</table>
<p><strong>DL-based Generation:</strong> Evaluated on ZINC250K (249,456 molecules) using validity, uniqueness, and novelty.</p>
<p><strong>GD Generation:</strong> Also on ZINC250K, using QED as the goal property with target QED = 0.948 (the dataset maximum). 10,000 molecules were generated for evaluation.</p>
<p><strong>Molecule Optimization:</strong> Evaluated on ZINC250K with QED as the optimization goal. Molecular pairs were constructed by selecting pairs with <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> in [0.6, 0.8], where the lower-QED molecule serves as input and the higher-QED molecule as target.</p>
<h3 id="key-results">Key Results</h3>
<p><strong>Classification (ROC-AUC, higher is better):</strong> X-MOL achieved state-of-the-art on all four datasets, outperforming both shallow learning methods and deep learning baselines including graph convolutional models.</p>
<p><strong>Regression (RMSE, lower is better):</strong> X-MOL achieved the best RMSE on ESOL, FreeSolv, and Lipophilicity.</p>
<p><strong>Reaction Productivity:</strong> X-MOL obtained an average RMSE of 0.0626, compared to the random forest baseline of 0.078.</p>
<p><strong>DDI Prediction:</strong> X-MOL achieved accuracy of 0.952, improving over DeepDDI&rsquo;s 0.924.</p>
<p><strong>DL-based Generation:</strong></p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Validity</th>
          <th>Uniqueness</th>
          <th>Novelty</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GCPN</td>
          <td>20%</td>
          <td>99.97%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td>MRNN</td>
          <td>65%</td>
          <td>99.89%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td>GraphAF</td>
          <td>68%</td>
          <td>99.10%</td>
          <td>100%</td>
      </tr>
      <tr>
          <td><strong>X-MOL</strong></td>
          <td><strong>85.28%</strong></td>
          <td><strong>99.91%</strong></td>
          <td><strong>100%</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>GD Generation:</strong> X-MOL generated all top-3 molecules with QED = 0.948, matching the dataset maximum. GraphAF reached 0.948/0.948/0.947, while JT-VAE and MRNN fell further behind.</p>
<h3 id="knowledge-embedding-ablation">Knowledge Embedding Ablation</h3>
<p>The paper tested three additional embedding strategies to inject structural information into the model:</p>
<ul>
<li><strong>Link embedding:</strong> Encodes connection information between atoms (position of the previous connected atom)</li>
<li><strong>Ring embedding:</strong> Encodes ring structure information from SMILES number pairs</li>
<li><strong>Type embedding:</strong> Categorizes characters into 9 types (atoms, bonds, structural symbols)</li>
</ul>
<p>None of these additional embeddings improved performance on the HIV or DDI tasks, whether with or without pre-training. The authors conclude that SMILES already contains sufficient information for molecular understanding and that pre-training effectively extracts this information, a finding they label &ldquo;SMILES is all you need.&rdquo;</p>
<h3 id="attention-visualization">Attention Visualization</h3>
<p>The authors provide attention heatmap analysis demonstrating that:</p>
<ul>
<li>Middle layers (e.g., layer 9) reconstruct molecular structure by correctly identifying atom connectivity and ring closures</li>
<li>Later layers abstract higher-level features for property prediction</li>
<li>In multi-input prediction tasks (reaction productivity), attention reveals which reaction components are most important (e.g., the ligand receives highest cross-attention)</li>
<li>In generation tasks, attention patterns differ between DL (self-focused), GD (source-constrained), and optimization (gradual shift from input to output)</li>
</ul>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>X-MOL demonstrates that large-scale pre-training on SMILES can produce a single model that achieves competitive or state-of-the-art performance across five distinct molecular analysis tasks. The key findings are:</p>
<ol>
<li><strong>Scale enables SMILES understanding.</strong> Pre-training on 1.1 billion molecules allows the model to learn SMILES grammar rules well enough to outperform graph-based methods on molecule generation validity.</li>
<li><strong>Unified framework.</strong> A single pre-trained backbone serves classification, regression, reaction prediction, DDI prediction, and generative tasks through different fine-tuning strategies.</li>
<li><strong>SMILES is sufficient.</strong> Additional knowledge embeddings (link, ring, type) do not improve performance, suggesting pre-training extracts the necessary structural information from SMILES alone.</li>
<li><strong>Interpretable attention.</strong> Attention visualization confirms that the model reconstructs molecular structure internally.</li>
</ol>
<p><strong>Limitations</strong> (observed):</p>
<ul>
<li>The paper reports only MoleculeNet benchmarks with relatively few datasets. No scaffold splits or temporal splits are used; all splits are random, which can overestimate performance on structurally novel compounds.</li>
<li>Comparison baselines are somewhat dated (2018-2019 era methods), and the paper does not compare against concurrent SMILES pre-training methods.</li>
<li>The molecule generation validity (85.28%) is much higher than graph baselines like GCPN (20%), but later work achieved near 100% validity with constrained SMILES grammars.</li>
<li>No code or model weights have been publicly released, limiting independent verification.</li>
<li>The paper remains a bioRxiv preprint and has not been published in a peer-reviewed venue.</li>
</ul>
<p><strong>Future directions</strong> proposed by the authors include: better pre-training strategies, extension to graph-based representations, and fine-tuning on additional downstream tasks.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC15</td>
          <td>1.1 billion molecules</td>
          <td>Random SMILES augmentation</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>HIV (MoleculeNet)</td>
          <td>41,127</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE (MoleculeNet)</td>
          <td>1,513</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP (MoleculeNet)</td>
          <td>2,039</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>ClinTox (MoleculeNet)</td>
          <td>1,484</td>
          <td>Two sub-datasets, averaged</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL (MoleculeNet)</td>
          <td>1,128</td>
          <td>Water solubility</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv (MoleculeNet)</td>
          <td>642</td>
          <td>Hydration free energy</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipophilicity (MoleculeNet)</td>
          <td>4,200</td>
          <td>logD at pH 7.4</td>
      </tr>
      <tr>
          <td>Reaction</td>
          <td>C-N cross-coupling</td>
          <td>3,956</td>
          <td>From Ahneman et al. (2018)</td>
      </tr>
      <tr>
          <td>DDI</td>
          <td>DeepDDI</td>
          <td>192,284 DDI pairs</td>
          <td>86 interaction types</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td>ZINC250K</td>
          <td>249,456</td>
          <td>For DL, GD, and optimization</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Pre-training: Generative SMILES-to-SMILES with shared encoder-decoder Transformer</li>
<li>Fine-tuning prediction tasks: [CLS] token passed through fully connected layers</li>
<li>Fine-tuning generation tasks: Autoregressive generation with random sampling (DL, GD) or beam search (optimization)</li>
<li>Data augmentation: Random SMILES augmentation for regression tasks</li>
<li>Repeated training: 20 random splits with averaged results for classification/regression</li>
<li>10-fold cross-validation for reaction productivity</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>12-layer Transformer, 768 hidden dimensions, 12 attention heads</li>
<li>Character-level tokenization: 108 chemical characters + 5 special tokens</li>
<li>Implemented in PaddlePaddle framework</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>X-MOL</th>
          <th>Best Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HIV (classification)</td>
          <td>ROC-AUC</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>BACE (classification)</td>
          <td>ROC-AUC</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>BBBP (classification)</td>
          <td>ROC-AUC</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>ClinTox (classification)</td>
          <td>ROC-AUC</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>ESOL (regression)</td>
          <td>RMSE</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>FreeSolv (regression)</td>
          <td>RMSE</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>Lipophilicity (regression)</td>
          <td>RMSE</td>
          <td>State-of-the-art</td>
          <td>Previous best (various)</td>
      </tr>
      <tr>
          <td>C-N coupling</td>
          <td>RMSE</td>
          <td>0.0626</td>
          <td>0.078 (random forest)</td>
      </tr>
      <tr>
          <td>DDI prediction</td>
          <td>Accuracy</td>
          <td>0.952</td>
          <td>0.924 (DeepDDI)</td>
      </tr>
      <tr>
          <td>DL generation</td>
          <td>Validity</td>
          <td>85.28%</td>
          <td>68% (GraphAF)</td>
      </tr>
      <tr>
          <td>GD generation</td>
          <td>Top-3 QED</td>
          <td>All 0.948</td>
          <td>0.948/0.948/0.947 (GraphAF)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 8/16 Tesla P40 GPUs (24 GB each), approximately 4 days</li>
<li>Data pre-processing: Over 1,000 CPUs with Hadoop</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>No code, model weights, or pre-trained checkpoints have been publicly released. The model was implemented in Baidu&rsquo;s PaddlePaddle framework, but no repository is available.</p>
<p><strong>Reproducibility status: Closed.</strong> While the datasets are all publicly available (ZINC15, MoleculeNet, ZINC250K, DeepDDI, C-N coupling), the model implementation, pre-trained weights, and fine-tuning code are not released. The computational requirements (1,000+ CPUs for data processing, 8-16 GPUs for 4 days of pre-training) are substantial.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xue, D., Zhang, H., Xiao, D., Gong, Y., Chuai, G., Sun, Y., Tian, H., Wu, H., Li, Y., &amp; Liu, Q. (2020). X-MOL: Large-scale pre-training for molecular understanding and diverse molecular analysis. <em>bioRxiv</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{xue2020xmol,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{X-MOL: large-scale pre-training for molecular understanding and diverse molecular analysis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xue, Dongyu and Zhang, Han and Xiao, Dongling and Gong, Yukang and Chuai, Guohui and Sun, Yu and Tian, Hao and Wu, Hua and Li, Yukun and Liu, Qi}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{bioRxiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1101/2020.12.23.424259}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Cold Spring Harbor Laboratory}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>VAE for Automatic Chemical Design (2018 Seminal)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/</guid><description>A variational autoencoder maps SMILES strings to a continuous latent space, enabling gradient-based optimization for molecular design and generation.</description><content:encoded><![CDATA[<h2 id="a-foundational-method-for-continuous-molecular-representation">A Foundational Method for Continuous Molecular Representation</h2>
<p>This is a <strong>Method</strong> paper that introduces a variational autoencoder (VAE) framework for mapping discrete molecular representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings) into a continuous latent space. The primary contribution is demonstrating that this continuous representation enables three key capabilities: (1) automatic generation of novel molecules by decoding random or perturbed latent vectors, (2) smooth interpolation between molecules in latent space, and (3) gradient-based optimization of molecular properties using a jointly trained property predictor. This work is widely regarded as one of the earliest and most influential applications of deep generative models to molecular design.</p>
<h2 id="the-challenge-of-searching-discrete-chemical-space">The Challenge of Searching Discrete Chemical Space</h2>
<p>Molecular design is fundamentally an optimization problem: identify molecules that maximize some set of desirable properties. The search space is enormous (estimated $10^{23}$ to $10^{60}$ drug-like molecules) and discrete, making systematic exploration difficult. Prior approaches fell into two categories, each with significant limitations:</p>
<ol>
<li><strong>Virtual screening</strong> over fixed libraries: effective but monolithic, costly to enumerate, and requiring hand-crafted rules to avoid impractical chemistries.</li>
<li><strong>Discrete local search</strong> (e.g., genetic algorithms): requires manual specification of mutation and crossover heuristics, and cannot leverage gradient information to guide the search.</li>
</ol>
<p>The core insight is that mapping molecules into a continuous vector space sidesteps these problems entirely. In a continuous space, new compounds can be generated by vector perturbation (no hand-crafted mutation rules), optimization can follow property gradients (enabling larger and more directed jumps), and large unlabeled chemical databases can be leveraged through unsupervised representation learning.</p>
<h2 id="a-vae-architecture-for-smiles-strings-with-joint-property-prediction">A VAE Architecture for SMILES Strings with Joint Property Prediction</h2>
<p>The architecture consists of three coupled neural networks trained jointly:</p>
<ol>
<li>
<p><strong>Encoder</strong>: Converts SMILES character strings into fixed-dimensional continuous vectors (the latent representation). Uses three 1D convolutional layers followed by a fully connected layer. For ZINC molecules, the latent space has 196 dimensions; for <a href="/notes/chemistry/datasets/qm9/">QM9</a>, 156 dimensions.</p>
</li>
<li>
<p><strong>Decoder</strong>: Converts latent vectors back into SMILES strings character by character using three layers of gated recurrent units (GRUs). The output is stochastic, as each character is sampled from a probability distribution over the SMILES alphabet.</p>
</li>
<li>
<p><strong>Property Predictor</strong>: A multilayer perceptron that predicts molecular properties directly from the latent representation. Joint training with the autoencoder reconstruction loss organizes the latent space so that molecules with similar properties cluster together.</p>
</li>
</ol>
<h3 id="the-vae-objective">The VAE Objective</h3>
<p>The model uses the <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoder framework of Kingma and Welling</a>. The training objective combines three terms:</p>
<p>$$\mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL}(q(z|x) | p(z)) + \lambda \cdot \mathcal{L}_{prop}$$</p>
<p>where $\mathcal{L}_{recon}$ is the reconstruction loss (cross-entropy over SMILES characters), $D_{KL}$ is the KL divergence regularizer that encourages the latent distribution $q(z|x)$ to match a standard Gaussian prior $p(z)$, and $\mathcal{L}_{prop}$ is the property prediction regression loss. Both the variational loss and the property prediction loss are annealed in using a sigmoid schedule after 29 epochs over a total of 120 epochs of training.</p>
<p>The KL regularization is critical: it forces the decoder to handle a wider variety of latent points, preventing &ldquo;dead areas&rdquo; in latent space that would decode to invalid molecules.</p>
<h3 id="gradient-based-optimization">Gradient-Based Optimization</h3>
<p>After training, a Gaussian process (GP) surrogate model is fit on top of the latent representations to predict the target property. Optimization proceeds by:</p>
<ol>
<li>Encoding a seed molecule into the latent space</li>
<li>Using the GP model to define a smooth property surface over the latent space</li>
<li>Optimizing the latent vector $z$ to maximize the predicted property via gradient ascent</li>
<li>Decoding the optimized $z$ back into a SMILES string</li>
</ol>
<p>The objective used for demonstration is $5 \times \text{QED} - \text{SAS}$, balancing drug-likeness (QED) against synthetic accessibility (SAS).</p>
<h2 id="experiments-on-zinc-and-qm9-datasets">Experiments on ZINC and QM9 Datasets</h2>
<p>Two autoencoder systems were trained:</p>
<ul>
<li><strong>ZINC</strong>: 250,000 drug-like molecules from the ZINC database, with a 196-dimensional latent space. Properties predicted: logP, QED, SAS.</li>
<li><strong>QM9</strong>: 108,000 molecules with fewer than 9 heavy atoms, with a 156-dimensional latent space. Properties predicted: HOMO energy, LUMO energy, electronic spatial extent ($\langle R^2 \rangle$).</li>
</ul>
<h3 id="latent-space-quality">Latent Space Quality</h3>
<p>The encoded latent dimensions follow approximately normal distributions as enforced by the variational regularizer. Decoding is stochastic: sampling the same latent point multiple times yields different SMILES strings, with the most frequent decoding tending to be closest to the original point in latent space. Decoding validity rates are 73-79% for points near known molecules but only 4% for randomly selected latent points.</p>
<p>Spherical interpolation (slerp) between molecules in latent space produces smooth structural transitions, accounting for the geometry of high-dimensional Gaussian distributions where linear interpolation would pass through low-probability regions.</p>
<h3 id="molecular-generation-comparison">Molecular Generation Comparison</h3>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>Dataset</th>
          <th>Samples</th>
          <th>logP</th>
          <th>SAS</th>
          <th>QED</th>
          <th>% in ZINC</th>
          <th>% in eMolecules</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Data</td>
          <td>ZINC</td>
          <td>249k</td>
          <td>2.46 (1.43)</td>
          <td>3.05 (0.83)</td>
          <td>0.73 (0.14)</td>
          <td>100</td>
          <td>12.9</td>
      </tr>
      <tr>
          <td>GA</td>
          <td>ZINC</td>
          <td>5303</td>
          <td>2.84 (1.86)</td>
          <td>3.80 (1.01)</td>
          <td>0.57 (0.20)</td>
          <td>6.5</td>
          <td>4.8</td>
      </tr>
      <tr>
          <td>VAE</td>
          <td>ZINC</td>
          <td>8728</td>
          <td>2.67 (1.46)</td>
          <td>3.18 (0.86)</td>
          <td>0.70 (0.14)</td>
          <td>5.8</td>
          <td>7.0</td>
      </tr>
      <tr>
          <td>Data</td>
          <td>QM9</td>
          <td>134k</td>
          <td>0.30 (1.00)</td>
          <td>4.25 (0.94)</td>
          <td>0.48 (0.07)</td>
          <td>0.0</td>
          <td>8.6</td>
      </tr>
      <tr>
          <td>GA</td>
          <td>QM9</td>
          <td>5470</td>
          <td>0.96 (1.53)</td>
          <td>4.47 (1.01)</td>
          <td>0.53 (0.13)</td>
          <td>0.018</td>
          <td>3.8</td>
      </tr>
      <tr>
          <td>VAE</td>
          <td>QM9</td>
          <td>2839</td>
          <td>0.30 (0.97)</td>
          <td>4.34 (0.98)</td>
          <td>0.47 (0.08)</td>
          <td>0.0</td>
          <td>8.9</td>
      </tr>
  </tbody>
</table>
<p>The VAE generates molecules whose property distributions closely match the training data, outperforming a genetic algorithm baseline that biases toward higher chemical complexity and decreased drug-likeness. Only 5.8% of VAE-generated ZINC molecules were found in the original ZINC database, indicating genuine novelty.</p>
<h3 id="property-prediction">Property Prediction</h3>
<table>
  <thead>
      <tr>
          <th>Dataset/Property</th>
          <th>Mean Baseline</th>
          <th>ECFP</th>
          <th>Graph Conv.</th>
          <th>1-hot SMILES</th>
          <th>Encoder Only</th>
          <th>VAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ZINC/logP</td>
          <td>1.14</td>
          <td>0.38</td>
          <td>0.05</td>
          <td>0.16</td>
          <td>0.13</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>ZINC/QED</td>
          <td>0.112</td>
          <td>0.045</td>
          <td>0.017</td>
          <td>0.041</td>
          <td>0.037</td>
          <td>0.054</td>
      </tr>
      <tr>
          <td>QM9/HOMO (eV)</td>
          <td>0.44</td>
          <td>0.20</td>
          <td>0.12</td>
          <td>0.12</td>
          <td>0.13</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>QM9/LUMO (eV)</td>
          <td>1.05</td>
          <td>0.20</td>
          <td>0.15</td>
          <td>0.11</td>
          <td>0.14</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>QM9/Gap (eV)</td>
          <td>1.07</td>
          <td>0.30</td>
          <td>0.18</td>
          <td>0.16</td>
          <td>0.18</td>
          <td>0.21</td>
      </tr>
  </tbody>
</table>
<p>The VAE latent representation achieves property prediction accuracy comparable to graph convolutions for some properties, though graph convolutions generally perform best. The primary purpose of joint training is not to maximize prediction accuracy but to organize the latent space for optimization.</p>
<h3 id="optimization-results">Optimization Results</h3>
<p>Bayesian optimization with a GP model on the jointly trained latent space consistently produces molecules with higher percentile scores on the $5 \times \text{QED} - \text{SAS}$ objective compared to both random Gaussian search and genetic algorithm baselines. Starting from molecules in the bottom 10th percentile of the ZINC dataset, the optimizer reliably discovers molecules in regions of high objective value. Training the GP with 1000 molecules (vs. 2000) produces a wider diversity of solutions by optimizing to multiple local optima rather than a single global optimum.</p>
<h2 id="key-findings-limitations-and-legacy">Key Findings, Limitations, and Legacy</h2>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li>A continuous latent representation of molecules enables gradient-based search through chemical space, a qualitatively different approach from discrete enumeration or genetic algorithms.</li>
<li>Joint training with property prediction organizes the latent space by property values, creating smooth gradients that optimization can follow.</li>
<li>The VAE generates novel molecules with realistic property distributions, and the latent space encodes an estimated 7.5 million molecules despite training on only 250,000.</li>
</ul>
<h3 id="acknowledged-limitations">Acknowledged Limitations</h3>
<ul>
<li>The SMILES-based decoder sometimes produces formally valid but chemically undesirable molecules (acid chlorides, anhydrides, cyclopentadienes, aziridines, etc.) because the grammar of valid SMILES does not capture all synthetic or stability constraints.</li>
<li>Character-level SMILES generation is fragile: the decoder must implicitly learn which strings are valid SMILES, making the learning problem harder than necessary.</li>
<li>Decoding validity drops to only 4% for random latent points far from training data, limiting the ability to explore truly novel regions of chemical space.</li>
</ul>
<h3 id="directions-identified">Directions Identified</h3>
<p>The authors point to several extensions that were already underway at the time of publication:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a></strong>: Using an explicitly defined SMILES grammar instead of forcing the model to learn one (Kusner et al., 2017).</li>
<li><strong>Graph-based decoders</strong>: Directly outputting molecular graphs to avoid the SMILES validity problem.</li>
<li><strong>Adversarial training</strong>: Using GANs for molecular generation (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN, ORGANIC</a>).</li>
<li><strong>LSTM/RNN generators</strong>: Applying recurrent networks directly to SMILES for generation and reaction prediction.</li>
</ul>
<p>This paper has been cited over 2,900 times and launched a large body of follow-up work in VAE-based, GAN-based, and reinforcement learning-based molecular generation.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ZINC (drug-like subset)</td>
          <td>250,000 molecules</td>
          <td>Randomly sampled from ZINC database</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>QM9</td>
          <td>108,000 molecules</td>
          <td>Molecules with fewer than 9 heavy atoms</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ZINC held-out set</td>
          <td>5,000 molecules</td>
          <td>For latent space analysis</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Encoder</strong>: 3 x 1D convolutional layers (ZINC: filters 9,9,10 with kernels 9,9,11; QM9: filters 2,2,1 with kernels 5,5,4), followed by a fully connected layer</li>
<li><strong>Decoder</strong>: 3 x GRU layers (ZINC: hidden dim 488; QM9: hidden dim 500), trained with teacher forcing</li>
<li><strong>Property Predictor</strong>: 2 fully connected layers of 1000 neurons (dropout 0.20) for prediction; smaller 3-layer MLP of 67 neurons (dropout 0.15) for latent space shaping</li>
<li><strong>Variational loss annealing</strong>: Sigmoid schedule after 29 epochs, total 120 epochs</li>
<li><strong>SMILES validation</strong>: Post-hoc filtering with RDKit; invalid outputs discarded</li>
<li><strong>Optimization</strong>: Gaussian process surrogate model trained on 2000 maximally diverse molecules from latent space</li>
</ul>
<h3 id="models">Models</h3>
<p>Built with Keras and TensorFlow. Latent dimensions: 196 (ZINC), 156 (QM9). SMILES alphabet: 35 characters (ZINC), 22 characters (QM9). Maximum string length: 120 (ZINC), 34 (QM9). Only canonicalized SMILES used for training.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>logP</td>
          <td>Water-octanol partition coefficient</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>Quantitative Estimation of Drug-likeness (0-1)</td>
      </tr>
      <tr>
          <td>SAS</td>
          <td>Synthetic Accessibility Score</td>
      </tr>
      <tr>
          <td>HOMO/LUMO (eV)</td>
          <td>Frontier orbital energies (QM9)</td>
      </tr>
      <tr>
          <td>Decoding validity</td>
          <td>Fraction of latent points producing valid SMILES</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>Fraction of generated molecules not in training set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on the Harvard FAS Odyssey Cluster. Specific GPU types and training times are not reported. The Gaussian process optimization requires only minutes to train on a few thousand molecules.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/chemical_vae">chemical_vae</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation with training scripts and pre-trained models</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Gómez-Bombarelli, R., Wei, J. N., Duvenaud, D., Hernández-Lobato, J. M., Sánchez-Lengeling, B., Sheberla, D., Aguilera-Iparraguirre, J., Hirzel, T. D., Adams, R. P., &amp; Aspuru-Guzik, A. (2018). Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules. <em>ACS Central Science</em>, 4(2), 268-276. <a href="https://doi.org/10.1021/acscentsci.7b00572">https://doi.org/10.1021/acscentsci.7b00572</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{gomez2018automatic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{G{\&#39;o}mez-Bombarelli, Rafael and Wei, Jennifer N. and Duvenaud, David and Hern{\&#39;a}ndez-Lobato, Jos{\&#39;e} Miguel and S{\&#39;a}nchez-Lengeling, Benjam{\&#39;i}n and Sheberla, Dennis and Aguilera-Iparraguirre, Jorge and Hirzel, Timothy D. and Adams, Ryan P. and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ACS Central Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{268--276}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acscentsci.7b00572}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformers for Molecular Property Prediction Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformers-molecular-property-prediction-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformers-molecular-property-prediction-review/</guid><description>A systematic review of 16 transformer models for molecular property prediction, analyzing architecture, data, tokenization, and benchmarking gaps.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformers-for-molecular-property-prediction">A Systematization of Transformers for Molecular Property Prediction</h2>
<p>This is a <strong>Systematization</strong> paper. Sultan et al. provide the first comprehensive, structured review of sequence-based transformer models applied to molecular property prediction (MPP). The review catalogs 16 models published between 2019 and 2023, organizes them by architecture type (encoder-decoder, encoder-only, decoder-only), and systematically examines seven key design decisions that arise when building a transformer for MPP. The paper&rsquo;s primary contribution is identifying gaps in current evaluation practices and articulating what standardization the field needs for meaningful progress.</p>
<h2 id="the-problem-inconsistent-evaluation-hinders-progress">The Problem: Inconsistent Evaluation Hinders Progress</h2>
<p>Molecular property prediction is essential for drug discovery, crop protection, and environmental science. Deep learning approaches, including transformers, have been increasingly applied to this task by learning molecular representations from string notations like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>. However, the field faces several challenges:</p>
<ol>
<li><strong>Small labeled datasets</strong>: Labeled molecular property datasets typically contain only hundreds or thousands of molecules, making supervised learning alone insufficient.</li>
<li><strong>No standardized evaluation protocol</strong>: Different papers use different data splits (scaffold vs. random), different splitting implementations, different numbers of repetitions (3 to 50), and sometimes do not share their test sets. This makes direct comparison across models infeasible.</li>
<li><strong>Unclear design choices</strong>: With many possible configurations for pre-training data, chemical language, tokenization, positional embeddings, model size, pre-training objectives, and fine-tuning approaches, the field lacks systematic analyses to guide practitioners.</li>
</ol>
<p>The authors note that standard machine learning methods with fixed-size molecular fingerprints remain strong baselines for real-world datasets, illustrating that the promise of transformers for MPP has not yet been fully realized.</p>
<h2 id="seven-design-questions-for-molecular-transformers">Seven Design Questions for Molecular Transformers</h2>
<p>The central organizing framework of this review addresses seven questions practitioners must answer when building a transformer for MPP. For each, the authors synthesize findings across the 16 reviewed models.</p>
<h3 id="reviewed-models">Reviewed Models</h3>
<p>The paper catalogs 16 models organized by architecture:</p>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Base Model</th>
          <th>Models</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Encoder-Decoder</td>
          <td>Transformer, BART</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-transformer/">ST</a>, Transformer-CNN, <a href="/notes/chemistry/molecular-representations/encoders/x-mol-pretraining-molecular-understanding/">X-Mol</a>, <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a></td>
      </tr>
      <tr>
          <td>Encoder-Only</td>
          <td>BERT</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, MAT, <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, Mol-BERT, Chen et al., K-BERT, FP-BERT, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
      </tr>
      <tr>
          <td>Encoder-Only</td>
          <td>RoBERTa</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a></td>
      </tr>
      <tr>
          <td>Decoder-Only</td>
          <td>XLNet</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">Regression Transformer</a> (RT)</td>
      </tr>
  </tbody>
</table>
<p>The core attention mechanism shared by all these models is the scaled dot-product attention:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V
$$</p>
<p>where $Q$, $K$, and $V$ are the query, key, and value matrices, and $d_{k}$ is the dimension of the key vectors.</p>
<h3 id="question-1-which-database-and-how-many-molecules">Question 1: Which Database and How Many Molecules?</h3>
<p>Pre-training data sources vary considerably. The three main databases are ZINC (37 billion molecules in ZINC22), ChEMBL (2.4 million unique molecules with 20 million bioactivity measurements), and PubChem (111 million unique molecules). Pre-training set sizes ranged from 900K (ST on ChEMBL) to 1.1B molecules (MolFormer on ZINC + PubChem).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Database</th>
          <th>Size</th>
          <th>Language</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ST</td>
          <td>ChEMBL</td>
          <td>900K</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>ChEMBL (<a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>)</td>
          <td>1.6M</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></td>
          <td>PubChem</td>
          <td>100K-10M</td>
          <td>SMILES, SELFIES</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a></td>
          <td>PubChem</td>
          <td>5M-77M</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>ZINC</td>
          <td>2M</td>
          <td>List of atoms</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>ZINC + PubChem</td>
          <td>1.1B</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td>Chen et al.</td>
          <td>C, CP, CPZ</td>
          <td>2M-775M</td>
          <td>SMILES</td>
      </tr>
  </tbody>
</table>
<p>A key finding is that larger pre-training datasets do not consistently improve downstream performance. MolFormer showed minimal difference between models trained on 100M vs. 1.1B molecules. ChemBERTa-2 found that the model trained on 5M molecules using MLM performed comparably to 77M molecules for BBBP (both around 0.70 ROC-AUC). Chen et al. reported comparable $R^{2}$ values of $0.925 \pm 0.01$, $0.917 \pm 0.012$, and $0.915 \pm 0.01$ for ESOL across datasets of 2M, 103M, and 775M molecules, respectively. The data composition and covered chemical space appear to matter more than raw size.</p>
<h3 id="question-2-which-chemical-language">Question 2: Which Chemical Language?</h3>
<p>Most models use SMILES. ChemBERTa, RT, and SELFormer also explored SELFIES. MAT uses a simple list of atoms with structural features, while Mol-BERT and FP-BERT use circular fingerprints.</p>
<p>Direct comparisons between SMILES and SELFIES (by ChemBERTa on Tox21 SR-p53 and RT for drug-likeness prediction) found no significant performance difference. The RT authors reported that SELFIES models performed approximately $0.004 \pm 0.01$ better on RMSE, while SMILES models performed approximately $0.004 \pm 0.01$ better on Pearson correlation. The choice of chemical language does not appear to be a major factor in prediction performance, and even non-string representations (atom lists in MAT, fingerprints in Mol-BERT) perform competitively.</p>
<h3 id="question-3-how-to-tokenize">Question 3: How to Tokenize?</h3>
<p>Tokenization methods span atom-level (42-66 vocabulary tokens), regex-based (47-2,362 tokens), BPE (509-52K tokens), and substructure-based (3,357-13,325 tokens) approaches. No systematic comparison of tokenization strategies exists in the literature. The vocabulary size varied dramatically, from 42 tokens for MolBERT to over 52K for ChemBERTa. The authors argue that chemically meaningful tokenization (e.g., functional group-based fragmentation) could improve both performance and explainability.</p>
<h3 id="question-4-how-to-add-positional-embeddings">Question 4: How to Add Positional Embeddings?</h3>
<p>Most models inherited the absolute positional embedding from their NLP base models. MolBERT and RT adopted relative positional embeddings. MolFormer combined absolute and Rotary Positional Embedding (RoPE). MAT incorporated spatial information (inter-atomic 3D distances and adjacency) alongside self-attention.</p>
<p>MolFormer&rsquo;s comparison showed that RoPE became superior to absolute embeddings only when the pre-training dataset was very large. The performance difference (MAE on QM9) between absolute and RoPE embeddings for models trained on 111K, 111M, and 1.1B molecules was approximately $-0.20 \pm 0.18$, $-0.44 \pm 0.22$, and $0.27 \pm 0.12$, respectively.</p>
<p>The authors highlight that SMILES and SELFIES are linearizations of a 2D molecular graph, so consecutive tokens in a sequence are not necessarily spatially close. Positional embeddings that reflect 2D or 3D molecular structure remain underexplored.</p>
<h3 id="question-5-how-many-parameters">Question 5: How Many Parameters?</h3>
<p>Model sizes range from approximately 7M (ST, Mol-BERT) to over 100M parameters (MAT). Most chemical language models operate with 100M parameters or fewer, much smaller than NLP models like BERT (110M-330M) or GPT-3 (175B).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Dimensions</th>
          <th>Heads</th>
          <th>Layers</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ST</td>
          <td>256</td>
          <td>4</td>
          <td>4</td>
          <td>7M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>768</td>
          <td>12</td>
          <td>12</td>
          <td>85M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>768</td>
          <td>12</td>
          <td>6, 12</td>
          <td>43M, 85M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a></td>
          <td>768</td>
          <td>12, 4</td>
          <td>8, 12</td>
          <td>57M, 85M</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>1024</td>
          <td>16</td>
          <td>8</td>
          <td>101M</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></td>
          <td>768</td>
          <td>12</td>
          <td>6</td>
          <td>43M</td>
      </tr>
  </tbody>
</table>
<p>SELFormer and MolFormer both tested different model sizes. SELFormer&rsquo;s larger model (approximately 86M parameters) showed approximately 0.034 better ROC-AUC for BBBP compared to the smaller model. MolFormer&rsquo;s larger model (approximately 87M parameters) performed approximately 0.04 better ROC-AUC on average for BBBP, HIV, BACE, and SIDER. The field lacks the systematic scaling analyses (analogous to Kaplan et al. and Hoffmann et al. in NLP) needed to establish proper scaling laws for chemical language models.</p>
<h3 id="question-6-which-pre-training-objectives">Question 6: Which Pre-training Objectives?</h3>
<p>Pre-training objectives fall into domain-agnostic and domain-specific categories:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Pre-training Objective</th>
          <th>Fine-tuning</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
          <td>MLM</td>
          <td>Frozen, Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a></td>
          <td>MLM</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a></td>
          <td>MLM, PhysChemPred, SMILES-EQ</td>
          <td>Frozen, Update</td>
      </tr>
      <tr>
          <td>K-BERT</td>
          <td>Atom feature, MACCS prediction, CL</td>
          <td>Update last layer</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a></td>
          <td>MLM, MTR</td>
          <td>Update</td>
      </tr>
      <tr>
          <td>MAT</td>
          <td>MLM, 2D Adjacency, 3D Distance</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a></td>
          <td>Denoising Span MLM, Augmentation</td>
          <td>Update</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/regression-transformer/">RT</a></td>
          <td>PLM (Permutation Language Modeling)</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Domain-specific objectives (predicting physico-chemical properties, atom features, or MACCS keys) showed promising but inconsistent results. MolBERT&rsquo;s PhysChemPred performed closely to the full three-objective model (approximately $0.72 \pm 0.06$ vs. $0.71 \pm 0.06$ ROC-AUC in virtual screening). The SMILES-EQ objective (identifying equivalent SMILES) was found to lower performance when combined with other objectives. K-BERT&rsquo;s contrastive learning objective did not significantly change performance (average ROC-AUC of 0.806 vs. 0.807 with and without CL).</p>
<p>ChemBERTa-2&rsquo;s Multi-Task Regression (MTR) objective performed noticeably better than MLM-only for almost all four classification tasks across pre-training dataset sizes.</p>
<h3 id="question-7-how-to-fine-tune">Question 7: How to Fine-tune?</h3>
<p>Fine-tuning through weight updates generally outperforms frozen representations. SELFormer showed this most dramatically, with a difference of 2.187 RMSE between frozen and updated models on FreeSolv. MolBERT showed a much smaller difference (0.575 RMSE on FreeSolv), likely because its domain-specific pre-training objectives already produced representations closer to the downstream tasks.</p>
<h2 id="benchmarking-challenges-and-performance-comparison">Benchmarking Challenges and Performance Comparison</h2>
<h3 id="downstream-datasets">Downstream Datasets</h3>
<p>The review focuses on nine benchmark datasets across three categories from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Tasks</th>
          <th>Type</th>
          <th>Application</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>1,128</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>Hydration free energy</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td>1 regression</td>
          <td>Physical chemistry</td>
          <td>LogD at pH 7.4</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>2,050</td>
          <td>1 classification</td>
          <td>Physiology</td>
          <td>Blood-brain barrier</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>1,484</td>
          <td>2 classification</td>
          <td>Physiology</td>
          <td>Clinical trial toxicity</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>1,427</td>
          <td>27 classification</td>
          <td>Physiology</td>
          <td>Drug side effects</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>7,831</td>
          <td>12 classification</td>
          <td>Physiology</td>
          <td>Nuclear receptor/stress pathways</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>1,513</td>
          <td>1 classification</td>
          <td>Biophysics</td>
          <td>Beta-secretase 1 binding</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>41,127</td>
          <td>1 classification</td>
          <td>Biophysics</td>
          <td>Anti-HIV activity</td>
      </tr>
  </tbody>
</table>
<h3 id="inconsistencies-in-evaluation">Inconsistencies in Evaluation</h3>
<p>The authors document substantial inconsistencies that prevent fair model comparison:</p>
<ol>
<li><strong>Data splitting</strong>: Models used different splitting methods (scaffold vs. random) and different implementations even when using the same method. Not all models adhered to scaffold splitting for classification tasks as recommended.</li>
<li><strong>Different test sets</strong>: Even models using the same split type may not evaluate on identical test molecules due to different random seeds.</li>
<li><strong>Varying repetitions</strong>: Repetitions ranged from 3 (RT) to 50 (Chen et al.), making some analyses more statistically robust than others.</li>
<li><strong>Metric inconsistency</strong>: Most use ROC-AUC for classification and RMSE for regression, but some models report only averages without standard deviations, while others report standard errors.</li>
</ol>
<h3 id="performance-findings">Performance Findings</h3>
<p>When comparing only models evaluated on the same test sets (Figure 2 in the paper), the authors observe that transformer models show comparable, but not consistently superior, performance to existing ML and DL models. The performance varies considerably across models and datasets.</p>
<p>For BBBP, the Mol-BERT model reported lower ROC-AUC than its corresponding MPNN (approximately 0.88 vs. 0.91), while MolBERT outperformed its corresponding CDDD model (approximately 0.86 vs. 0.76 ROC-AUC) and its SVM baseline (approximately 0.86 vs. 0.70 ROC-AUC). A similar mixed pattern appeared for HIV: ChemBERTa performed worse than its corresponding ML models, while MolBERT performed better than its ML (approximately 0.08 higher ROC-AUC) and DL (approximately 0.03 higher ROC-AUC) baselines. For SIDER, Mol-BERT performed approximately 0.1 better ROC-AUC than its corresponding MPNN. For regression, MAT and MolBERT showed improved performance over their ML and DL baselines on ESOL, FreeSolv, and Lipophilicity. For example, MAT performed approximately 0.2 lower RMSE than an SVM model and approximately 0.03 lower RMSE than the Weave model on ESOL.</p>
<h2 id="key-takeaways-and-future-directions">Key Takeaways and Future Directions</h2>
<p>The review concludes with six main takeaways:</p>
<ol>
<li><strong>Performance</strong>: Transformers using SMILES show comparable but not consistently superior performance to existing ML and DL models for MPP.</li>
<li><strong>Scaling</strong>: No systematic analysis of model parameter scaling relative to data size exists for chemical language models. Such analysis is essential.</li>
<li><strong>Pre-training data</strong>: Dataset size alone is not the sole determinant of downstream performance. Composition and chemical space coverage matter.</li>
<li><strong>Chemical language</strong>: SMILES and SELFIES perform similarly. Alternative representations (atom lists, fingerprints) also work when the architecture is adjusted.</li>
<li><strong>Domain knowledge</strong>: Domain-specific pre-training objectives show promise, but tokenization and positional encoding remain underexplored.</li>
<li><strong>Benchmarking</strong>: The community needs standardized data splitting, fixed test sets, statistical analysis, and consistent reporting to enable meaningful comparison.</li>
</ol>
<p>The authors also highlight the need for attention visualization and explainability analysis, investigation of NLP-originated techniques (pre-training regimes, fine-tuning strategies like LoRA, explainability methods), and adaptation of these techniques to the specific characteristics of chemical data (smaller vocabularies, shorter sequences).</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a review paper. No new data or models are introduced. All analyses use previously reported results from the 16 reviewed papers, with additional visualization and comparison. The authors provide a GitHub repository with the code and data used to generate their comparative figures.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Not applicable (review paper). The paper describes training strategies at a conceptual level, referencing the original publications for implementation details.</p>
<h3 id="models">Models</h3>
<p>Not applicable (review paper). The paper catalogs 16 models with their architecture details, parameter counts, and training configurations across Tables 1, 4, 5, 6, and 7.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The paper compiles performance across nine MoleculeNet datasets. Key comparison figures (Figures 2 and 7) restrict to models evaluated on the same test sets for fair comparison, using ROC-AUC for classification and RMSE for regression.</p>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/volkamerlab/Transformers4MPP_review">Transformers4MPP_review</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Figure generation code and compiled data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sultan, A., Sieg, J., Mathea, M., &amp; Volkamer, A. (2024). Transformers for Molecular Property Prediction: Lessons Learned from the Past Five Years. <em>Journal of Chemical Information and Modeling</em>, 64(16), 6259-6280. <a href="https://doi.org/10.1021/acs.jcim.4c00747">https://doi.org/10.1021/acs.jcim.4c00747</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{sultan2024transformers,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformers for Molecular Property Prediction: Lessons Learned from the Past Five Years}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Sultan, Afnan and Sieg, Jochen and Mathea, Miriam and Volkamer, Andrea}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6259--6280}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.4c00747}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformer-CNN: SMILES Embeddings for QSAR Modeling</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformer-cnn-qsar-modeling/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/transformer-cnn-qsar-modeling/</guid><description>Transformer-CNN uses SMILES embeddings from a canonicalization Transformer with a CNN head for interpretable QSAR property prediction.</description><content:encoded><![CDATA[<h2 id="transformer-based-smiles-embeddings-for-property-prediction">Transformer-Based SMILES Embeddings for Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that introduces Transformer-CNN, a two-stage architecture for <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> (Quantitative Structure-Activity Relationship) modeling. The primary contribution is a transfer learning approach: a Transformer model is first trained on the task of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> canonicalization (mapping non-canonical SMILES to canonical forms), and the encoder&rsquo;s internal representations are then used as &ldquo;dynamic SMILES embeddings&rdquo; for downstream property prediction via a convolutional neural network (TextCNN). The authors also contribute an interpretability framework based on Layer-wise Relevance Propagation (LRP) that traces predictions back to individual atom contributions.</p>
<h2 id="from-descriptors-to-learned-embeddings-in-qsar">From Descriptors to Learned Embeddings in QSAR</h2>
<p>Traditional QSAR methods rely on hand-engineered molecular descriptors (fragment counts, physicochemical features) coupled with feature selection and classical ML algorithms. While deep learning approaches that operate on raw SMILES strings or molecular graphs have reduced the need for manual feature engineering, they typically require large training datasets to learn effective representations from scratch. QSAR datasets, in contrast, often contain only hundreds of molecules, making it difficult to train end-to-end deep models.</p>
<p>The authors identify two specific gaps. First, existing SMILES-based autoencoders such as <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (Continuous and Data-Driven molecular Descriptors) produce fixed-length latent vectors, discarding positional information that could be useful for property prediction and interpretation. Second, QSAR models built on deep architectures generally lack interpretability, making it hard to verify that predictions rely on chemically meaningful structural features rather than spurious correlations.</p>
<h2 id="dynamic-smiles-embeddings-via-canonicalization-pre-training">Dynamic SMILES Embeddings via Canonicalization Pre-training</h2>
<p>The core insight is that training a Transformer to perform SMILES canonicalization (a Seq2Seq task mapping non-canonical SMILES to canonical SMILES) produces an encoder whose internal states serve as information-rich, position-dependent molecular embeddings.</p>
<h3 id="pre-training-on-smiles-canonicalization">Pre-training on SMILES Canonicalization</h3>
<p>The Transformer encoder-decoder is trained on approximately 17.7 million canonicalization pairs derived from the <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> database (SMILES with length up to 110 characters). Each molecule is augmented 10 times by generating non-canonical SMILES variants, plus one identity pair where both sides are canonical. The training uses character-level tokenization with a 66-symbol vocabulary covering drug-like molecules including stereochemistry, charges, and inorganic ions.</p>
<p>The Transformer architecture follows Vaswani et al. with 3 layers and 10 self-attention heads. The learning rate schedule follows:</p>
<p>$$\lambda = \text{factor} \cdot \min(1.0,; \text{step} / \text{warmup}) / \max(\text{step},; \text{warmup})$$</p>
<p>where factor = 20, warmup = 16,000 steps, and $\lambda$ is clipped at a minimum of $10^{-4}$. Training runs for 10 epochs (275,907 batches per epoch) without early stopping.</p>
<p>On validation with 500,000 generated ChEMBL-like SMILES, the model correctly canonicalizes 83.6% of all samples. Performance drops for stereochemistry (37.2% for @-containing SMILES) and cis/trans notation (73.9%).</p>
<h3 id="from-encoder-states-to-qsar-predictions">From Encoder States to QSAR Predictions</h3>
<p>After pre-training, the encoder&rsquo;s output for a molecule with $N$ characters is a matrix of dimensions $(N, \text{EMBEDDINGS})$. Unlike fixed-length CDDD descriptors, these &ldquo;dynamic embeddings&rdquo; preserve positional information, meaning equivalent characters receive different embedding values depending on their context and position.</p>
<p>To handle variable-length embeddings, the authors use a TextCNN architecture (from DeepChem) with 1D convolutional filters at kernel sizes (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20) producing (100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160) filters respectively. After GlobalMaxPool and concatenation, the features pass through Dropout (rate = 0.25), a Dense layer ($N = 512$), a Highway layer, and finally an output layer (1 neuron for regression, 2 for classification).</p>
<p>The Transformer weights are frozen during QSAR training. The Adam optimizer is used with a fixed learning rate of $10^{-4}$ and early stopping on a 10% held-out validation set. Critically, SMILES augmentation ($n = 10$) is applied during both training and inference, with the final prediction being the average over augmented SMILES for each molecule.</p>
<h3 id="interpretability-via-layer-wise-relevance-propagation">Interpretability via Layer-wise Relevance Propagation</h3>
<p>The LRP algorithm propagates relevance scores from the output back through the CNN layers to the Transformer encoder output (which is position-wise). The relevance conservation property holds:</p>
<p>$$y = R = f(x) = \sum_{l \in (L)} R_{l} = \sum_{l \in (L-1)} R_{l} = \cdots = \sum_{l \in (1)} R_{l}$$</p>
<p>In practice, biases absorb some relevance, so the total propagated to the input is less than the output:</p>
<p>$$\sum_{l \in (L)} R_{l} = \sum_{l \in (L-1)} R_{l} + B$$</p>
<p>For gated connections in the Highway block, the authors implement the signal-take-all redistribution rule. The interpretation algorithm generates one SMILES per non-hydrogen atom (each drawn starting from that atom), runs LRP on each, and averages contributions. If more than 50% of relevance dissipates on biases, the interpretation may be unreliable, serving as an applicability domain indicator.</p>
<h2 id="benchmarks-across-18-regression-and-classification-datasets">Benchmarks Across 18 Regression and Classification Datasets</h2>
<p>The authors evaluate on the same 18 datasets (9 regression, 9 classification) used in their previous SMILES augmentation study, enabling direct comparison. All experiments use five-fold cross-validation.</p>
<h3 id="regression-results-r2">Regression Results ($r^2$)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">Descriptor-based</th>
          <th style="text-align: center">SMILES-based (augm=10)</th>
          <th style="text-align: center">Transformer-CNN (no augm)</th>
          <th style="text-align: center">Transformer-CNN (augm=10)</th>
          <th style="text-align: center">CDDD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MP (19,104)</td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center"><strong>0.86</strong></td>
          <td style="text-align: center">0.85</td>
      </tr>
      <tr>
          <td>BP (11,893)</td>
          <td style="text-align: center">0.98</td>
          <td style="text-align: center">0.98</td>
          <td style="text-align: center">0.97</td>
          <td style="text-align: center"><strong>0.98</strong></td>
          <td style="text-align: center">0.98</td>
      </tr>
      <tr>
          <td>BCF (378)</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.71</td>
          <td style="text-align: center"><strong>0.85</strong></td>
          <td style="text-align: center">0.81</td>
      </tr>
      <tr>
          <td>FreeSolv (642)</td>
          <td style="text-align: center"><strong>0.94</strong></td>
          <td style="text-align: center">0.93</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.93</td>
      </tr>
      <tr>
          <td>LogS (1,311)</td>
          <td style="text-align: center"><strong>0.92</strong></td>
          <td style="text-align: center">0.92</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.91</td>
      </tr>
      <tr>
          <td>Lipo (4,200)</td>
          <td style="text-align: center">0.70</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.60</td>
          <td style="text-align: center">0.73</td>
          <td style="text-align: center"><strong>0.74</strong></td>
      </tr>
      <tr>
          <td>BACE (1,513)</td>
          <td style="text-align: center">0.73</td>
          <td style="text-align: center">0.72</td>
          <td style="text-align: center">0.66</td>
          <td style="text-align: center"><strong>0.76</strong></td>
          <td style="text-align: center">0.75</td>
      </tr>
      <tr>
          <td>DHFR (739)</td>
          <td style="text-align: center">0.62</td>
          <td style="text-align: center">0.63</td>
          <td style="text-align: center">0.46</td>
          <td style="text-align: center"><strong>0.67</strong></td>
          <td style="text-align: center">0.61</td>
      </tr>
      <tr>
          <td>LEL (483)</td>
          <td style="text-align: center">0.19</td>
          <td style="text-align: center">0.25</td>
          <td style="text-align: center">0.20</td>
          <td style="text-align: center"><strong>0.27</strong></td>
          <td style="text-align: center">0.23</td>
      </tr>
  </tbody>
</table>
<h3 id="classification-results-auc">Classification Results (AUC)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">Descriptor-based</th>
          <th style="text-align: center">SMILES-based (augm=10)</th>
          <th style="text-align: center">Transformer-CNN (no augm)</th>
          <th style="text-align: center">Transformer-CNN (augm=10)</th>
          <th style="text-align: center">CDDD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HIV (41,127)</td>
          <td style="text-align: center">0.82</td>
          <td style="text-align: center">0.78</td>
          <td style="text-align: center">0.81</td>
          <td style="text-align: center"><strong>0.83</strong></td>
          <td style="text-align: center">0.74</td>
      </tr>
      <tr>
          <td>AMES (6,542)</td>
          <td style="text-align: center">0.86</td>
          <td style="text-align: center">0.88</td>
          <td style="text-align: center">0.86</td>
          <td style="text-align: center"><strong>0.89</strong></td>
          <td style="text-align: center">0.86</td>
      </tr>
      <tr>
          <td>BACE (1,513)</td>
          <td style="text-align: center">0.88</td>
          <td style="text-align: center">0.89</td>
          <td style="text-align: center">0.89</td>
          <td style="text-align: center"><strong>0.91</strong></td>
          <td style="text-align: center">0.90</td>
      </tr>
      <tr>
          <td>ClinTox (1,478)</td>
          <td style="text-align: center"><strong>0.77</strong></td>
          <td style="text-align: center">0.76</td>
          <td style="text-align: center">0.71</td>
          <td style="text-align: center">0.77</td>
          <td style="text-align: center">0.73</td>
      </tr>
      <tr>
          <td>Tox21 (7,831)</td>
          <td style="text-align: center">0.79</td>
          <td style="text-align: center"><strong>0.83</strong></td>
          <td style="text-align: center">0.81</td>
          <td style="text-align: center">0.82</td>
          <td style="text-align: center">0.82</td>
      </tr>
      <tr>
          <td>BBBP (2,039)</td>
          <td style="text-align: center">0.90</td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.90</td>
          <td style="text-align: center"><strong>0.92</strong></td>
          <td style="text-align: center">0.89</td>
      </tr>
      <tr>
          <td>JAK3 (886)</td>
          <td style="text-align: center">0.79</td>
          <td style="text-align: center"><strong>0.80</strong></td>
          <td style="text-align: center">0.70</td>
          <td style="text-align: center">0.78</td>
          <td style="text-align: center">0.76</td>
      </tr>
      <tr>
          <td>BioDeg (1,737)</td>
          <td style="text-align: center">0.92</td>
          <td style="text-align: center"><strong>0.93</strong></td>
          <td style="text-align: center">0.91</td>
          <td style="text-align: center">0.93</td>
          <td style="text-align: center">0.92</td>
      </tr>
      <tr>
          <td>RP AR (930)</td>
          <td style="text-align: center">0.85</td>
          <td style="text-align: center"><strong>0.87</strong></td>
          <td style="text-align: center">0.83</td>
          <td style="text-align: center">0.87</td>
          <td style="text-align: center">0.86</td>
      </tr>
  </tbody>
</table>
<h3 id="key-comparisons">Key Comparisons</h3>
<p>Baselines include descriptor-based methods (the best from LibSVM, Random Forest, XGBoost, ASNN, and DNNs), direct SMILES-based models with augmentation, and CDDD descriptors analyzed by the same classical ML methods. CDDD descriptors come from the Sml2canSml autoencoder approach, which produces fixed 512-dimensional vectors.</p>
<p>Transformer-CNN with augmentation matches or exceeds all baselines on 14 of 18 datasets. The effect of augmentation is dramatic: without it, Transformer-CNN underperforms substantially (e.g., BCF drops from 0.85 to 0.71, JAK3 from 0.78 to 0.70). This confirms that the internal consensus from multiple SMILES representations is essential to the method&rsquo;s effectiveness.</p>
<p>A practical advantage over CDDD is that Transformer-CNN imposes no constraints on molecular properties (CDDD requires logP in (-5, 7), molecular weight under 12,600, 3-50 heavy atoms, and organic molecules only), since the Transformer was trained on the full diversity of ChEMBL.</p>
<h3 id="interpretability-case-studies">Interpretability Case Studies</h3>
<p>For <a href="https://en.wikipedia.org/wiki/Ames_test">AMES</a> mutagenicity, the LRP analysis of 1-Bromo-4-nitrobenzene correctly identifies the nitro group and halogen as structural alerts, consistent with known mutagenicity rules. For aqueous solubility of <a href="https://en.wikipedia.org/wiki/Haloperidol">haloperidol</a>, the model assigns positive contributions to hydroxyl, carbonyl, and aliphatic nitrogen groups (which increase solubility) and negative contributions to aromatic carbons (which decrease it). Both cases align with established chemical knowledge, supporting the trustworthiness of the model.</p>
<h2 id="effective-transfer-learning-for-small-qsar-datasets">Effective Transfer Learning for Small QSAR Datasets</h2>
<p>Transformer-CNN achieves competitive or superior QSAR performance across 18 diverse benchmarks by combining three ingredients: (1) Transformer-based pre-training via SMILES canonicalization, (2) SMILES augmentation during training and inference, and (3) a lightweight CNN head. The method requires minimal hyperparameter tuning, as the Transformer weights are frozen and the CNN architecture is fixed.</p>
<p>The authors acknowledge several limitations and future directions:</p>
<ul>
<li>Stereochemistry canonicalization accuracy is low (37.2%), which could impact models for stereo-sensitive properties</li>
<li>The LRP interpretability depends on sufficient relevance propagation (at least 50% reaching the input layer)</li>
<li>The variance among augmented SMILES predictions could serve as a confidence estimate, but this is left to future work</li>
<li>Applicability domain assessment based on SMILES reconstruction quality is proposed but not fully developed</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL (SMILES &lt;= 110 chars)</td>
          <td>17.7M pairs</td>
          <td>10x augmentation + 1 identity pair per molecule</td>
      </tr>
      <tr>
          <td>Validation (canon.)</td>
          <td>Generated ChEMBL-like SMILES</td>
          <td>500,000</td>
          <td>From a molecular generator</td>
      </tr>
      <tr>
          <td>QSAR benchmarks</td>
          <td>9 regression + 9 classification</td>
          <td>378-41,127</td>
          <td>Available on OCHEM (<a href="https://ochem.eu">https://ochem.eu</a>)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer: 3 layers, 10 self-attention heads, character-level tokenization (66 symbols)</li>
<li>TextCNN: 12 kernel sizes (1-10, 15, 20) with 100-200 filters each, GlobalMaxPool, Dense(512), Highway, Dropout(0.25)</li>
<li>Augmentation: n=10 non-canonical SMILES per molecule during training and inference</li>
<li>LRP: signal-take-all redistribution for Highway gates, standard LRP for Dense and Conv layers</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Transformer encoder weights pre-trained on canonicalization task (frozen during QSAR training)</li>
<li>QSAR CNN trained with Adam optimizer, learning rate $10^{-4}$, early stopping</li>
<li>Pre-trained embeddings and standalone prediction models available in the GitHub repository</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Regression: coefficient of determination $r^2 = 1 - SS_{\text{res}} / SS_{\text{tot}}$</li>
<li>Classification: Area Under the ROC Curve (AUC)</li>
<li>Five-fold cross-validation with bootstrap standard errors</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>NVIDIA Quadro P6000, Titan Xp, and Titan V GPUs (donated by NVIDIA)</li>
<li>TensorFlow v1.12.0, RDKit v2018.09.2</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/bigchem/transformer-cnn">transformer-cnn</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Source code, pre-trained embeddings, standalone prediction models</td>
      </tr>
      <tr>
          <td><a href="https://ochem.eu">OCHEM</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online platform hosting the method, training datasets, and models</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Karpov, P., Godin, G., &amp; Tetko, I. V. (2020). Transformer-CNN: Swiss knife for QSAR modeling and interpretation. <em>Journal of Cheminformatics</em>, 12, 17. <a href="https://doi.org/10.1186/s13321-020-00423-w">https://doi.org/10.1186/s13321-020-00423-w</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{karpov2020transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer-{CNN}: Swiss knife for {QSAR} modeling and interpretation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Karpov, Pavel and Godin, Guillaume and Tetko, Igor V.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00423-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Transformer Name-to-SMILES with Atom Count Losses</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/transformer-chemical-name-to-smiles/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/transformer-chemical-name-to-smiles/</guid><description>A Transformer seq2seq model translates chemical compound names to SMILES, using atom-count constraints and SMILES/InChI multi-task learning.</description><content:encoded><![CDATA[<h2 id="translating-chemical-names-to-structures-with-transformers">Translating Chemical Names to Structures with Transformers</h2>
<p>This is a <strong>Method</strong> paper that proposes using Transformer-based sequence-to-sequence models to predict chemical compound structures (represented as SMILES strings) from chemical compound names. The primary contribution is the application of neural machine translation techniques to the name-to-structure problem, along with two domain-specific improvements: an atom-count constraint loss function and a multi-task learning approach that jointly predicts SMILES and InChI strings.</p>
<h2 id="why-rule-based-name-to-structure-fails-for-synonyms">Why Rule-Based Name-to-Structure Fails for Synonyms</h2>
<p>Chemical compound names come in several varieties. IUPAC names follow systematic nomenclature and are well-handled by rule-based parsers like OPSIN. Database IDs (e.g., CAS registry numbers) can be resolved by dictionary lookup. The third category, Synonyms (which includes abbreviations, common names, and other informal designations), is problematic because naming patterns are complex and widely variable.</p>
<p>In preliminary experiments, rule-based tools achieved F-measures of 0.878 to 0.960 on IUPAC names but only 0.719 to 0.758 on Synonyms. This performance gap motivates a data-driven approach. The authors frame name-to-SMILES prediction as a machine translation problem: the source language is the chemical compound name and the target language is the SMILES string. A neural model trained on millions of name-SMILES pairs can learn patterns that rule-based systems miss, particularly for non-systematic nomenclature.</p>
<h2 id="atom-count-constraints-and-multi-task-learning">Atom-Count Constraints and Multi-Task Learning</h2>
<p>The paper introduces two improvements over a vanilla Transformer seq2seq model.</p>
<h3 id="atom-count-constraint-loss">Atom-Count Constraint Loss</h3>
<p>A correct structure prediction must contain the right number of atoms of each element. The authors add an auxiliary loss that penalizes the squared difference between the predicted and true atom counts for each element. The predicted atom counts are obtained by summing Gumbel-softmax outputs across all decoded positions.</p>
<p>For the $i$-th output token, the Gumbel-softmax probability vector is:</p>
<p>$$
y_{ij} = \frac{\exp\left((\log(\pi_{ij}) + g_{ij}) / \tau\right)}{\sum_{k=1}^{|\mathcal{V}|} \exp\left((\log(\pi_{ik}) + g_{ik}) / \tau\right)}
$$</p>
<p>where $\pi_{ij}$ is the model&rsquo;s softmax output, $g_{ij}$ is a Gumbel noise sample, and $\tau = 0.1$ is the temperature. The predicted token frequency vector is $\mathbf{y}^{pred} = \sum_{i=1}^{m} \mathbf{y}_i$, and the atom-count loss is:</p>
<p>$$
\mathcal{L}_{atom} = \frac{1}{|A|} \sum_{a \in A} \left(N_a(T) - y_{idx(a)}^{pred}\right)^2
$$</p>
<p>where $A$ is the set of chemical elements in the vocabulary, $N_a(T)$ returns the number of atoms of element $a$ in the correct SMILES string $T$, and $idx(a)$ returns the vocabulary index of element $a$. Only element tokens (e.g., &ldquo;C&rdquo;, &ldquo;O&rdquo;) are counted; bond symbols (e.g., &ldquo;=&rdquo;, &ldquo;#&rdquo;) are excluded.</p>
<p>The combined objective is:</p>
<p>$$
\mathcal{L}_{smiles} + \lambda_{atom} \mathcal{L}_{atom}
$$</p>
<p>with $\lambda_{atom} = 0.7$.</p>
<h3 id="multi-task-smilesinchi-prediction">Multi-Task SMILES/InChI Prediction</h3>
<p>SMILES and InChI strings encode the same chemical structure in different formats. The authors hypothesize that jointly predicting both representations can improve the shared encoder. The multi-task model shares the encoder between a SMILES decoder and an InChI decoder, minimizing:</p>
<p>$$
\mathcal{L}_{smiles} + \lambda_{inchi} \mathcal{L}_{inchi}
$$</p>
<p>where $\mathcal{L}_{inchi} = -\log P(I | X; \boldsymbol{\theta}_{enc}, \boldsymbol{\theta}_{inchi})$ and $\lambda_{inchi} = 0.3$.</p>
<h2 id="experimental-setup-and-evaluation">Experimental Setup and Evaluation</h2>
<h3 id="dataset">Dataset</h3>
<p>The dataset was constructed from PubChem dump data (97M compound records). Chemical compound names categorized as Synonyms were paired with canonical SMILES strings (converted via RDKit). Database-like IDs were filtered out using regular expressions. Duplicate names mapping to different CIDs were removed.</p>
<table>
  <thead>
      <tr>
          <th>Split</th>
          <th>Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>5,000,000</td>
      </tr>
      <tr>
          <td>Development</td>
          <td>1,113</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>11,194</td>
      </tr>
  </tbody>
</table>
<h3 id="model-configuration">Model Configuration</h3>
<p>The Transformer uses 6 encoder/decoder layers, 8 attention heads, 512-dimensional embeddings, and 0.1 dropout. Training used label-smoothing cross-entropy ($\epsilon = 0.1$), Adam optimizer ($\beta_1 = 0.9$, $\beta_2 = 0.98$), and a warmup schedule with peak learning rate 0.0005 over 4,000 steps followed by inverse square root decay. Models were trained for 300,000 update steps. Final predictions averaged the last 10 checkpoints and used beam search (beam size 4, length penalty $\alpha = 0.6$, max output length 200).</p>
<h3 id="tokenization">Tokenization</h3>
<p>Three tokenization strategies were compared:</p>
<ul>
<li><strong>BPE</strong>: Byte pair encoding learned on chemical compound names (500 merge operations) via fastBPE</li>
<li><strong>OPSIN-TK</strong>: The OPSIN rule-based tokenizer</li>
<li><strong>OPSIN-TK+BPE</strong>: A hybrid where OPSIN handles tokenizable names and BPE handles the rest</li>
</ul>
<p>SMILES tokens were identified by regular expressions (elements as single tokens, remaining symbols as characters). InChI strings were tokenized by SentencePiece (vocabulary size 1,000).</p>
<h3 id="baselines">Baselines</h3>
<ul>
<li><strong>OPSIN</strong>: Open-source rule-based parser</li>
<li><strong>Tool A</strong> and <strong>Tool B</strong>: Two commercially available name-to-structure tools</li>
</ul>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Tokenizer</th>
          <th>Recall</th>
          <th>Precision</th>
          <th>F-measure</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OPSIN</td>
          <td>Rule-based</td>
          <td>0.693</td>
          <td>0.836</td>
          <td>0.758</td>
      </tr>
      <tr>
          <td>Tool A</td>
          <td>Rule-based</td>
          <td>0.711</td>
          <td>0.797</td>
          <td>0.752</td>
      </tr>
      <tr>
          <td>Tool B</td>
          <td>Rule-based</td>
          <td>0.653</td>
          <td>0.800</td>
          <td>0.719</td>
      </tr>
      <tr>
          <td>Transformer</td>
          <td>BPE</td>
          <td>0.793</td>
          <td>0.806</td>
          <td>0.799</td>
      </tr>
      <tr>
          <td>+ atomnum</td>
          <td>BPE</td>
          <td>0.798</td>
          <td>0.808</td>
          <td>0.803</td>
      </tr>
      <tr>
          <td>+ inchigen</td>
          <td>BPE</td>
          <td>0.810</td>
          <td>0.819</td>
          <td>0.814</td>
      </tr>
      <tr>
          <td>Transformer</td>
          <td>OPSIN-TK+BPE</td>
          <td>0.763</td>
          <td>0.873</td>
          <td>0.814</td>
      </tr>
      <tr>
          <td>+ atomnum</td>
          <td>OPSIN-TK+BPE</td>
          <td>0.768</td>
          <td>0.876</td>
          <td>0.818</td>
      </tr>
      <tr>
          <td>+ inchigen</td>
          <td>OPSIN-TK+BPE</td>
          <td>0.779</td>
          <td>0.886</td>
          <td>0.829</td>
      </tr>
      <tr>
          <td>Transformer</td>
          <td>OPSIN-TK</td>
          <td>0.755</td>
          <td>0.868</td>
          <td>0.808</td>
      </tr>
      <tr>
          <td>+ atomnum</td>
          <td>OPSIN-TK</td>
          <td>0.757</td>
          <td>0.867</td>
          <td>0.808</td>
      </tr>
      <tr>
          <td>+ inchigen</td>
          <td>OPSIN-TK</td>
          <td>0.754</td>
          <td>0.869</td>
          <td>0.807</td>
      </tr>
  </tbody>
</table>
<p>The best configuration (inchigen with OPSIN-TK+BPE) achieved an F-measure of 0.829, surpassing OPSIN by 0.071 points. The multi-task learning approach (inchigen) consistently outperformed the atom-count constraint alone (atomnum) across all tokenizer settings.</p>
<h2 id="key-findings-and-error-analysis">Key Findings and Error Analysis</h2>
<p>The Transformer-based approach produced grammatically correct SMILES strings (parseable by RDKit) for 99% of test examples, compared to 81.6-88.4% for the rule-based tools. Even when predictions were incorrect, they tended to be structurally similar to the correct answer. Using MACCS fingerprints and Jaccard (Tanimoto) similarity, the average similarity between incorrectly predicted and correct structures was 0.753.</p>
<p>The OPSIN-TK tokenizer yielded higher precision than BPE because approximately 11.5% (1,293 of 11,194) of test compounds could not be tokenized by OPSIN, reducing the number of outputs. BPE-based tokenizers achieved higher recall by covering all inputs. The hybrid OPSIN-TK+BPE approach balanced both, achieving the highest overall F-measure.</p>
<p><strong>Limitations</strong>: The paper does not evaluate on IUPAC names separately with the Transformer models (only comparing rule-based tools on IUPAC). The atom-count constraint and multi-task learning are not combined in a single model. The dataset is released but the training code is not. Hardware details and training times are not reported. The evaluation uses only exact-match F-measure and Jaccard similarity, without measuring partial credit for nearly-correct structures.</p>
<p><strong>Future work</strong>: The authors plan to explore additional tokenization methods, combine the atom-count constraint with multi-task learning, and apply the constraint loss to other chemistry problems including chemical reaction prediction.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>PubChem Synonyms (custom split)</td>
          <td>5,000,000 pairs</td>
          <td>Chemical compound names to canonical SMILES</td>
      </tr>
      <tr>
          <td>Development</td>
          <td>PubChem Synonyms (custom split)</td>
          <td>1,113 pairs</td>
          <td>Filtered for duplicates</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>PubChem Synonyms (custom split)</td>
          <td>11,194 pairs</td>
          <td>Filtered for duplicates; released as benchmark</td>
      </tr>
  </tbody>
</table>
<p>The authors state the dataset is released for future research. The data was constructed from the PubChem dump (97M compound records) using RDKit for SMILES canonicalization. Database-like IDs were removed with regular expressions and duplicate names across CIDs were filtered.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer seq2seq (6 layers, 8 heads, 512-dim embeddings)</li>
<li>BPE tokenization via fastBPE (500 merge operations)</li>
<li>SentencePiece for InChI tokenization (vocabulary size 1,000)</li>
<li>Gumbel-softmax atom-count constraint ($\tau = 0.1$, $\lambda_{atom} = 0.7$)</li>
<li>Multi-task SMILES/InChI loss ($\lambda_{inchi} = 0.3$)</li>
<li>Adam optimizer ($\beta_1 = 0.9$, $\beta_2 = 0.98$, $\epsilon = 10^{-8}$)</li>
<li>Label smoothing ($\epsilon = 0.1$), 300K training steps</li>
<li>Beam search (beam size 4, length penalty $\alpha = 0.6$)</li>
</ul>
<h3 id="models">Models</h3>
<p>Standard Transformer architecture following Vaswani et al. (2017). No pre-trained weights or model checkpoints are released.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value</th>
          <th>Model</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>F-measure</td>
          <td>0.829</td>
          <td>inchigen (OPSIN-TK+BPE)</td>
          <td>Highest overall</td>
      </tr>
      <tr>
          <td>Precision</td>
          <td>0.886</td>
          <td>inchigen (OPSIN-TK+BPE)</td>
          <td>Highest overall</td>
      </tr>
      <tr>
          <td>Recall</td>
          <td>0.810</td>
          <td>inchigen (BPE)</td>
          <td>Highest overall</td>
      </tr>
      <tr>
          <td>Grammatical correctness</td>
          <td>99%</td>
          <td>inchigen (BPE)</td>
          <td>SMILES parseable by RDKit</td>
      </tr>
      <tr>
          <td>Avg. Jaccard similarity (errors)</td>
          <td>0.753</td>
          <td>inchigen (BPE)</td>
          <td>On incorrect predictions only</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not reported.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Omote, Y., Matsushita, K., Iwakura, T., Tamura, A., &amp; Ninomiya, T. (2020). Transformer-based Approach for Predicting Chemical Compound Structures. <em>Proceedings of the 1st Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics and the 10th International Joint Conference on Natural Language Processing</em>, 154-162. <a href="https://doi.org/10.18653/v1/2020.aacl-main.19">https://doi.org/10.18653/v1/2020.aacl-main.19</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{omote2020transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer-based Approach for Predicting Chemical Compound Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Omote, Yutaro and Matsushita, Kyoumoto and Iwakura, Tomoya and Tamura, Akihiro and Ninomiya, Takashi}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 1st Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics and the 10th International Joint Conference on Natural Language Processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{154--162}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2020.aacl-main.19}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Survey of Transformer Architectures in Molecular Science</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/transformers-molecular-science-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/transformers-molecular-science-review/</guid><description>A comprehensive review of 12 transformer architectures applied to molecular science, covering GPT, BERT, BART, graph transformers, and more.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-transformer-architectures-for-molecular-science">A Systematization of Transformer Architectures for Molecular Science</h2>
<p>This paper is a <strong>Systematization</strong> review. It organizes and taxonomizes 12 families of transformer architectures that have been applied across molecular science, including chemistry, biology, and drug discovery. The primary contribution is not a new method or dataset, but a structured technical overview of the algorithmic internals of each transformer variant and their specific applications to molecular problems. The review covers 201 references and provides a unified treatment of how these architectures capture molecular patterns from sequential, graphical, and image-based data.</p>
<h2 id="bridging-the-gap-between-transformer-variants-and-molecular-applications">Bridging the Gap Between Transformer Variants and Molecular Applications</h2>
<p>Transformer-based models have become widespread in molecular science, yet the authors identify a gap: there is no organized taxonomy linking these diverse techniques in the existing literature. Individual papers introduce specific architectures or applications, but practitioners lack a unified reference that explains the technical differences between GPT, BERT, BART, graph transformers, and other variants in the context of molecular data. The review aims to fill this gap by providing an in-depth investigation of the algorithmic components of each model family, explaining how their architectural innovations contribute to processing complex molecular data. The authors note that the success of transformers in molecular science stems from several factors: the sequential nature of chemical and biological molecules (DNA, RNA, proteins, SMILES strings), the attention mechanism&rsquo;s ability to capture long-range dependencies within molecular structures, and the capacity for transfer learning through pre-training on large chemical and biological datasets.</p>
<h2 id="twelve-transformer-families-and-their-molecular-mechanisms">Twelve Transformer Families and Their Molecular Mechanisms</h2>
<p>The review covers transformer preliminaries before diving into 12 specific architecture families. The core self-attention mechanism computes:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$</p>
<p>where $d_k$ is the dimension of the key vectors. The position-wise feed-forward network is:</p>
<p>$$
\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2
$$</p>
<p>The 12 architecture families covered are:</p>
<ol>
<li>
<p><strong>GPT (Generative Pre-trained Transformer)</strong>: Uses the decoder part of the transformer for autoregressive generation. Applications include MolGPT for molecular generation, DrugGPT for protein-ligand binding, and cMolGPT for target-specific de novo molecular generation.</p>
</li>
<li>
<p><strong>BERT (Bidirectional Encoder Representations from Transformers)</strong>: Uses transformer encoders with masked language modeling and next-sentence prediction for pre-training. Molecular applications include FP-BERT for molecular property prediction using composite fingerprint representations, Graph-BERT for protein-protein interaction identification, SMILES-BERT, and Mol-BERT.</p>
</li>
<li>
<p><strong>BART (Bidirectional and Auto-Regressive Transformers)</strong>: Functions as a denoising autoencoder with both encoder and decoder. Molecular applications include Chemformer for sequence-to-sequence chemistry tasks, MS2Mol for mass spectrometry analysis, and MolBART for molecular feature learning.</p>
</li>
<li>
<p><strong>Graph Transformer</strong>: Leverages self-attention on graph-structured data to capture global context. Applications include GraphSite for protein-DNA binding site prediction (using AlphaFold2 structure predictions), KPGT for knowledge-guided molecular graph pre-training, and PAGTN for establishing long-range dependencies in molecular graphs.</p>
</li>
<li>
<p><strong>Transformer-XL</strong>: Incorporates relative positional encoding for modeling long sequences. Used for small molecule retention time prediction, drug design with CHEMBL data (1.27 million molecules), and Heck reaction generation.</p>
</li>
<li>
<p><strong><a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5 (Text-to-Text Transfer Transformer)</a></strong>: Unifies NLP tasks into text-to-text mapping. T5Chem was pre-trained on 97 million molecules from PubChem and achieved 99.5% accuracy on reaction classification (USPTO 500 MT). C5T5 uses IUPAC naming for molecular optimization in drug discovery.</p>
</li>
<li>
<p><strong>Vision Transformer (ViT)</strong>: Applies transformer architecture to image patches. Used for organic molecule classification (97% accuracy with WGAN-generated data), bacterial identification via SERS, and molecular property prediction from mass spectrometry data (TransG-Net).</p>
</li>
<li>
<p><strong>DETR (Detection Transformer)</strong>: End-to-end object detection using transformers. Applied to cryo-EM particle picking (TransPicker), molecular structure image recognition (IMG2SMI), and cell segmentation (Cell-DETR).</p>
</li>
<li>
<p><strong>Conformer</strong>: Integrates convolutional modules into transformer structure. Used for DNA storage error correction (RRCC-DNN), drug-target affinity prediction (NG-DTA with Davis and Kiba datasets).</p>
</li>
<li>
<p><strong>CLIP (Contrastive Language-Image Pre-training)</strong>: Multimodal learning linking text and images. Applied to peptide design (Cut&amp;CLIP for protein degradation), gene identification (pathCLIP), and drug discovery (CLOOME for zero-shot transfer learning).</p>
</li>
<li>
<p><strong>Sparse Transformers</strong>: Use sparse attention matrices to reduce complexity to $O(n\sqrt{n})$. Applied to drug-target interaction prediction with gated cross-attention mechanisms.</p>
</li>
<li>
<p><strong>Mobile and Efficient Transformers</strong>: Compressed variants (TinyBERT, MobileBERT) for resource-constrained environments. Molormer uses ProbSparse self-attention for drug-drug interaction prediction. LOGO is a lightweight pre-trained language model for non-coding genome interpretation.</p>
</li>
</ol>
<h2 id="survey-organization-and-coverage-of-molecular-domains">Survey Organization and Coverage of Molecular Domains</h2>
<p>As a survey paper, this work does not present new experiments. Instead, it catalogues existing applications across multiple molecular domains:</p>
<p><strong>Drug Discovery and Design</strong>: GPT-based ligand design (DrugGPT), BART-based molecular generation (Chemformer, MolBART), graph transformer pre-training for molecular property prediction (KPGT), T5-based chemical reaction prediction (T5Chem), and sparse transformer methods for drug-target interactions.</p>
<p><strong>Protein Science</strong>: BERT-based protein-protein interaction prediction (Graph-BERT), graph transformer methods for protein-DNA binding (GraphSite with AlphaFold2 integration), conformer-based drug-target affinity prediction (NG-DTA), and CLIP-based peptide design (Cut&amp;CLIP).</p>
<p><strong>Molecular Property Prediction</strong>: FP-BERT for fingerprint-based prediction, SMILES-BERT and Mol-BERT for end-to-end prediction from SMILES, KPGT for knowledge-guided graph pre-training, and Transformer-XL for property modeling with relative positional encoding.</p>
<p><strong>Structural Biology</strong>: DETR-based cryo-EM particle picking (TransPicker), vision transformer applications in cell imaging, and Cell-DETR for instance segmentation in microscopy.</p>
<p><strong>Genomics</strong>: Conformer-based DNA storage error correction (RRCC-DNN), LOGO for non-coding genome interpretation, and MetaTransformer for metagenomic sequencing analysis.</p>
<h2 id="future-directions-and-limitations-of-the-survey">Future Directions and Limitations of the Survey</h2>
<p>The review concludes with four future directions:</p>
<ol>
<li>
<p><strong>ChatGPT integration into molecular science</strong>: Using LLMs for data analysis, literature review, and hypothesis generation in chemistry and biology.</p>
</li>
<li>
<p><strong>Multifunction transformers</strong>: Models that extract features across diverse molecular structures and sequences simultaneously.</p>
</li>
<li>
<p><strong>Molecular-aware transformers</strong>: Architectures that handle multiple data types (text, sequence, structure, image, energy, molecular dynamics, function) in a unified framework.</p>
</li>
<li>
<p><strong>Self-assessment transformers and superintelligence</strong>: Speculative discussion of models that learn from seemingly unrelated data sources.</p>
</li>
</ol>
<p>The review has several limitations worth noting. The coverage is broad but shallow: each architecture family receives only 1-2 pages of discussion, and the paper largely describes existing work rather than critically evaluating it. The review does not systematically compare the architectures against each other on common benchmarks. The future directions section (particularly the superintelligence discussion) is speculative and lacks concrete proposals. The paper also focuses primarily on technical architecture descriptions rather than analyzing failure modes, scalability challenges, or reproducibility concerns across the surveyed methods. As a review article, no new data were created or analyzed.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a survey paper. No new datasets were created or used. The paper reviews applications involving datasets such as PubChem (97 million molecules for T5Chem), CHEMBL (1.27 million molecules for Transformer-XL drug design), USPTO 500 MT (reaction classification), ESOL (5,328 molecules for property prediction), and Davis/Kiba (drug-target affinity).</p>
<h3 id="algorithms">Algorithms</h3>
<p>No new algorithms are introduced. The paper provides mathematical descriptions of the core transformer components (self-attention, positional encoding, feed-forward networks, layer normalization) and describes how 12 architecture families modify these components.</p>
<h3 id="models">Models</h3>
<p>No new models are presented. The paper surveys existing models including MolGPT, DrugGPT, FP-BERT, SMILES-BERT, Chemformer, MolBART, GraphSite, KPGT, T5Chem, TransPicker, Cell-DETR, CLOOME, and Molormer, among others.</p>
<h3 id="evaluation">Evaluation</h3>
<p>No new evaluation is performed. Performance numbers cited from the literature include: T5Chem reaction classification accuracy of 99.5%, ViT organic molecule classification at 97%, Transformer-XL property prediction RMSE of 0.6 on ESOL, and Heck reaction generation feasibility rate of 47.76%.</p>
<h3 id="hardware">Hardware</h3>
<p>No hardware requirements are specified, as this is a survey paper.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://onlinelibrary.wiley.com/doi/pdfdirect/10.1002/wcms.1725">Paper (open access)</a></td>
          <td>Paper</td>
          <td>CC-BY-NC-ND</td>
          <td>Open access via Wiley</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jiang, J., Ke, L., Chen, L., Dou, B., Zhu, Y., Liu, J., Zhang, B., Zhou, T., &amp; Wei, G.-W. (2024). Transformer technology in molecular science. <em>WIREs Computational Molecular Science</em>, 14(4), e1725. <a href="https://doi.org/10.1002/wcms.1725">https://doi.org/10.1002/wcms.1725</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{jiang2024transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer technology in molecular science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Jiang, Jian and Ke, Lu and Chen, Long and Dou, Bozheng and Zhu, Yueying and Liu, Jie and Zhang, Bengong and Zhou, Tianshou and Wei, Guo-Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{WIREs Computational Molecular Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{e1725}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Wiley}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1002/wcms.1725}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SPMM: A Bidirectional Molecular Foundation Model</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/</guid><description>SPMM is a multimodal molecular foundation model that aligns SMILES structures with property vectors for bidirectional generation and prediction tasks.</description><content:encoded><![CDATA[<h2 id="a-multimodal-foundation-model-for-structure-property-comprehension">A Multimodal Foundation Model for Structure-Property Comprehension</h2>
<p>This is a <strong>Method</strong> paper that introduces the Structure-Property Multi-Modal foundation model (SPMM), a transformer-based architecture that treats SMILES strings and molecular property vectors (PVs) as two separate modalities and learns to align them in a shared embedding space. The primary contribution is enabling bidirectional generation through a single pre-trained model: given a property vector, SPMM can generate molecules (inverse-QSAR), and given a SMILES string, it can predict all 53 properties simultaneously. The model also transfers to unimodal downstream tasks including MoleculeNet benchmarks and reaction prediction.</p>
<h2 id="bridging-the-gap-between-molecular-structure-and-properties">Bridging the Gap Between Molecular Structure and Properties</h2>
<p>Existing chemical pre-trained models typically learn representations from a single modality (SMILES, graphs, or fingerprints) and fine-tune for specific downstream tasks. While some approaches have attempted multimodal learning by combining SMILES with graph representations or InChI strings, these modalities encode nearly identical structural information, limiting the potential for emergent cross-modal knowledge.</p>
<p>The key gap SPMM addresses is the lack of multimodal pre-training that incorporates genuinely complementary modalities. Prior conditional molecule generation methods could typically control only a small number of properties simultaneously and required retraining when target properties changed. The authors draw on successes in vision-language pre-training (VLP), where aligning image and text modalities has enabled rich bidirectional understanding, and apply similar ideas to molecular structure and property domains.</p>
<h2 id="treating-property-vectors-as-a-language">Treating Property Vectors as a Language</h2>
<p>The core innovation in SPMM is treating a collection of 53 RDKit-computed molecular properties as a &ldquo;language&rdquo; where each property value is analogous to a word token. This design allows the model to attend to individual properties independently rather than treating the entire property vector as a single fixed-length condition.</p>
<h3 id="dual-stream-architecture">Dual-Stream Architecture</h3>
<p>SPMM follows the dual-stream VLP architecture. The model has three components:</p>
<ol>
<li><strong>SMILES Encoder</strong>: 6 BERT-base layers that encode tokenized SMILES (using a 300-subword BPE vocabulary) via self-attention</li>
<li><strong>PV Encoder</strong>: 6 BERT-base layers that encode the 53 normalized property values (each passed through a linear layer) with learnable positional embeddings</li>
<li><strong>Fusion Encoder</strong>: 6 BERT-base layers with cross-attention that combines both modalities, using one modality&rsquo;s features as queries and the other as keys/values</li>
</ol>
<h3 id="pre-training-objectives">Pre-training Objectives</h3>
<p>The model is pre-trained with four complementary losses:</p>
<p><strong>Contrastive Learning</strong> aligns SMILES and PV features in a shared embedding space. For [CLS] token outputs $\mathbf{S}_{cls}$ and $\mathbf{P}_{cls}$:</p>
<p>$$
\text{sim}(\mathbf{S}, \mathbf{P}) = \left(h_{S}(\mathbf{S}_{cls})\right)^{\top} h_{P}(\mathbf{P}_{cls})
$$</p>
<p>The intermodal similarities are computed with a learnable temperature $\tau$:</p>
<p>$$
s_{s2p} = \frac{\exp(\text{sim}(\mathbf{S}, \mathbf{P}) / \tau)}{\sum_{n=1}^{N} \exp(\text{sim}(\mathbf{S}, \mathbf{P}_{n}) / \tau)}
$$</p>
<p>The contrastive loss uses cross-entropy with one-hot labels (1 for same-molecule pairs):</p>
<p>$$
L_{\text{contrastive}} = \frac{1}{2}\left(H(y_{s2p}, s_{s2p}) + H(y_{p2s}, s_{p2s}) + H(y_{s2s}, s_{s2s}) + H(y_{p2p}, s_{p2p})\right)
$$</p>
<p><strong>Next Word Prediction (NWP)</strong> trains autoregressive SMILES generation conditioned on the PV:</p>
<p>$$
L_{NWP} = \sum_{i=1}^{n} H\left(y_{n}^{NWP}, p^{NWP}(s_{n} \mid s_{0:n-1}, \mathbf{P})\right)
$$</p>
<p><strong>Next Property Prediction (NPP)</strong> applies the same autoregressive concept to property values, using mean-square-error loss:</p>
<p>$$
L_{NPP} = \sum_{i=1}^{n} \left(p_{n} - \hat{p}_{n}(p_{0:n-1}, \mathbf{S})\right)^{2}
$$</p>
<p><strong>SMILES-PV Matching (SPM)</strong> is a binary classification loss predicting whether a SMILES-PV pair originated from the same molecule, trained with hard-negative mining.</p>
<p>The overall pre-training loss combines all four:</p>
<p>$$
L = \widetilde{L}_{\text{contrastive}} + \widetilde{L}_{NWP} + L_{NPP} + L_{SPM}
$$</p>
<p>where tildes indicate the use of momentum teacher distillation to soften one-hot labels, acknowledging that multiple valid SMILES-PV pairings may exist.</p>
<h3 id="random-property-masking">Random Property Masking</h3>
<p>During pre-training, 50% of property values are randomly replaced with a special [UNK] token. This serves three purposes: preventing overfitting to specific properties, augmenting data, and enabling flexible inference where users can specify any subset of the 53 properties as generation conditions. The model can handle all $2^{53}$ possible property combinations at inference time despite never seeing most of them during training.</p>
<h2 id="experiments-across-bidirectional-and-unimodal-tasks">Experiments Across Bidirectional and Unimodal Tasks</h2>
<h3 id="pv-to-smiles-generation-conditional-molecule-design">PV-to-SMILES Generation (Conditional Molecule Design)</h3>
<p>The authors evaluate SPMM on multiple generation scenarios using 1000 unseen PubChem PVs:</p>
<table>
  <thead>
      <tr>
          <th>Sampling</th>
          <th>Input PV</th>
          <th>Validity</th>
          <th>Uniqueness</th>
          <th>Novelty</th>
          <th>Norm. RMSE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Deterministic</td>
          <td>1000 unseen PVs</td>
          <td>0.995</td>
          <td>0.999</td>
          <td>0.961</td>
          <td>0.216</td>
      </tr>
      <tr>
          <td>Stochastic</td>
          <td>Full PV (molecule 1)</td>
          <td>0.974</td>
          <td>0.905</td>
          <td>0.998</td>
          <td>0.185</td>
      </tr>
      <tr>
          <td>Stochastic</td>
          <td>Molar mass = 150</td>
          <td>0.974</td>
          <td>0.945</td>
          <td>0.872</td>
          <td>0.192</td>
      </tr>
      <tr>
          <td>Stochastic</td>
          <td>4 properties controlled</td>
          <td>0.998</td>
          <td>0.981</td>
          <td>0.952</td>
          <td>0.257</td>
      </tr>
      <tr>
          <td>Stochastic</td>
          <td>No control (all [UNK])</td>
          <td>0.971</td>
          <td>0.991</td>
          <td>0.950</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>The normalized RMSE of 0.216 across 53 properties indicates that generated molecules closely match the input property conditions. The model can also perform unconditional generation (all properties masked) where outputs follow the pre-training distribution. The authors report that SPMM outperforms benchmark models including MolGAN, GraphVAE, and scaffold-based graph generative models in both conditional and unconditional settings (Supplementary Table 1).</p>
<h3 id="smiles-to-pv-generation-multi-property-prediction">SMILES-to-PV Generation (Multi-Property Prediction)</h3>
<p>When given 1000 unseen ZINC15 molecules, SPMM predicts all 53 properties autoregressively with a mean $r^{2}$ of 0.924 across all properties.</p>
<h3 id="moleculenet-benchmarks">MoleculeNet Benchmarks</h3>
<p>Using only the SMILES encoder (6 BERT layers), SPMM achieves best or competitive performance on 9 <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> tasks:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Metric</th>
          <th>SPMM</th>
          <th>Best Baseline</th>
          <th>Baseline Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>0.817</td>
          <td>0.798</td>
          <td>ChemRL-GEM</td>
      </tr>
      <tr>
          <td>LIPO</td>
          <td>RMSE</td>
          <td>0.681</td>
          <td>0.660</td>
          <td>ChemRL-GEM</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE</td>
          <td>1.868</td>
          <td>1.877</td>
          <td>ChemRL-GEM</td>
      </tr>
      <tr>
          <td>BACE (reg)</td>
          <td>RMSE</td>
          <td>1.041</td>
          <td>1.047</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a></td>
      </tr>
      <tr>
          <td>Clearance</td>
          <td>RMSE</td>
          <td>42.607</td>
          <td>43.175</td>
          <td>MolFormer</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>AUROC</td>
          <td>75.1%</td>
          <td>73.6%</td>
          <td>MolFormer</td>
      </tr>
      <tr>
          <td>BACE (cls)</td>
          <td>AUROC</td>
          <td>84.4%</td>
          <td>86.3%</td>
          <td>MolFormer</td>
      </tr>
      <tr>
          <td>ClinTox</td>
          <td>AUROC</td>
          <td>92.7%</td>
          <td>91.2%</td>
          <td>MolFormer</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td>AUROC</td>
          <td>66.9%</td>
          <td>67.2%</td>
          <td>ChemRL-GEM</td>
      </tr>
  </tbody>
</table>
<p>SPMM achieved best performance on 5 of 9 tasks, with notable gains on BBBP (75.1% vs. 73.6%) and ClinTox (92.7% vs. 91.2%). Without pre-training, all scores dropped substantially.</p>
<h3 id="dili-classification">DILI Classification</h3>
<p>On Drug-Induced Liver Injury prediction, SPMM achieved 92.6% AUROC, outperforming the 5-ensemble model of Ai et al. (90.4% AUROC) while using a single model.</p>
<h3 id="reaction-prediction">Reaction Prediction</h3>
<p>On USPTO-480k forward reaction prediction, SPMM achieved 91.5% top-1 accuracy, the highest among all models tested (including <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">Chemformer</a> at 91.3%). On USPTO-50k retro-reaction prediction, SPMM reached 53.4% top-1 accuracy, second only to Chemformer (54.3%) among string-based models.</p>
<h2 id="bidirectional-generation-from-a-single-pre-trained-model">Bidirectional Generation From a Single Pre-trained Model</h2>
<p>SPMM demonstrates that multimodal pre-training with genuinely complementary modalities (structure and properties, rather than structurally redundant representations) enables a single foundation model to handle both generation directions and downstream unimodal tasks. Key findings include:</p>
<ol>
<li><strong>Flexible conditional generation</strong>: The [UNK] masking strategy allows controlling any subset of 53 properties at inference time without retraining, a capability not demonstrated by prior methods.</li>
<li><strong>Interpretable cross-attention</strong>: Attention visualizations show that the model learns chemically meaningful structure-property relationships (e.g., hydrogen bonding properties attend to oxygen and nitrogen atoms; ring count properties attend to ring tokens).</li>
<li><strong>Competitive unimodal transfer</strong>: Despite using only 6 BERT layers and 50M pre-training molecules (smaller than <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>&rsquo;s 77M or Chemformer&rsquo;s 100M), the SMILES encoder alone achieves best or second-best results on 5 of 9 MoleculeNet tasks and the highest forward reaction prediction accuracy among tested models.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>SMILES representation constraints</strong>: Implicit connectivity information in SMILES means small structural changes can cause drastic string changes. Graph representations could be a complementary alternative.</li>
<li><strong>Stereochemistry blindness</strong>: All 53 RDKit properties used are invariant to stereochemistry, meaning different stereoisomers produce identical PVs. The contrastive loss then forces their SMILES encoder outputs to converge, which the authors identify as the primary factor limiting MoleculeNet performance on stereo-sensitive tasks.</li>
<li><strong>No wet-lab validation</strong>: Generated molecules and predicted properties are not experimentally verified.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>PubChem</td>
          <td>50M molecules</td>
          <td>SMILES + 53 RDKit properties</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet (9 tasks)</td>
          <td>642-4200 per task</td>
          <td>Scaffold split via DeepChem (8:1:1)</td>
      </tr>
      <tr>
          <td>DILI classification</td>
          <td>Ai et al. dataset</td>
          <td>Not specified</td>
          <td>Following published preparation</td>
      </tr>
      <tr>
          <td>Forward reaction</td>
          <td>USPTO-480k</td>
          <td>479,035 pairs</td>
          <td>Reactant-product pairs</td>
      </tr>
      <tr>
          <td>Retro reaction</td>
          <td>USPTO-50k</td>
          <td>50,037 pairs</td>
          <td>Product-reactant pairs, no reaction types used</td>
      </tr>
      <tr>
          <td>SMILES-to-PV test</td>
          <td>ZINC15</td>
          <td>1000 molecules</td>
          <td>Not in pre-training set</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: BPE with 300-subword dictionary</li>
<li><strong>Property masking</strong>: 50% random replacement with [UNK] during pre-training</li>
<li><strong>Momentum distillation</strong>: EMA parameter $\lambda = 0.995$, soft-label mixing $\alpha$ linearly warmed from 0 to 0.4 over first epoch</li>
<li><strong>Contrastive queue</strong>: Size $k = 24{,}576$ for storing recent SMILES and PV instances</li>
<li><strong>Beam search</strong>: $k = 2$ for PV-to-SMILES generation</li>
<li><strong>SMILES augmentation</strong>: Random non-canonical augmentation with probability 0.5 for reaction tasks</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: 6 BERT-base encoder layers each for SMILES encoder, PV encoder, and fusion encoder (18 total layers)</li>
<li><strong>Vocabulary</strong>: 300 BPE subwords for SMILES; 53 property tokens for PV</li>
<li><strong>Pre-trained weights</strong>: Available via GitHub</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PV-to-SMILES (deterministic)</td>
          <td>Validity</td>
          <td>99.5%</td>
          <td>1000 unseen PubChem PVs</td>
      </tr>
      <tr>
          <td>PV-to-SMILES (deterministic)</td>
          <td>Normalized RMSE</td>
          <td>0.216</td>
          <td>Across 53 properties</td>
      </tr>
      <tr>
          <td>SMILES-to-PV</td>
          <td>Mean $r^{2}$</td>
          <td>0.924</td>
          <td>1000 ZINC15 molecules</td>
      </tr>
      <tr>
          <td>Forward reaction (USPTO-480k)</td>
          <td>Top-1 accuracy</td>
          <td>91.5%</td>
          <td>Best among all tested models</td>
      </tr>
      <tr>
          <td>Retro reaction (USPTO-50k)</td>
          <td>Top-1 accuracy</td>
          <td>53.4%</td>
          <td>Second-best string-based</td>
      </tr>
      <tr>
          <td>DILI classification</td>
          <td>AUROC</td>
          <td>92.6%</td>
          <td>Single model vs. 5-ensemble</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Pre-training</strong>: 8 NVIDIA A100 GPUs, approximately 52,000 batch iterations, roughly 12 hours</li>
<li><strong>Batch size</strong>: 96</li>
<li><strong>Optimizer</strong>: AdamW with weight decay 0.02</li>
<li><strong>Learning rate</strong>: Warmed up to $10^{-4}$, cosine decay to $10^{-5}$</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jinhojsk515/SPMM">SPMM Source Code</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation with experimental scripts</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10567599">SPMM Zenodo Archive</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Archived version for reproducibility</td>
      </tr>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td>Dataset</td>
          <td>Public domain</td>
          <td>50M molecules for pre-training</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Dataset</td>
          <td>Varies</td>
          <td>Benchmark datasets via DeepChem</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chang, J., &amp; Ye, J. C. (2024). Bidirectional generation of structure and properties through a single molecular foundation model. <em>Nature Communications</em>, 15, 2323. <a href="https://doi.org/10.1038/s41467-024-46440-3">https://doi.org/10.1038/s41467-024-46440-3</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chang2024bidirectional,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Bidirectional generation of structure and properties through a single molecular foundation model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chang, Jinho and Ye, Jong Chul}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{2323}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-024-46440-3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SPE: Data-Driven SMILES Substructure Tokenization</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/</guid><description>SMILES Pair Encoding adapts byte pair encoding to learn chemically meaningful substructure tokens from SMILES, improving generation and QSAR prediction.</description><content:encoded><![CDATA[<h2 id="a-data-driven-tokenization-method-for-chemical-deep-learning">A Data-Driven Tokenization Method for Chemical Deep Learning</h2>
<p>This is a <strong>Method</strong> paper that introduces SMILES Pair Encoding (SPE), a tokenization algorithm adapted from <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">byte pair encoding (BPE)</a> in natural language processing. The primary contribution is a data-driven approach that learns a vocabulary of high-frequency SMILES substrings from a large chemical dataset and then uses that vocabulary to tokenize SMILES for downstream deep learning tasks. The authors provide an open-source Python package (SmilesPE) and demonstrate improvements on both molecular generation and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a> prediction benchmarks.</p>
<h2 id="limitations-of-atom-level-smiles-tokenization">Limitations of Atom-Level SMILES Tokenization</h2>
<p>SMILES-based deep learning models require tokenization to convert molecular strings into sequences of discrete units. The standard approaches have well-known drawbacks:</p>
<ul>
<li><strong>Character-level tokenization</strong> breaks SMILES character by character, splitting chemically meaningful multi-character atoms. For example, <code>[C@@H]</code> becomes six separate tokens (<code>[</code>, <code>C</code>, <code>@</code>, <code>@</code>, <code>H</code>, <code>]</code>), losing the stereochemistry information of a single carbon.</li>
<li><strong>Atom-level tokenization</strong> addresses some of these issues by treating multi-character element symbols (Cl, Br) and bracketed atoms ([nH], [O-]) as single tokens. However, these tokens still encode only individual atoms, not substructures.</li>
<li><strong>k-mer tokenization</strong> (sequences of k consecutive overlapping characters) captures some connectivity information but suffers from the out-of-vocabulary problem: the model cannot represent k-mers not seen during training.</li>
</ul>
<p>All three approaches produce relatively long input sequences (mean ~40 tokens per molecule on ChEMBL at the atom level), which increases computational cost for sequential architectures like RNNs and exacerbates long-range dependency issues.</p>
<h2 id="core-innovation-adapting-byte-pair-encoding-for-smiles">Core Innovation: Adapting Byte Pair Encoding for SMILES</h2>
<p>SPE adapts the byte pair encoding algorithm, originally developed for data compression and later adopted for subword tokenization in NLP, to the domain of chemical strings. The algorithm has two phases:</p>
<p><strong>Vocabulary training:</strong></p>
<ol>
<li>Tokenize SMILES from a large dataset (ChEMBL) at the atom level</li>
<li>Initialize the vocabulary with all unique atom-level tokens</li>
<li>Iteratively count the frequency of all adjacent token pairs, merge the most frequent pair into a new token, and add it to the vocabulary</li>
<li>Stop when either the maximum vocabulary size (MVS) or a minimum frequency threshold (FT) is reached</li>
</ol>
<p><strong>Tokenization:</strong> Given a trained SPE vocabulary, a new SMILES string is first tokenized at the atom level, then token pairs are iteratively merged according to their frequency rank in the vocabulary until no further merges are possible.</p>
<p>The key hyperparameters are MVS and FT. In the reported experiments, MVS was set to 30,000 and FT was set to 2,000. The vocabulary was trained on ~3.4 million SMILES (both canonical and one non-canonical variant per molecule) from ChEMBL25. The resulting vocabulary contained 3,002 unique SMILES substrings with lengths ranging from 1 to 22 atom-level characters.</p>
<p>The trained SPE vocabulary produces tokens that are human-readable and correspond to chemically meaningful substructures and functional groups. SPE tokenization reduces the mean sequence length from approximately 40 tokens (atom-level) to approximately 6 tokens on ChEMBL, a roughly 6-7x compression. This shorter representation directly reduces computational cost for RNN-based and other sequential models.</p>
<p>The algorithm is also compatible with other text-based molecular representations such as <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, since these share atom-level character structures that can serve as the starting point for pair merging.</p>
<h2 id="molecular-generation-and-qsar-prediction-experiments">Molecular Generation and QSAR Prediction Experiments</h2>
<h3 id="molecular-generation">Molecular Generation</h3>
<p>The authors trained AWD-LSTM language models with SPE and atom-level tokenization on 9 million SMILES (1 canonical + 5 non-canonical per compound from ChEMBL25). Each model sampled 1 million SMILES for evaluation. The AWD-LSTM architecture used an embedding size of 400, three LSTM layers with 1,152 hidden units each, and various dropout settings (embedding: 0.1, input: 0.6, weight: 0.5, hidden: 0.2). Models were trained for 10 epochs with a base learning rate of 0.008 using one-cycle scheduling.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SPE</th>
          <th>Atom-level</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>0.941</td>
          <td>0.970</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>0.994</td>
          <td>0.992</td>
      </tr>
      <tr>
          <td>Novelty</td>
          <td>0.983</td>
          <td>0.978</td>
      </tr>
      <tr>
          <td>Internal diversity</td>
          <td>0.897</td>
          <td>0.886</td>
      </tr>
      <tr>
          <td>Nearest neighbor similarity</td>
          <td>0.391</td>
          <td>0.386</td>
      </tr>
  </tbody>
</table>
<p>The SPE model generated a more diverse population of novel molecules at the cost of slightly lower validity (94.1% vs. 97.0%). Internal diversity is defined as:</p>
<p>$$
\text{Internal diversity} = 1 - \frac{1}{|G|} \sum_{(x_1, x_2) \in G \times G} T(x_1, x_2)
$$</p>
<p>where $T(x_1, x_2)$ is the Tanimoto similarity between molecules $x_1$ and $x_2$ using 1024-bit ECFP6 fingerprints. Nearest neighbor similarity (SNN) measures how well the generated set resembles the reference set:</p>
<p>$$
\text{SNN} = \frac{1}{|G|} \sum_{x_G \in G} \max_{x_R \in R} T(x_G, x_R)
$$</p>
<p>Substructure coverage analysis showed both models recovered the same top-1000 BRICS fragments (100% coverage), but SPE consistently outperformed atom-level tokenization on top-5000 coverage across all four substructure types: BRICS fragments (0.997 vs. 0.987), functional groups (0.688 vs. 0.659), scaffolds (0.872 vs. 0.825), and ring systems (0.781 vs. 0.761).</p>
<h3 id="qsar-prediction">QSAR Prediction</h3>
<p>QSAR models were built using the <a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT transfer learning framework</a>, which pre-trains a language model on ChEMBL and then fine-tunes it for specific prediction tasks. The evaluation used 24 regression benchmarks (pIC50 values) from Cortes-Ciriano et al., covering targets ranging from 199 molecules (alpha-2a adrenergic receptor) to 5,010 molecules (<a href="https://en.wikipedia.org/wiki/KCNH2">hERG</a>). Models were evaluated on 10 random 80:10:10 splits using RMSE, R-squared, and MAE. Random forest models with 1024-bit ECFP6 were included as baseline comparisons.</p>
<p><a href="https://en.wikipedia.org/wiki/Effect_size">Cohen&rsquo;s d</a> effect sizes were computed to quantify performance differences between tokenization methods. SPE performed comparably or better than atom-level tokenization on 23 out of 24 datasets. Notable results with medium or large effect sizes favoring SPE included <a href="https://en.wikipedia.org/wiki/Cannabinoid_receptor_1">cannabinoid CB1 receptor</a> (large effect), A2a adrenergic receptor, LCK, estrogen receptor, and <a href="https://en.wikipedia.org/wiki/Aurora_kinase_A">Aurora-A kinase</a> (all medium effects). Against k-mer tokenization, SPE matched or outperformed on 22 out of 24 datasets.</p>
<p>Cohen&rsquo;s d is defined as:</p>
<p>$$
\text{Cohen&rsquo;s } d = \frac{\bar{x}_1 - \bar{x}_2}{\sqrt{(\text{SD}_1^2 + \text{SD}_2^2) / 2}}
$$</p>
<p>where $\bar{x}_1, \bar{x}_2$ are the group means and $\text{SD}_1, \text{SD}_2$ are the standard deviations. Thresholds of 0.2 (small), 0.5 (medium), and 0.8 (large) were used following standard recommendations.</p>
<p>SMILES-based deep learning models generally performed on par with or better than the RF baseline, with particularly strong advantages on the four largest datasets (<a href="https://en.wikipedia.org/wiki/Cyclooxygenase-2">COX-2</a>, <a href="https://en.wikipedia.org/wiki/Acetylcholinesterase">acetylcholinesterase</a>, erbB1, and hERG).</p>
<p>In addition to performance gains, SPE-based models trained on average 5 times faster than atom-level models due to the shorter input sequences.</p>
<h2 id="results-summary-and-future-directions">Results Summary and Future Directions</h2>
<p>The main findings of this study are:</p>
<ol>
<li>
<p><strong>SPE produces chemically meaningful tokens.</strong> The learned vocabulary contains human-readable SMILES substrings that correspond to common substructures and functional groups, making model interpretations more accessible.</p>
</li>
<li>
<p><strong>SPE compresses input sequences by ~6-7x.</strong> Mean token sequence length drops from ~40 (atom-level) to ~6 (SPE) on ChEMBL, yielding a ~5x training speedup.</p>
</li>
<li>
<p><strong>SPE improves molecular generation diversity.</strong> The SPE-based generative model produces molecules with higher novelty (98.3% vs. 97.8%), internal diversity (0.897 vs. 0.886), and substructure coverage, at the cost of slightly lower validity (94.1% vs. 97.0%).</p>
</li>
<li>
<p><strong>SPE matches or outperforms atom-level and k-mer tokenization on QSAR prediction.</strong> Across 24 benchmarks, SPE showed comparable or better performance in 23/24 comparisons against atom-level and 22/24 against k-mer tokenization.</p>
</li>
</ol>
<p><strong>Limitations acknowledged by the authors:</strong></p>
<ul>
<li>The SPE vocabulary is trained on a specific dataset (ChEMBL25) and may not optimally represent chemical spaces that differ significantly from drug-like compounds.</li>
<li>The validity rate for molecular generation is slightly lower than atom-level tokenization (94.1% vs. 97.0%), since longer substructure tokens can introduce invalid fragments.</li>
<li>The k-mer tokenization suffers from an out-of-vocabulary problem, which the authors address by replacing unseen 4-mers with <code>[UNK]</code> tokens, but this is a limitation of the comparison rather than of SPE itself.</li>
</ul>
<p><strong>Future directions:</strong> The authors suggest SPE could serve as a general tokenization method for SMILES-based deep learning, applicable to any task where SMILES strings are used as input (<a href="/notes/chemistry/molecular-design/generation/">generation</a>, <a href="/notes/chemistry/molecular-design/property-prediction/">property prediction</a>, <a href="/notes/chemistry/molecular-design/reaction-prediction/">reaction prediction</a>, retrosynthesis). The algorithm can also be applied to DeepSMILES and SELFIES representations without modification.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SPE vocabulary training</td>
          <td>ChEMBL25</td>
          <td>~3.4M SMILES</td>
          <td>1 canonical + 1 non-canonical per molecule</td>
      </tr>
      <tr>
          <td>Language model training</td>
          <td>ChEMBL25 augmented</td>
          <td>~9M SMILES</td>
          <td>1 canonical + 5 non-canonical per molecule</td>
      </tr>
      <tr>
          <td>Molecular generation evaluation</td>
          <td>Sampled from model</td>
          <td>1M SMILES per model</td>
          <td>Validated with RDKit</td>
      </tr>
      <tr>
          <td>QSAR benchmarks</td>
          <td>Cortes-Ciriano et al.</td>
          <td>24 datasets, 199-5010 molecules</td>
          <td>pIC50 regression tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>SPE vocabulary training: iterative pair merging with MVS=30,000 and FT=2,000</li>
<li>Language model: AWD-LSTM with embedding size 400, 3 LSTM layers with 1,152 hidden units</li>
<li>Dropout: embedding=0.1, input=0.6, weight=0.5, hidden=0.2</li>
<li>Training: 10 epochs, base learning rate 0.008, one-cycle policy</li>
<li>QSAR: MolPMoFiT transfer learning with 25x training augmentation and 15x validation augmentation</li>
<li>Test time augmentation: average of canonical + 4 augmented SMILES predictions</li>
<li>RF baseline: 500 trees, 1024-bit ECFP6, default scikit-learn parameters</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>AWD-LSTM architecture from Merity et al. (2018)</li>
<li>MolPMoFiT framework from Li and Fourches (2020) for transfer learning QSAR</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity, Uniqueness, Novelty</td>
          <td>Generation</td>
          <td>Basic quality metrics</td>
      </tr>
      <tr>
          <td>Internal diversity</td>
          <td>Generation</td>
          <td>1 - mean pairwise Tanimoto (ECFP6)</td>
      </tr>
      <tr>
          <td>Nearest neighbor similarity</td>
          <td>Generation</td>
          <td>Mean max Tanimoto to reference set</td>
      </tr>
      <tr>
          <td>Substructure coverage</td>
          <td>Generation</td>
          <td>BRICS, functional groups, scaffolds, ring systems</td>
      </tr>
      <tr>
          <td>RMSE, R-squared, MAE</td>
          <td>QSAR regression</td>
          <td>10 random 80:10:10 splits</td>
      </tr>
      <tr>
          <td>Cohen&rsquo;s d</td>
          <td>QSAR comparison</td>
          <td>Effect size between tokenization methods</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not explicitly specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/SmilesPE">SmilesPE</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>SPE tokenization Python package</td>
      </tr>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/MolPMoFiT">MolPMoFiT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Transfer learning QSAR framework</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, X., &amp; Fourches, D. (2021). SMILES Pair Encoding: A Data-Driven Substructure Tokenization Algorithm for Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 61(4), 1560-1569. <a href="https://doi.org/10.1021/acs.jcim.0c01127">https://doi.org/10.1021/acs.jcim.0c01127</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2021smiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES Pair Encoding: A Data-Driven Substructure Tokenization Algorithm for Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Xinhao and Fourches, Denis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{61}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1560--1569}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.0c01127}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Smirk: Complete Tokenization for Molecular Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smirk-tokenization-molecular-models/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smirk-tokenization-molecular-models/</guid><description>Smirk tokenizer achieves full OpenSMILES coverage with 165 tokens by decomposing bracketed atoms into glyphs, validated via n-gram proxy models.</description><content:encoded><![CDATA[<h2 id="a-method-for-complete-chemical-tokenization">A Method for Complete Chemical Tokenization</h2>
<p>This is a <strong>Method</strong> paper that introduces two new tokenizers for molecular foundation models: Smirk and Smirk-GPE. The primary contribution is a tokenization scheme that achieves complete coverage of the OpenSMILES specification using only 165 tokens, addressing the vocabulary gaps present in existing atom-wise tokenizers. The paper also proposes n-gram language models as low-cost proxy evaluators for tokenizer quality and validates these proxies against 18 transformer-based models across multiple benchmarks.</p>
<h2 id="vocabulary-gaps-in-molecular-tokenization">Vocabulary Gaps in Molecular Tokenization</h2>
<p>Molecular foundation models overwhelmingly use &ldquo;atom-wise&rdquo; tokenization, where SMILES strings are split at atom boundaries using a regular expression first proposed by Schwaller et al. A key pattern in this regex treats all &ldquo;bracketed atoms&rdquo; (e.g., <code>[C@@H]</code>, <code>[18F]</code>, <code>[Au+]</code>) as single, irreducible tokens. Since bracketed atoms encode isotopes, chirality, charge, hydrogen count, and element identity, the number of possible permutations under the OpenSMILES specification exceeds 28 trillion. In practice, existing atom-wise tokenizers maintain vocabularies of fewer than 3,000 tokens, leaving large portions of chemical space unrepresentable.</p>
<p>This gap has real consequences. Many chemistry-specific tokenizers emit the unknown token <code>[UNK]</code> at non-negligible frequencies, particularly on datasets with diverse elements and stereochemistry. For example, <a href="/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/">SPE and APE</a> tokenizers produce <code>[UNK]</code> for roughly 19% of tokens on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> and approximately 50% on the tmQM transition metal complex dataset. Even models like <a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a> and <a href="/notes/chemistry/molecular-design/reaction-prediction/reactiont5-pretrained-limited-reaction-data/">ReactionT5</a> lack tokens for elements such as copper, ruthenium, gold, and uranium.</p>
<p>The authors also note a subtler issue: some open-vocabulary tokenizers (e.g., <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa&rsquo;s</a> BPE) conflate chemically distinct entities. The same <code>Sc</code> token may represent both a sulfur-carbon bond (in organic SMILES) and the element scandium (in <code>[Sc]</code>), creating ambiguity in downstream analysis.</p>
<h2 id="smirk-glyph-level-decomposition-of-smiles">Smirk: Glyph-Level Decomposition of SMILES</h2>
<p>The core insight behind Smirk is to fully decompose bracketed atoms into their constituent &ldquo;glyphs,&rdquo; the primitive symbols defined by the OpenSMILES specification (element symbols, chirality markers, charges, isotope numbers, hydrogen counts, and brackets themselves). This transforms tokenization from a word-level scheme (one token per bracketed atom) to a character-level scheme over chemically meaningful glyphs.</p>
<p>Smirk uses a two-stage tokenization process:</p>
<ol>
<li><strong>Atom decomposition</strong>: Split a SMILES string into atom-level units using a regex (e.g., <code>OC[C@@H][OH]</code> becomes <code>O C [C@@H] [OH]</code>).</li>
<li><strong>Glyph decomposition</strong>: Further split each unit into its constituent glyphs (e.g., <code>[C@@H]</code> becomes <code>[ C @@ H ]</code>).</li>
</ol>
<p>The two-stage process is necessary to resolve ambiguities. For example, <code>Sc</code> in an unbracketed context represents a sulfur-carbon bond, while <code>[Sc]</code> denotes scandium. This ambiguity occurs over half a million times in PubChem&rsquo;s compound dataset.</p>
<p>The resulting vocabulary contains only 165 tokens, requires no training, and by construction can faithfully tokenize any molecule that conforms to the OpenSMILES specification. The implementation is written in Rust using HuggingFace&rsquo;s Tokenizers library and is available on PyPI.</p>
<p><strong>Smirk-GPE</strong> (Glyph Pair Encoding) extends Smirk with a <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">BPE</a>-like compression step. After Smirk tokenization, adjacent tokens are merged using learned rules, reducing sequence length. Unlike standard BPE, merges operate on token IDs rather than character strings, preserving the distinction between chemically different entities that happen to share the same characters. Smirk-GPE was trained on 262 million molecules from Enamine REAL Space with a target vocabulary of 50,000 tokens, though training terminated at 2,300 tokens after exhausting all possible merges.</p>
<h2 id="evaluation-framework-intrinsic-metrics-n-gram-proxies-and-transformer-benchmarks">Evaluation Framework: Intrinsic Metrics, N-Gram Proxies, and Transformer Benchmarks</h2>
<p>The evaluation covers 34 tokenizers across three datasets (Enamine REALSpace, MoleculeNet, and tmQM) using both intrinsic and extrinsic metrics.</p>
<h3 id="intrinsic-metrics">Intrinsic Metrics</h3>
<p>Four intrinsic metrics are computed for each tokenizer:</p>
<p><strong>Fertility</strong> measures the mean tokenized sequence length. Higher fertility increases computational cost due to the quadratic scaling of attention:</p>
<p>$$
\text{cost} \propto \text{fertility}^2
$$</p>
<p><strong>Normalized entropy</strong> quantifies how close a tokenizer comes to the information-theoretic ideal where all tokens are equally probable:</p>
<p>$$
\eta = \frac{-1}{\log |V|} \sum_{x \in V} p(x) \log p(x)
$$</p>
<p>where $V$ is the vocabulary and $p(x)$ is the observed token probability. Higher normalized entropy correlates with better downstream performance.</p>
<p><strong>Token imbalance</strong> measures the distance between observed token frequencies and a uniform distribution:</p>
<p>$$
D = \frac{1}{2} \sum_{x \in V} |p(x) - |V|^{-1}|
$$</p>
<p><strong>Unknown token frequency</strong> captures the fraction of emitted tokens that are <code>[UNK]</code>. This metric is particularly revealing: all existing chemistry-specific tokenizers (SPE/APE, atom-wise, BPE, and Unigram variants) emit <code>[UNK]</code> at non-negligible rates, while NLP tokenizers, Smirk, and Smirk-GPE do not.</p>
<h3 id="n-gram-proxy-language-models">N-Gram Proxy Language Models</h3>
<p>The paper proposes using n-gram models as low-cost proxies for transformer-based evaluation. An n-gram estimates token likelihood with <a href="https://en.wikipedia.org/wiki/Additive_smoothing">add-one smoothing</a>:</p>
<p>$$
P_{n}(x_{i} \mid x_{i-n+1}, \dots, x_{i-1}) = \frac{C(x_{i-n+1}, \dots, x_{i}) + 1}{C(x_{i-n+1}, \dots, x_{i-1}) + |V|}
$$</p>
<p>where $C$ is the count function and $|V|$ is the vocabulary size. N-grams were &ldquo;pretrained&rdquo; on 1.6 billion SMILES from Enamine REAL Space and evaluated on validation splits. Cross-entropy loss and information loss from unknown tokens were computed.</p>
<p>To quantify information lost to <code>[UNK]</code> tokens, the authors compute the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">KL-divergence</a> between token distributions with and without unknown tokens, using a bidirectional character n-gram model:</p>
<p>$$
B_{n}(x_{i} \mid x_{i-n+1}, \dots, x_{i-1}, x_{i+1}, \dots, x_{i+n-1}) \propto \frac{C(x_{i-n+1}, \dots, x_{i}) + 1}{C(x_{i-n+1}, \dots, x_{i-1}) + |V|} \times \frac{C(x_{i}, \dots, x_{i+n-1}) + 1}{C(x_{i+1}, \dots, x_{i+n-1}) + |V|}
$$</p>
<h3 id="transformer-experiments">Transformer Experiments</h3>
<p>Eighteen encoder-only RoBERTa models (25M parameters each, excluding embeddings) were pretrained from scratch using masked language modeling on Enamine REAL Space (245M molecules, 30,000 steps). Each model used a different tokenizer, isolating the tokenizer&rsquo;s effect on performance. Finetuning was conducted on six regression and seven classification tasks from MoleculeNet and tmQM.</p>
<p>Linear fixed-effects models were used to estimate the standardized effect of each tokenization scheme relative to an atom-wise SMILES baseline.</p>
<h2 id="key-findings-and-practical-implications">Key Findings and Practical Implications</h2>
<h3 id="tokenizer-performance">Tokenizer Performance</h3>
<ul>
<li><strong>Smirk</strong> shows a positive effect on pretraining quality and downstream performance on tmQM (the dataset with the most bracketed atoms), but performs comparably to atom-wise tokenization on MoleculeNet tasks.</li>
<li><strong>SPE and APE</strong> tokenizers have a negative impact on both pretraining and downstream performance relative to the atom-wise baseline, likely due to their high <code>[UNK]</code> rates.</li>
<li><strong>Molecular encoding choice</strong> (<a href="/notes/chemistry/molecular-representations/notations/smiles-selfies-tokenization-chemical-lm/">SMILES vs. SELFIES</a>) has a negligible effect on performance.</li>
<li><strong>NLP tokenizers</strong> (GPT-4o, LLaMA, Gemma) score comparably to chemistry-specific tokenizers on intrinsic metrics and do not emit unknown tokens.</li>
</ul>
<h3 id="n-gram-proxy-validation">N-Gram Proxy Validation</h3>
<p>N-gram cross-entropy and information loss metrics show strong rank correlation (Spearman&rsquo;s $\rho$) with downstream transformer performance, validating their use as low-cost evaluation proxies. The effect sizes from n-gram and transformer experiments are directionally consistent.</p>
<h3 id="information-loss-from-unknown-tokens">Information Loss from Unknown Tokens</h3>
<p>Information loss is minimal for tokenizers with robust coverage but substantial for tokenizers with limited vocabularies on chemically diverse datasets. <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a> incurs only 0.1 nats/molecule on MoleculeNet but 40.3 nats/molecule on tmQM. Open-vocabulary tokenizers (Smirk, Smirk-GPE, NLP tokenizers) mitigate this degradation.</p>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>The authors argue that molecular foundation models must encode the entire breadth of chemical space or risk obscuring critical features. Bracketed atoms encode information essential to clinically relevant pharmaceuticals (e.g., <a href="https://en.wikipedia.org/wiki/Amoxicillin">Amoxicillin</a>), industrial compounds (e.g., Tricalcium Silicate), and foundational chemistry (e.g., <a href="https://en.wikipedia.org/wiki/Cisplatin">Cisplatin</a>, where omitting the chiral marker erases medically relevant stereochemical information). The paper encourages the community to adopt open-vocabulary tokenizers and develop more chemically diverse benchmarks.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The analysis uses a single-point evaluation for transformer experiments, which may underestimate performance achievable with additional hyperparameter tuning.</li>
<li>Smirk-GPE&rsquo;s learned merges from REALSpace did not fully generalize to tmQM, as indicated by the token imbalance metric.</li>
<li>Current benchmarks (MoleculeNet) lack sufficient diversity to evaluate tokenizer robustness across the full periodic table, isotopes, charged species, and uncommon bond types.</li>
<li>The downstream impact of token ambiguities in BPE-based tokenizers (e.g., ChemBERTa&rsquo;s conflation of <code>Sc</code> as both sulfur-carbon and scandium) remains unclear.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>Enamine REAL Space</td>
          <td>1.6B SMILES (n-gram), 245M molecules (transformer)</td>
          <td>80/10/10 train/val/test split</td>
      </tr>
      <tr>
          <td>Downstream evaluation</td>
          <td>MoleculeNet</td>
          <td>Multiple tasks</td>
          <td>6 regression + 7 classification tasks</td>
      </tr>
      <tr>
          <td>Downstream evaluation</td>
          <td>tmQM</td>
          <td>108K transition metal complexes</td>
          <td>OpenSMILES molecular encodings</td>
      </tr>
      <tr>
          <td>Smirk-GPE training</td>
          <td>Enamine REAL Space (subset)</td>
          <td>262M molecules</td>
          <td>Training split only</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Smirk</strong>: Two-stage regex-based tokenization (atom decomposition, then glyph decomposition). No training required. Vocabulary: 165 tokens.</li>
<li><strong>Smirk-GPE</strong>: BPE-like compression on top of Smirk. Operates on token IDs (not strings) to preserve chemical disambiguation. Final vocabulary: 2,300 tokens.</li>
<li><strong>N-gram models</strong>: Add-one smoothing, bidirectional context ($2n - 2$ total context window). Implemented in Julia with exact integer arithmetic.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: RoBERTa-PreLayerNorm, 8 layers, 8 attention heads, hidden size 512, intermediate size 2048, max sequence length 2048. ~25M parameters (excluding embeddings).</li>
<li><strong>Pretraining</strong>: Masked language modeling, 30,000 steps, effective batch size 8192, FusedLamb optimizer, learning rate $1.6 \times 10^{-4}$.</li>
<li><strong>Finetuning</strong>: 100,000 steps, AdamW optimizer, effective batch size 128, learning rate $1.6 \times 10^{-4}$.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>MoleculeNet preferred metrics per task (AUROC for classification, MAE/RMSE for regression)</li>
<li>Fixed-effects models for standardized effect size estimation</li>
<li>Spearman&rsquo;s rank correlation between n-gram and transformer metrics</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pretraining: 2x NVIDIA A100 GPUs (Delta system at NCSA)</li>
<li>Finetuning: 1x NVIDIA A40 GPU</li>
<li>N-gram models: CPU-based (Julia implementation)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BattModels/Smirk">Smirk tokenizer</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Rust implementation with Python bindings, available on PyPI</td>
      </tr>
      <tr>
          <td>Model checkpoints</td>
          <td>Model</td>
          <td>Not specified</td>
          <td>Pretrained and finetuned checkpoints included in data release</td>
      </tr>
      <tr>
          <td>N-gram code</td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Julia implementation included in data release</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wadell, A., Bhutani, A., &amp; Viswanathan, V. (2026). Tokenization for Molecular Foundation Models. <em>Journal of Chemical Information and Modeling</em>, 66(3), 1384-1393. <a href="https://doi.org/10.1021/acs.jcim.5c01856">https://doi.org/10.1021/acs.jcim.5c01856</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wadell2026tokenization,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Tokenization for Molecular Foundation Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wadell, Alexius and Bhutani, Anoushka and Viswanathan, Venkatasubramanian}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{66}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1384--1393}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2026}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.5c01856}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES2Vec: Interpretable Chemical Property Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/smiles2vec-interpretable-property-prediction/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/smiles2vec-interpretable-property-prediction/</guid><description>SMILES2Vec uses a Bayesian-optimized CNN-GRU architecture to predict chemical properties directly from SMILES strings with an interpretable explanation mask.</description><content:encoded><![CDATA[<h2 id="a-general-purpose-rnn-for-chemical-property-prediction-from-smiles">A General-Purpose RNN for Chemical Property Prediction from SMILES</h2>
<p>SMILES2Vec is a <strong>Method</strong> paper that introduces a deep recurrent neural network architecture for predicting chemical properties directly from <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> text representations. The primary contributions are: (1) a Bayesian-optimized CNN-<a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> architecture that serves as a general-purpose predictor for diverse chemical properties (toxicity, activity, solubility, <a href="https://en.wikipedia.org/wiki/Solvation">solvation</a> energy), (2) an explanation mask mechanism that provides interpretable predictions by identifying which SMILES characters drive the network&rsquo;s decisions, and (3) evidence that representation learning from raw SMILES can match or outperform models using hand-crafted molecular descriptors.</p>
<h2 id="motivation-beyond-engineered-features-in-chemical-modeling">Motivation: Beyond Engineered Features in Chemical Modeling</h2>
<p>At the time of writing (2017), deep learning models in chemistry relied heavily on engineered <a href="https://en.wikipedia.org/wiki/Molecular_descriptor">molecular descriptors</a> and fingerprints as input features. Over 5,000 molecular descriptors had been developed since the late 1940s, and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a>/QSPR modeling remained the dominant paradigm. The authors identified two key limitations with this approach:</p>
<ol>
<li><strong>Restricted search space</strong>: Engineered features limit the neural network&rsquo;s ability to discover potentially useful representations that domain experts have not anticipated.</li>
<li><strong>Incomplete domain knowledge</strong>: For complex properties where first-principles understanding is incomplete, the lack of appropriate descriptors constrains model performance.</li>
</ol>
<p>In contrast, computer vision and NLP had shown that deep learning models trained on raw data (unaltered images, raw text) could learn powerful representations without feature engineering. The chemical SMILES notation, a text-based encoding of molecular structure that serves as the standard interchange format in cheminformatics, provided a natural analog to text data for NLP-style modeling.</p>
<p>A secondary motivation was interpretability. Most ML and DL models for chemistry operated as black boxes, which posed particular problems for regulated applications like FDA drug approval where mechanistic explanations are required.</p>
<h2 id="core-innovation-cnn-gru-architecture-with-explanation-masks">Core Innovation: CNN-GRU Architecture with Explanation Masks</h2>
<h3 id="architecture-design-via-bayesian-optimization">Architecture Design via <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian Optimization</a></h3>
<p>SMILES2Vec treats SMILES strings as character-level text input. The network processes one-hot encoded characters (padded to length 250, covering 99.9% of the <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> database) through three stages:</p>
<ol>
<li><strong>Embedding layer</strong>: Maps one-hot character vectors to a learned embedding space (size 50)</li>
<li><strong>1D convolutional layer</strong>: 192 filters with kernel size 3, stride 1</li>
<li><strong>Bidirectional GRU layers</strong>: Two layers with 224 and 384 units respectively</li>
</ol>
<p>The authors explored four architectural classes (GRU, LSTM, CNN-GRU, CNN-LSTM) using Bayesian optimization via SigOpt. Each class was evaluated over 60 trials, optimizing embedding size, convolutional filter count, and RNN layer widths. The CNN-GRU class was selected as the best compromise: CNN-LSTM performed best on classification (Tox21), while GRU-based networks excelled at regression (FreeSolv). The final architecture is summarized by the hyperparameters:</p>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Parameter</th>
          <th>Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Embedding</td>
          <td>Size</td>
          <td>50</td>
      </tr>
      <tr>
          <td>Conv1D</td>
          <td>Filters</td>
          <td>192</td>
      </tr>
      <tr>
          <td>BiGRU Layer 1</td>
          <td>Units</td>
          <td>224</td>
      </tr>
      <tr>
          <td>BiGRU Layer 2</td>
          <td>Units</td>
          <td>384</td>
      </tr>
  </tbody>
</table>
<h3 id="explanation-mask-for-interpretability">Explanation Mask for Interpretability</h3>
<p>The explanation mask is a post-hoc interpretability mechanism. Given a trained (frozen) SMILES2Vec base model, a separate explanation network learns to produce a per-character mask over the input SMILES string. The mask is trained to preserve the base model&rsquo;s output while masking as much input as possible. The loss function for a single sample is:</p>
<p>$$
\text{Loss}_i = | f(\text{SMILES}_i, \theta) - \text{Sol}(\text{SMILES}_i) |_2 + 10^{-6} | \text{MASK}_i |_2 + 0.05 , H(\text{MASK}_i)
$$</p>
<p>where $f(\text{SMILES}_i, \theta)$ is the base network prediction, $\text{Sol}(\text{SMILES}_i)$ is the ground truth solubility, $H$ is the entropy of the normalized mask, and $\text{MASK}_i$ is the per-character mask vector. The L2 term encourages sparsity and the entropy term penalizes uniform attention distributions.</p>
<p>The explanation network itself is a 20-layer residual network with SELU activations, ending in a 1D convolution of length 1, batch normalization, and a softplus activation. The softplus output ranges from 0 (fully masked) to infinity (amplified attention), allowing the mask to both suppress and emphasize specific SMILES characters.</p>
<h2 id="experimental-setup-and-baseline-comparisons">Experimental Setup and Baseline Comparisons</h2>
<h3 id="datasets">Datasets</h3>
<p>The model was evaluated on four datasets from the <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmark and the ESOL solubility dataset:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Property</th>
          <th>Task</th>
          <th>Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tox21</td>
          <td>Toxicity</td>
          <td>Multi-task classification</td>
          <td>8,014</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Activity</td>
          <td>Single-task classification</td>
          <td>41,193</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Solvation energy</td>
          <td>Single-task regression</td>
          <td>643</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>Solubility</td>
          <td>Single-task regression</td>
          <td>1,128</td>
      </tr>
  </tbody>
</table>
<p>SMILES strings longer than 250 characters were excluded. Classification datasets (Tox21, HIV) used 1/6 test split with minority class oversampling; regression datasets (FreeSolv, ESOL) used 1/10 test split. All experiments used 5-fold cross-validation.</p>
<h3 id="training-protocol">Training Protocol</h3>
<ul>
<li><strong>Optimizer</strong>: RMSprop with learning rate $10^{-3}$, $\rho = 0.9$, $\epsilon = 10^{-8}$</li>
<li><strong>Batch size</strong>: 32</li>
<li><strong>Epochs</strong>: 250 with early stopping (patience of 25 epochs based on validation loss)</li>
<li><strong>Classification loss</strong>: Binary cross-entropy</li>
<li><strong>Regression loss</strong>: Mean absolute error</li>
<li><strong>Metrics</strong>: AUC for classification, RMSE for regression</li>
</ul>
<h3 id="baselines">Baselines</h3>
<p>SMILES2Vec was compared against:</p>
<ul>
<li><strong>MLP with engineered features</strong>: Standard multi-layer perceptron using molecular fingerprints (from MoleculeNet)</li>
<li><strong>Molecular graph convolutions</strong>: Graph-based neural network from MoleculeNet</li>
<li><strong>Chemception</strong>: CNN operating on 2D chemical images</li>
</ul>
<h3 id="bayesian-optimization-protocol">Bayesian Optimization Protocol</h3>
<p>Only two datasets were used for architecture optimization: the nr-ahr toxicity task from Tox21 (classification) and FreeSolv (regression). The remaining datasets (full Tox21, HIV, ESOL) served purely for generalization evaluation. A fixed test set was held out during optimization, and correlation between validation and test metrics (0.54 for Tox21, 0.78 for FreeSolv) confirmed limited overfitting to the validation set.</p>
<h2 id="results-competitive-accuracy-with-interpretable-predictions">Results: Competitive Accuracy with Interpretable Predictions</h2>
<h3 id="property-prediction-performance">Property Prediction Performance</h3>
<p>SMILES2Vec achieved the following validation metrics (with a pre-training approach from ChemNet improving performance slightly):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Metric</th>
          <th>SMILES2Vec</th>
          <th>SMILES2Vec + Pre-training</th>
          <th>Graph Conv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tox21</td>
          <td>AUC</td>
          <td>0.80</td>
          <td>0.81</td>
          <td>0.81</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>AUC</td>
          <td>0.78</td>
          <td>0.80</td>
          <td>0.80</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE (kcal/mol)</td>
          <td>1.4</td>
          <td>1.2</td>
          <td>1.3</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>0.63</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Exact numbers for MLP and Chemception baselines were reported only in a bar chart (Figure 6) and not as precise values. The paper states that MLP with fingerprints performed worst across all tasks, and Chemception fell between MLP and the graph/SMILES methods.</p>
<p>Key findings:</p>
<ul>
<li>SMILES2Vec outperformed MLP models using engineered features across all tasks, despite using no feature engineering.</li>
<li>Against graph convolutions (the state-of-the-art at the time), SMILES2Vec matched on classification (Tox21: 0.81 vs 0.81, HIV: 0.80 vs 0.80) and outperformed on regression (FreeSolv: 1.2 vs 1.3).</li>
<li>SMILES2Vec outperformed Chemception (2D image CNN) on classification tasks but slightly underperformed on regression, which the authors attributed to SMILES lacking explicit atomic number information.</li>
</ul>
<h3 id="interpretability-evaluation">Interpretability Evaluation</h3>
<p>On the ESOL solubility dataset, the explanation mask was evaluated against first-principles chemical knowledge. The authors separated compounds into soluble (&gt; 1.0) and insoluble (&lt; -5.0) categories and defined ground truth: soluble compounds should attend to hydrophilic atoms (O, N) while insoluble compounds should attend to hydrophobic atoms (C, F, Cl, Br, I). The top-3 character accuracy was 88%, confirming that SMILES2Vec learned representations consistent with known functional group chemistry.</p>
<p>Qualitative analysis of the masks showed that for low-solubility molecules, characters corresponding to hydrophobic groups (c, C, Cl) received high attention, while high-solubility molecules showed attention focused on hydrophilic groups (O, N).</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The interpretability evaluation was limited to solubility, a well-understood property with simple first-principles rules. The authors acknowledged that quantifying interpretability for complex properties (toxicity, activity) where no simple ground truth exists is nontrivial.</li>
<li>The Bayesian optimization used only a subset of datasets, so the architecture may not be globally optimal across all chemical tasks.</li>
<li>SMILES strings lack explicit atomic number information, which may limit performance on physical property prediction compared to image or graph representations.</li>
<li>The explanation mask approach requires training a separate 20-layer network per property, adding computational overhead.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Architecture optimization</td>
          <td>Tox21 (nr-ahr task)</td>
          <td>8,014</td>
          <td>Single toxicity task for Bayesian optimization</td>
      </tr>
      <tr>
          <td>Architecture optimization</td>
          <td>FreeSolv</td>
          <td>643</td>
          <td>Solvation free energy regression</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Tox21 (full, 12 tasks)</td>
          <td>8,014</td>
          <td>Multi-task classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HIV</td>
          <td>41,193</td>
          <td>Single-task classification</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Solubility regression, also used for interpretability</td>
      </tr>
  </tbody>
</table>
<p>All datasets are publicly available through MoleculeNet. The ESOL dataset is from Delaney (2004).</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Bayesian optimization via SigOpt (60 trials per architectural class, 4 classes, 6 manually seeded initial designs per class)</li>
<li>RMSprop optimizer with standard settings</li>
<li>Explanation mask trained with Adam, learning rate annealed from $10^{-2}$ to $10^{-6}$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Final architecture: Embedding(50) -&gt; Conv1D(192, kernel=3, stride=1) -&gt; BiGRU(224) -&gt; BiGRU(384)</li>
<li>Explanation network: 20-layer residual network with SELU activations</li>
<li>No pre-trained weights or code were released</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUC</td>
          <td>Tox21</td>
          <td>0.81</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>AUC</td>
          <td>HIV</td>
          <td>0.80</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>FreeSolv</td>
          <td>1.2 kcal/mol</td>
          <td>With pre-training</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>ESOL</td>
          <td>0.63</td>
          <td>Base model</td>
      </tr>
      <tr>
          <td>Top-3 accuracy</td>
          <td>ESOL interpretability</td>
          <td>88%</td>
          <td>Explanation mask</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The authors report using TensorFlow with GPU acceleration via NVIDIA cuDNN libraries. Specific GPU models and training times were not reported.</p>
<h3 id="artifacts">Artifacts</h3>
<p>No code, models, or data artifacts were released by the authors. The datasets used are publicly available through MoleculeNet.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Goh, G. B., Hodas, N. O., Siegel, C., &amp; Vishnu, A. (2017). SMILES2Vec: An Interpretable General-Purpose Deep Neural Network for Predicting Chemical Properties. <em>arXiv preprint arXiv:1712.02034</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{goh2017smiles2vec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES2Vec: An Interpretable General-Purpose Deep Neural Network for Predicting Chemical Properties}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Goh, Garrett B. and Hodas, Nathan O. and Siegel, Charles and Vishnu, Abhinav}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1712.02034}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arxiv.1712.02034}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES-BERT: BERT-Style Pre-Training for Molecules</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-bert/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-bert/</guid><description>SMILES-BERT applies BERT-style masked pre-training to SMILES strings for molecular property prediction, using Transformer encoders fine-tuned on labeled data.</description><content:encoded><![CDATA[<h2 id="pre-training-transformers-on-smiles-for-molecular-properties">Pre-Training Transformers on SMILES for Molecular Properties</h2>
<p>SMILES-BERT is a <strong>Method</strong> paper that introduces a BERT-inspired pre-training and fine-tuning framework for molecular property prediction. The primary contribution is adapting the masked language model paradigm from NLP to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a>, enabling a Transformer encoder to learn molecular representations from large-scale unlabeled data before fine-tuning on smaller labeled datasets.</p>
<h2 id="limited-labels-in-molecular-property-prediction">Limited Labels in Molecular Property Prediction</h2>
<p>Molecular property prediction is central to drug discovery and chemical design, but obtaining labeled data requires expensive biological assays. Deep learning methods for this task fall into three categories: manually designed fingerprints (e.g., ECFP), graph-based methods (GCNs operating on molecular graphs), and sequence-based methods (RNNs or CNNs operating on SMILES strings).</p>
<p>Prior unsupervised approaches like <a href="/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/">Seq2seq Fingerprint</a> used an encoder-decoder architecture to learn representations from unlabeled SMILES, but the decoder acts as scaffolding that consumes GPU memory during pre-training without contributing to downstream prediction. The semi-supervised Seq3seq Fingerprint improved on this by incorporating labeled data, but retained the encoder-decoder inefficiency. RNN-based methods also suffer from difficulty in parallel training and require careful tuning (gradient clipping, early stopping) to converge.</p>
<p>The authors identify two motivations: (1) building a semi-supervised model that effectively leverages large pools of unlabeled SMILES to improve prediction with limited labels, and (2) designing an architecture where the entire pre-trained model participates in fine-tuning (no wasted decoder parameters) and naturally supports parallel training.</p>
<h2 id="masked-smiles-recovery-with-transformer-encoders">Masked SMILES Recovery with Transformer Encoders</h2>
<p>The core innovation is the Masked SMILES Recovery pre-training task, directly analogous to BERT&rsquo;s masked language modeling. The model architecture is a stack of Transformer encoder layers, making it fully convolutional and parallelizable.</p>
<h3 id="architecture">Architecture</h3>
<p>SMILES-BERT uses 6 Transformer encoder layers, each with 4-head multi-head self-attention and feed-forward dimension of 1024. Each Transformer layer contains three components: a pre-attention feed-forward network, a self-attention layer, and a post-attention feed-forward network, all followed by layer normalization with residual connections.</p>
<p>The self-attention mechanism uses scaled dot-product attention:</p>
<p>$$
Z = \text{Softmax}\left(\frac{(XW^{Q})(XW^{K})^{T}}{\sqrt{d_{k}}}\right) XW^{V}
$$</p>
<p>where $X \in \mathbb{R}^{N \times M}$ is the input feature matrix, $W^{Q}$, $W^{K}$, $W^{V} \in \mathbb{R}^{M \times d_{k}}$ are the query, key, and value weight matrices, and $\sqrt{d_{k}}$ is the scaling factor.</p>
<p>Input SMILES are tokenized at the character level with token embeddings and positional embeddings. A special <code>&lt;GO&gt;</code> token is prepended to each SMILES, and its output representation is used for downstream classification/regression after fine-tuning.</p>
<h3 id="pre-training-masked-smiles-recovery">Pre-training: Masked SMILES Recovery</h3>
<p>Following BERT&rsquo;s masking strategy, 15% of tokens in each SMILES are selected for masking (minimum one per SMILES). Of the selected tokens:</p>
<ul>
<li>85% are replaced with a <code>&lt;MASK&gt;</code> token</li>
<li>10% are replaced with a random token from the vocabulary</li>
<li>5% are kept unchanged</li>
</ul>
<p>The model is trained to recover the original tokens at masked positions. The loss is computed only on the masked token outputs.</p>
<h3 id="fine-tuning">Fine-tuning</h3>
<p>After pre-training, a classifier or regressor head is added to the <code>&lt;GO&gt;</code> token output. The entire model (all Transformer layers plus the new head) is fine-tuned on the labeled dataset.</p>
<p>Key differences from the original BERT:</p>
<ol>
<li>Only the Masked SMILES Recovery task is used (BERT&rsquo;s next sentence prediction is dropped since SMILES have no consecutive-sentence structure)</li>
<li>Segment embeddings are removed</li>
<li>The architecture is smaller (6 layers, 4 heads, 1024 FFN dim) since SMILES have a much smaller vocabulary and shorter sequences than natural language</li>
</ol>
<p>The authors compared this configuration against a larger BERT-base setup (12 layers, 12 heads, 3072 FFN dim) and found no meaningful performance difference, confirming that the smaller model is sufficient for SMILES.</p>
<h2 id="experimental-setup-and-baseline-comparisons">Experimental Setup and Baseline Comparisons</h2>
<h3 id="pre-training-data">Pre-training Data</h3>
<p>SMILES-BERT was pre-trained on the <a href="/notes/chemistry/datasets/zinc-22/">ZINC database</a> with 18,671,355 training SMILES, 10,000 for validation, and 10,000 for evaluation. Pre-training ran for 10 epochs using the Adam optimizer with a warm-up strategy (learning rate from $10^{-9}$ to $10^{-4}$ over 4,000 steps, then inverse-square-root decay). Batch size was 256 and dropout was 0.1. The pre-training masked SMILES exact recovery rate reached 82.85% on the validation set.</p>
<h3 id="fine-tuning-datasets">Fine-tuning Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Source</th>
          <th>Size</th>
          <th>Task</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a></td>
          <td>NCATS/NIH</td>
          <td>10,850</td>
          <td>Classification (threshold 1.88)</td>
          <td>Accuracy</td>
      </tr>
      <tr>
          <td>PM2</td>
          <td>NCATS/NIH</td>
          <td>323,242</td>
          <td>Classification (threshold 0.024896)</td>
          <td>Accuracy</td>
      </tr>
      <tr>
          <td>PCBA-686978</td>
          <td><a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a></td>
          <td>302,175</td>
          <td>Classification</td>
          <td>Accuracy</td>
      </tr>
  </tbody>
</table>
<p>All datasets were split 80/10/10 for train/validation/test. Fine-tuning used Adam with a fixed learning rate for 50 epochs, selecting the best model on validation data.</p>
<h3 id="baselines">Baselines</h3>
<ul>
<li><strong>Circular Fingerprint (CircularFP)</strong>: Manually designed hash-based fingerprint (ECFP family)</li>
<li><strong>Neural Fingerprint (NeuralFP)</strong>: Graph-based neural network replacing hash functions with learned layers</li>
<li><strong>Seq2seq Fingerprint (Seq2seqFP)</strong>: Unsupervised encoder-decoder model on SMILES</li>
<li><strong>Seq3seq Fingerprint (Seq3seqFP)</strong>: Semi-supervised encoder-decoder model on SMILES</li>
</ul>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>LogP</th>
          <th>PM2</th>
          <th>PCBA-686978</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CircularFP</td>
          <td>~0.90</td>
          <td>0.6858</td>
          <td>~0.82</td>
      </tr>
      <tr>
          <td>NeuralFP</td>
          <td>~0.90</td>
          <td>0.6802</td>
          <td>~0.82</td>
      </tr>
      <tr>
          <td>Seq2seqFP</td>
          <td>~0.87</td>
          <td>0.6112</td>
          <td>~0.80</td>
      </tr>
      <tr>
          <td>Seq3seqFP</td>
          <td>~0.90</td>
          <td>0.7038</td>
          <td>~0.84</td>
      </tr>
      <tr>
          <td><strong>SMILES-BERT</strong></td>
          <td><strong>0.9154</strong></td>
          <td><strong>0.7589</strong></td>
          <td><strong>0.8784</strong></td>
      </tr>
  </tbody>
</table>
<p>SMILES-BERT outperformed all baselines on all three datasets. The improvement over Seq3seqFP was approximately 2% on LogP, 5.5% on PM2, and 3.8% on PCBA-686978. The results on PM2 (the largest labeled dataset) show that pre-training benefits persist even with substantial labeled data.</p>
<h3 id="structure-study">Structure Study</h3>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Layers</th>
          <th>Attention Heads</th>
          <th>FFN Dim</th>
          <th>LogP Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SMILES-BERT</td>
          <td>6</td>
          <td>4</td>
          <td>1024</td>
          <td>0.9154</td>
      </tr>
      <tr>
          <td>SMILES-BERT (large)</td>
          <td>12</td>
          <td>12</td>
          <td>3072</td>
          <td>0.9147</td>
      </tr>
  </tbody>
</table>
<p>The larger configuration provided no improvement, supporting the choice of the smaller, more efficient architecture.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>SMILES-BERT demonstrated that BERT-style masked pre-training on SMILES strings produces transferable molecular representations that improve property prediction across datasets of varying sizes and property types.</p>
<p>Key findings:</p>
<ul>
<li>The Masked SMILES Recovery pre-training task transfers effectively to molecular property prediction</li>
<li>The full model participates in fine-tuning (no wasted decoder), making SMILES-BERT more parameter-efficient than encoder-decoder alternatives</li>
<li>A smaller Transformer configuration (6 layers, 4 heads) matches the performance of a BERT-base-sized model for SMILES data</li>
<li>Pre-training on ~18.7M SMILES from ZINC provides robust initialization across different downstream tasks</li>
</ul>
<p><strong>Limitations</strong>: The evaluation uses only classification accuracy as the metric, without reporting AUC-ROC, F1, or other metrics common in molecular property prediction. The comparison is limited to four baselines, and two of the three evaluation datasets (LogP, PM2) are non-public NIH datasets. The paper does not explore different pre-training dataset sizes or ablate the masking strategy. Only classification tasks are evaluated, though the architecture supports regression.</p>
<p><strong>Future work</strong>: The authors propose incorporating Quantitative Estimate of Druglikeness (QED) prediction as an additional pre-training task to warm up the model&rsquo;s classification capability, analogous to BERT&rsquo;s next sentence prediction.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC</td>
          <td>18,671,355 SMILES</td>
          <td>Publicly available database</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>LogP</td>
          <td>10,850</td>
          <td>Non-public, from NCATS/NIH</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>PM2</td>
          <td>323,242</td>
          <td>Non-public, from NCATS/NIH</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>PCBA-686978</td>
          <td>302,175</td>
          <td>Public, from PubChem BioAssay</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Pre-training: Adam optimizer, warm-up for 4,000 steps ($10^{-9}$ to $10^{-4}$), inverse-square-root LR schedule, batch size 256, dropout 0.1, 10 epochs</li>
<li>Fine-tuning: Adam optimizer, fixed LR (insensitive to choice among $10^{-5}$, $10^{-6}$, $10^{-7}$), 50 epochs, best model on validation</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>6 Transformer encoder layers, 4-head multi-head attention, FFN dim 1024</li>
<li>Token embedding + positional embedding, <code>&lt;GO&gt;</code> special token</li>
<li>Implemented with FairSeq (Facebook AI Research Sequence-to-Sequence Toolkit)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SMILES-BERT</th>
          <th>Best Baseline (Seq3seqFP)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LogP Accuracy</td>
          <td>0.9154</td>
          <td>~0.90</td>
          <td>~2% improvement</td>
      </tr>
      <tr>
          <td>PM2 Accuracy</td>
          <td>0.7589</td>
          <td>0.7038</td>
          <td>~5.5% improvement</td>
      </tr>
      <tr>
          <td>PCBA Accuracy</td>
          <td>0.8784</td>
          <td>~0.84</td>
          <td>~3.8% improvement</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper mentions GPU training and NVIDIA GPU donation in acknowledgments but does not specify the exact GPU model or training time beyond noting that pre-training on a single GPU takes over a week for 10 epochs.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>No public code or model release identified</td>
          <td>-</td>
          <td>-</td>
          <td>Paper does not provide a GitHub link or model checkpoint</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. The ZINC pre-training data is public and the architecture is described in detail, but no code or pre-trained weights are released. Two of three evaluation datasets (LogP, PM2) are non-public.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, S., Guo, Y., Wang, Y., Sun, H., &amp; Huang, J. (2019). SMILES-BERT: Large Scale Unsupervised Pre-Training for Molecular Property Prediction. In <em>Proceedings of the 10th ACM International Conference on Bioinformatics, Computational Biology and Health Informatics (ACM-BCB &lsquo;19)</em>, 429-436. <a href="https://doi.org/10.1145/3307339.3342186">https://doi.org/10.1145/3307339.3342186</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{wang2019smilesbert,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES-BERT: Large Scale Unsupervised Pre-Training for Molecular Property Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wang, Sheng and Guo, Yuzhi and Wang, Yuhong and Sun, Hongmao and Huang, Junzhou}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 10th ACM International Conference on Bioinformatics, Computational Biology and Health Informatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{429--436}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3307339.3342186}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES vs SELFIES Tokenization for Chemical LMs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-selfies-tokenization-chemical-lm/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/smiles-selfies-tokenization-chemical-lm/</guid><description>Atom Pair Encoding (APE) tokenizer outperforms BPE on SMILES and SELFIES in RoBERTa-based chemical language models across MoleculeNet classification tasks.</description><content:encoded><![CDATA[<h2 id="atom-pair-encoding-for-chemical-language-modeling">Atom Pair Encoding for Chemical Language Modeling</h2>
<p>This is a <strong>Method</strong> paper that introduces Atom Pair Encoding (APE), a tokenization algorithm designed specifically for chemical string representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>). The primary contribution is demonstrating that a chemistry-aware tokenizer, which preserves atomic identity during subword merging, leads to improved molecular property classification accuracy in transformer-based models compared to the standard Byte Pair Encoding (BPE) approach.</p>
<h2 id="why-tokenization-matters-for-chemical-strings">Why Tokenization Matters for Chemical Strings</h2>
<p>Existing chemical language models based on BERT/RoBERTa architectures have typically relied on BPE for tokenizing SMILES and SELFIES strings. <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">Byte Pair Encoding (BPE)</a> was originally designed for natural language and data compression, where it excels at breaking words into meaningful subword units. When applied to chemical strings, BPE operates at the character level without understanding chemical semantics, leading to several problems:</p>
<ul>
<li><strong>Stray characters</strong>: BPE may create tokens like &ldquo;C)(&rdquo; that have no chemical meaning.</li>
<li><strong>Element splitting</strong>: Multi-character elements like chlorine (&ldquo;Cl&rdquo;) can be split into &ldquo;C&rdquo; and &ldquo;l&rdquo;, causing the model to misinterpret carbon and a dangling character.</li>
<li><strong>Lost structural context</strong>: BPE compresses sequences without considering how character position encodes molecular structure.</li>
</ul>
<p>Previous work on <a href="/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/">SMILES Pair Encoding (SPE)</a> attempted to address this by iteratively merging SMILES substrings into chemically meaningful tokens. However, SPE had practical limitations: its Python implementation did not support SELFIES, and it produced a smaller vocabulary (~3000 tokens) than what the data could support. These gaps motivated the development of APE.</p>
<h2 id="the-ape-tokenizer-chemistry-aware-subword-merging">The APE Tokenizer: Chemistry-Aware Subword Merging</h2>
<p>APE draws inspiration from both BPE and SPE but addresses their shortcomings. The key design decisions are:</p>
<ol>
<li>
<p><strong>Atom-level initialization</strong>: Instead of starting from individual characters (as BPE does), APE begins with chemically valid atomic units. For SMILES, this means recognizing multi-character elements (e.g., &ldquo;Cl&rdquo;, &ldquo;Br&rdquo;) as single tokens. For SELFIES, each bracketed string (e.g., [C], [Ring1], [=O]) serves as the fundamental unit.</p>
</li>
<li>
<p><strong>Iterative pair merging</strong>: Like BPE, APE iteratively merges the most frequent adjacent token pairs. The difference is that the initial tokenization preserves atomic boundaries, so merged tokens always represent valid chemical substructures.</p>
</li>
<li>
<p><strong>Larger vocabulary</strong>: Using the same minimum frequency threshold of 2000, APE generates approximately 5300 unique tokens from the PubChem dataset, compared to SPE&rsquo;s approximately 3000. This richer vocabulary provides more expressive power for representing chemical substructures.</p>
</li>
<li>
<p><strong>SELFIES compatibility</strong>: APE natively supports both SMILES and SELFIES, using the bracketed token structure of SELFIES as its starting point for that representation.</p>
</li>
</ol>
<p>The tokenizer was trained on a subset of 2 million molecules from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> (10 million SMILES total). This produced four tokenizer variants: SMILES-BPE, SMILES-APE, SELFIES-BPE, and SELFIES-APE.</p>
<h2 id="pre-training-and-evaluation-on-moleculenet-benchmarks">Pre-training and Evaluation on MoleculeNet Benchmarks</h2>
<h3 id="model-architecture">Model architecture</h3>
<p>All four models use the RoBERTa architecture with 6 hidden layers, a hidden size of 768, an intermediate size of 1536, and 12 attention heads. Pre-training used masked language modeling (MLM) with 15% token masking on 1 million molecules from PubChem, with a validation set of 100,000 molecules. Each model was pre-trained for 20 epochs using AdamW, with hyperparameter optimization via Optuna.</p>
<h3 id="downstream-tasks">Downstream tasks</h3>
<p>The models were fine-tuned on three <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification tasks:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Category</th>
          <th>Compounds</th>
          <th>Tasks</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP</td>
          <td>Physiology</td>
          <td>2,039</td>
          <td>1</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Biophysics</td>
          <td>41,127</td>
          <td>1</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>Physiology</td>
          <td>7,831</td>
          <td>12</td>
          <td>ROC-AUC</td>
      </tr>
  </tbody>
</table>
<p>Data was split 80/10/10 (train/validation/test) following MoleculeNet recommendations. Models were fine-tuned for 5 epochs with early stopping based on validation ROC-AUC.</p>
<h3 id="baselines">Baselines</h3>
<p>Results were compared against two text-based models (<a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a> MTR-77M and <a href="/notes/chemistry/molecular-representations/encoders/selformer/">SELFormer</a>) and two graph-based models (D-MPNN from Chemprop and MoleculeNet Graph-Conv).</p>
<h3 id="main-results">Main results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP ROC</th>
          <th>HIV ROC</th>
          <th>Tox21 ROC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SMILYAPE-1M</td>
          <td>0.754 +/- 0.006</td>
          <td>0.772 +/- 0.010</td>
          <td>0.838 +/- 0.002</td>
      </tr>
      <tr>
          <td>SMILYBPE-1M</td>
          <td>0.746 +/- 0.006</td>
          <td>0.754 +/- 0.015</td>
          <td>0.849 +/- 0.002</td>
      </tr>
      <tr>
          <td>SELFYAPE-1M</td>
          <td>0.735 +/- 0.015</td>
          <td>0.768 +/- 0.012</td>
          <td>0.842 +/- 0.002</td>
      </tr>
      <tr>
          <td>SELFYBPE-1M</td>
          <td>0.676 +/- 0.014</td>
          <td>0.709 +/- 0.012</td>
          <td>0.825 +/- 0.001</td>
      </tr>
      <tr>
          <td>ChemBERTa-2-MTR-77M</td>
          <td>0.698 +/- 0.014</td>
          <td>0.735 +/- 0.008</td>
          <td>0.790 +/- 0.003</td>
      </tr>
      <tr>
          <td>SELFormer</td>
          <td>0.716 +/- 0.021</td>
          <td>0.769 +/- 0.010</td>
          <td>0.838 +/- 0.005</td>
      </tr>
      <tr>
          <td>MoleculeNet-Graph-Conv</td>
          <td>0.690</td>
          <td>0.763</td>
          <td>0.829</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.737</td>
          <td>0.776</td>
          <td>0.851</td>
      </tr>
  </tbody>
</table>
<p>APE consistently outperforms BPE for both SMILES and SELFIES. SMILYAPE achieves the best BBBP score (0.754), beating D-MPNN (0.737). On HIV, SMILYAPE (0.772) is competitive with D-MPNN (0.776). On Tox21, D-MPNN (0.851) leads, with SMILYBPE (0.849) and SELFYAPE (0.842) close behind.</p>
<h3 id="statistical-significance">Statistical significance</h3>
<p><a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">Mann-Whitney U tests</a> confirmed statistically significant differences between SMILYAPE and SMILYBPE (p &lt; 0.05 on all datasets). Cliff&rsquo;s delta values indicate large effect sizes: 0.74 (BBBP), 0.70 (HIV), and -1.00 (Tox21, favoring BPE). For SELFIES models, SELFYAPE achieved Cliff&rsquo;s delta of 1.00 across all three datasets, indicating complete separation from SELFYBPE.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="ape-outperforms-bpe-by-preserving-atomic-identity">APE outperforms BPE by preserving atomic identity</h3>
<p>The consistent advantage of APE over BPE stems from APE&rsquo;s atom-level initialization. By starting with chemically valid units rather than individual characters, APE avoids creating nonsensical tokens that break chemical elements or mix structural delimiters with atoms.</p>
<h3 id="smiles-outperforms-selfies-with-ape-tokenization">SMILES outperforms SELFIES with APE tokenization</h3>
<p>SMILYAPE generally outperforms SELFYAPE across tasks. Attention weight analysis revealed that SMILYAPE assigns more weight to immediate neighboring tokens (0.108 vs. 0.096) and less to distant tokens (0.030 vs. 0.043). This pattern aligns with chemical intuition: bonding is primarily determined by directly connected atoms. SMILYAPE also produces more compact tokenizations (8.6 tokens per molecule vs. 11.9 for SELFYAPE), potentially allowing more efficient attention allocation.</p>
<h3 id="selfies-models-show-higher-inter-tokenizer-agreement">SELFIES models show higher inter-tokenizer agreement</h3>
<p>On the BBBP dataset, all true positives identified by SELFYBPE were also captured by SELFYAPE, with SELFYAPE achieving higher recall (61.68% vs. 55.14%). In contrast, SMILES-based models shared only 29.3% of true positives between APE and BPE variants, indicating that tokenization choice has a larger impact on SMILES models.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Pre-training used only 1 million molecules, compared to 77 million for ChemBERTa-2. Despite this, APE models were competitive or superior, but scaling effects remain unexplored.</li>
<li>Evaluation was limited to three binary classification tasks from MoleculeNet. Regression tasks, molecular generation, and reaction prediction were not tested.</li>
<li>The Tox21 result is notable: SMILYBPE outperforms SMILYAPE (0.849 vs. 0.838), suggesting APE&rsquo;s advantage may be task-dependent.</li>
<li>No comparison with recent atom-level tokenizers like <a href="/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/">Atom-in-SMILES</a> or newer approaches beyond SPE.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tokenizer training</td>
          <td>PubChem subset</td>
          <td>2M molecules</td>
          <td>SMILES strings converted to SELFIES via selfies library</td>
      </tr>
      <tr>
          <td>Pre-training</td>
          <td>PubChem subset</td>
          <td>1M molecules</td>
          <td>100K validation set</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BBBP</td>
          <td>2,039 compounds</td>
          <td>80/10/10 split</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HIV</td>
          <td>41,127 compounds</td>
          <td>80/10/10 split</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Tox21</td>
          <td>7,831 compounds</td>
          <td>80/10/10 split, 12 tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Tokenizers: BPE (via Hugging Face), APE (custom implementation, minimum frequency 2000)</li>
<li>Pre-training: Masked Language Modeling (15% masking) for 20 epochs</li>
<li>Optimizer: AdamW with Optuna hyperparameter search</li>
<li>Fine-tuning: 5 epochs with early stopping on validation ROC-AUC</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Architecture: RoBERTa with 6 layers, hidden size 768, intermediate size 1536, 12 attention heads</li>
<li>Four variants: SMILYAPE, SMILYBPE, SELFYAPE, SELFYBPE</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SMILYAPE</th>
          <th>SMILYBPE</th>
          <th>SELFYAPE</th>
          <th>SELFYBPE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BBBP ROC-AUC</td>
          <td>0.754</td>
          <td>0.746</td>
          <td>0.735</td>
          <td>0.676</td>
      </tr>
      <tr>
          <td>HIV ROC-AUC</td>
          <td>0.772</td>
          <td>0.754</td>
          <td>0.768</td>
          <td>0.709</td>
      </tr>
      <tr>
          <td>Tox21 ROC-AUC</td>
          <td>0.838</td>
          <td>0.849</td>
          <td>0.842</td>
          <td>0.825</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>NVIDIA RTX 3060 GPU with 12 GiB VRAM</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mikemayuare/apetokenizer">APE Tokenizer</a></td>
          <td>Code</td>
          <td>Other (unspecified SPDX)</td>
          <td>Official APE tokenizer implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/mikemayuare/PubChem10M_SMILES_SELFIES">PubChem10M SMILES/SELFIES</a></td>
          <td>Dataset</td>
          <td>Not specified</td>
          <td>10M SMILES with SELFIES conversions</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/mikemayuare">Pre-trained and fine-tuned models</a></td>
          <td>Model</td>
          <td>Not specified</td>
          <td>All four model variants on Hugging Face</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Leon, M., Perezhohin, Y., Peres, F., Popovič, A., &amp; Castelli, M. (2024). Comparing SMILES and SELFIES tokenization for enhanced chemical language modeling. <em>Scientific Reports</em>, 14(1), 25016. <a href="https://doi.org/10.1038/s41598-024-76440-8">https://doi.org/10.1038/s41598-024-76440-8</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{leon2024comparing,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Comparing SMILES and SELFIES tokenization for enhanced chemical language modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Leon, Miguelangel and Perezhohin, Yuriy and Peres, Fernando and Popovi{\v{c}}, Ale{\v{s}} and Castelli, Mauro}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{25016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-024-76440-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMILES Transformer: Low-Data Molecular Fingerprints</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-transformer/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smiles-transformer/</guid><description>SMILES Transformer uses unsupervised Transformer pre-training on SMILES strings to produce molecular fingerprints that excel in low-data drug discovery tasks.</description><content:encoded><![CDATA[<h2 id="a-transformer-approach-to-learned-molecular-fingerprints">A Transformer Approach to Learned Molecular Fingerprints</h2>
<p>This is a <strong>Method</strong> paper that introduces SMILES Transformer (ST), a Transformer-based sequence-to-sequence model pre-trained on unlabeled SMILES strings to produce continuous, data-driven molecular fingerprints. The primary contribution is demonstrating that unsupervised pre-training on chemical text representations yields fingerprints that generalize well under low-data conditions, outperforming both rule-based fingerprints (ECFP) and graph convolution models on several <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks. A secondary contribution is the Data Efficiency Metric (DEM), a scalar metric for evaluating model performance across varying training set sizes.</p>
<h2 id="the-low-data-problem-in-molecular-property-prediction">The Low-Data Problem in Molecular Property Prediction</h2>
<p>Machine learning for drug discovery depends on molecular representations, but labeled datasets of experimentally validated properties are typically small. Conventional approaches fall into two camps: rule-based fingerprints like ECFP that hash substructures into sparse binary vectors, and graph-based methods like GraphConv that learn representations end-to-end. Rule-based fingerprints perform poorly with shallow models or limited data, while graph-based methods are designed for large fully-labeled settings.</p>
<p>Pre-training on unlabeled data had shown strong results in NLP (ELMo, BERT, XLNet), and prior work in cheminformatics had explored RNN-based and VAE-based pre-training on SMILES (<a href="/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/">Seq2Seq fingerprints</a>, <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a>, heteroencoders). However, none of these studies systematically evaluated performance in small-data settings. Honda et al. fill this gap by applying Transformer-based pre-training to SMILES and measuring data efficiency explicitly.</p>
<h2 id="transformer-pre-training-on-smiles-with-pooled-fingerprint-extraction">Transformer Pre-training on SMILES with Pooled Fingerprint Extraction</h2>
<p>The core innovation is a Transformer encoder-decoder architecture pre-trained as an autoencoder on SMILES strings, with a specific fingerprint extraction strategy that pools the encoder outputs into a fixed-length vector.</p>
<h3 id="architecture">Architecture</h3>
<p>The model uses 4 Transformer blocks for both the encoder and decoder, each with 4-head attention and 256 embedding dimensions plus 2 linear layers. Input SMILES are tokenized at the symbol level (e.g., &lsquo;c&rsquo;, &lsquo;Br&rsquo;, &lsquo;=&rsquo;, &lsquo;(&rsquo;, &lsquo;2&rsquo;) and one-hot encoded. Following Vaswani et al. (2017), the input uses the sum of token encoding and positional encoding.</p>
<h3 id="pre-training">Pre-training</h3>
<p>The model is pre-trained on 861,000 unlabeled SMILES sampled from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL24</a> to minimize cross-entropy between input and output SMILES (i.e., reconstruction). <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> (Bjerrum, 2017) randomly generates non-canonical SMILES at each epoch to reduce representation bias. Training runs for 5 epochs with Adam optimization, reaching a perplexity of 1.0 (perfect decoding).</p>
<h3 id="fingerprint-extraction">Fingerprint Extraction</h3>
<p>Since the Transformer outputs symbol-level (atom-level) representations, a pooling strategy produces molecule-level fingerprints. Four vectors are concatenated:</p>
<ol>
<li>Mean-pooled output of the last encoder layer</li>
<li>Max-pooled output of the last encoder layer</li>
<li>First output token of the last encoder layer</li>
<li>First output token of the penultimate encoder layer</li>
</ol>
<p>This produces a 1024-dimensional fingerprint, matching the dimensionality of ECFP for fair comparison.</p>
<h3 id="data-efficiency-metric">Data Efficiency Metric</h3>
<p>The paper proposes DEM to measure how well a model performs across different training set sizes:</p>
<p>$$
M_{DE}(f, m) = \frac{1}{|I|} \sum_{i \in I} m(f_i, X_i, Y_i)
$$</p>
<p>where $f_i$ is the model trained on the fraction $i$ of training data, $m$ is the task metric, and $I = {0.0125, 0.025, 0.05, 0.1, 0.2, 0.4, 0.8}$ doubles the training percentage at each step. This captures average performance across a range of data availability, giving a single scalar that balances accuracy and data efficiency.</p>
<h2 id="benchmarking-across-moleculenet-with-data-efficiency-focus">Benchmarking Across MoleculeNet with Data Efficiency Focus</h2>
<h3 id="datasets">Datasets</h3>
<p>The evaluation uses 10 datasets from MoleculeNet spanning three categories:</p>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Dataset</th>
          <th>Tasks</th>
          <th>Type</th>
          <th>Molecules</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Physical chemistry</td>
          <td>ESOL</td>
          <td>1</td>
          <td>Regression</td>
          <td>1,128</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Physical chemistry</td>
          <td>FreeSolv</td>
          <td>1</td>
          <td>Regression</td>
          <td>643</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Physical chemistry</td>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>1</td>
          <td>Regression</td>
          <td>4,200</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>MUV</td>
          <td>17</td>
          <td>Classification</td>
          <td>93,127</td>
          <td>PRC-AUC</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>HIV</td>
          <td>1</td>
          <td>Classification</td>
          <td>41,913</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Biophysics</td>
          <td>BACE</td>
          <td>1</td>
          <td>Classification</td>
          <td>1,522</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>BBBP</td>
          <td>1</td>
          <td>Classification</td>
          <td>2,053</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>Tox21</td>
          <td>12</td>
          <td>Classification</td>
          <td>8,014</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>SIDER</td>
          <td>27</td>
          <td>Classification</td>
          <td>1,427</td>
          <td>ROC-AUC</td>
      </tr>
      <tr>
          <td>Physiology</td>
          <td>ClinTox</td>
          <td>2</td>
          <td>Classification</td>
          <td>1,491</td>
          <td>ROC-AUC</td>
      </tr>
  </tbody>
</table>
<h3 id="baselines">Baselines</h3>
<ul>
<li><strong>ECFP4</strong>: Rule-based extended-connectivity fingerprint with 1024 dimensions</li>
<li><strong>RNNS2S</strong>: RNN-based Seq2Seq pre-trained fingerprint (3-layer bidirectional GRU, same pre-training data as ST)</li>
<li><strong>GraphConv</strong>: Graph convolution network trained end-to-end on labeled data</li>
</ul>
<h3 id="experimental-setup">Experimental Setup</h3>
<p>All fingerprint methods use a simple MLP classifier/regressor from scikit-learn with default hyperparameters to isolate the fingerprint quality from model capacity. Datasets are randomly split (stratified for classification), and results are averaged over 20 trials. Note that random splits are used rather than scaffold splits for the DEM experiments.</p>
<h3 id="data-efficiency-results-dem">Data Efficiency Results (DEM)</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>ST+MLP</th>
          <th>ECFP+MLP</th>
          <th>RNNS2S+MLP</th>
          <th>GraphConv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL (RMSE, lower is better)</td>
          <td><strong>1.144</strong></td>
          <td>1.741</td>
          <td>1.317</td>
          <td>1.673</td>
      </tr>
      <tr>
          <td>FreeSolv (RMSE, lower is better)</td>
          <td><strong>2.246</strong></td>
          <td>3.043</td>
          <td>2.987</td>
          <td>3.476</td>
      </tr>
      <tr>
          <td>Lipophilicity (RMSE, lower is better)</td>
          <td>1.169</td>
          <td><strong>1.090</strong></td>
          <td>1.219</td>
          <td><strong>1.062</strong></td>
      </tr>
      <tr>
          <td>MUV (PRC-AUC, higher is better)</td>
          <td>0.009</td>
          <td><strong>0.036</strong></td>
          <td>0.010</td>
          <td>0.004</td>
      </tr>
      <tr>
          <td>HIV (ROC-AUC, higher is better)</td>
          <td>0.683</td>
          <td>0.697</td>
          <td>0.682</td>
          <td><strong>0.723</strong></td>
      </tr>
      <tr>
          <td>BACE (ROC-AUC, higher is better)</td>
          <td>0.719</td>
          <td><strong>0.769</strong></td>
          <td>0.717</td>
          <td>0.744</td>
      </tr>
      <tr>
          <td>BBBP (ROC-AUC, higher is better)</td>
          <td><strong>0.900</strong></td>
          <td>0.760</td>
          <td>0.884</td>
          <td>0.795</td>
      </tr>
      <tr>
          <td>Tox21 (ROC-AUC, higher is better)</td>
          <td><strong>0.706</strong></td>
          <td>0.616</td>
          <td>0.702</td>
          <td>0.687</td>
      </tr>
      <tr>
          <td>SIDER (ROC-AUC, higher is better)</td>
          <td>0.559</td>
          <td><strong>0.588</strong></td>
          <td>0.558</td>
          <td>0.557</td>
      </tr>
      <tr>
          <td>ClinTox (ROC-AUC, higher is better)</td>
          <td><strong>0.963</strong></td>
          <td>0.515</td>
          <td>0.904</td>
          <td>0.936</td>
      </tr>
  </tbody>
</table>
<p>ST achieves the best DEM in 5 of 10 datasets (ESOL, FreeSolv, BBBP, Tox21, ClinTox), with particularly strong margins on ClinTox (+0.027 over GraphConv) and BBBP (+0.016 over RNNS2S).</p>
<h3 id="linear-model-experiments">Linear Model Experiments</h3>
<p>To further isolate fingerprint quality, the authors replace MLP with ridge/logistic regression with L2 penalty. On 8 datasets (excluding MUV and SIDER due to class imbalance issues), ST achieves best DEM in 5 of 8, confirming the fingerprint quality holds regardless of downstream model.</p>
<h3 id="stratified-analysis-by-molecule-size">Stratified Analysis by Molecule Size</h3>
<p>On BBBP stratified by SMILES length, ST&rsquo;s ROC-AUC increases with longer SMILES, similar to RNNS2S but unlike GraphConv which shows stable performance across lengths. This suggests text-based models extract richer information from longer sequences.</p>
<h3 id="comparison-with-record-scores-large-data">Comparison with Record Scores (Large Data)</h3>
<p>Under the large-data setting (80/10/10 train/val/test split with hyperparameter tuning via Optuna), ST achieves first place only in ClinTox (0.954) but performs comparably to ECFP and graph-based models on the other datasets. This confirms that ST&rsquo;s main advantage is in the low-data regime.</p>
<h2 id="strong-low-data-performance-with-caveats-on-scalability">Strong Low-Data Performance with Caveats on Scalability</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li>Transformer-based unsupervised pre-training on SMILES produces fingerprints that excel in low-data molecular property prediction, achieving best data efficiency on 5 of 10 MoleculeNet tasks.</li>
<li>The advantage is most pronounced on small datasets (ESOL with 1,128 molecules, FreeSolv with 643, BBBP with 2,053, ClinTox with 1,491) where pre-training enables good generalization.</li>
<li>With sufficient labeled data and hyperparameter tuning, ST fingerprints perform comparably to (but do not surpass) graph-based methods.</li>
<li>Longer SMILES provide richer information for text-based models, as shown by the stratified analysis on BBBP.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Random splits are used for most DEM experiments rather than scaffold splits, which may inflate performance estimates for drug discovery applications where training and test molecules are structurally distinct.</li>
<li>The pre-training corpus (861K SMILES from ChEMBL24) is relatively small by modern standards.</li>
<li>MUV performance is poor across all methods (PRC-AUC near zero), suggesting the DEM framework may not be informative for extremely imbalanced or noisy datasets.</li>
<li>No comparison with BERT-style masked language model pre-training, which later work (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>) would show as a viable alternative.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose three directions: (1) replacing the Transformer with Transformer-XL to handle longer SMILES, (2) multi-task pre-training that jointly predicts molecular descriptors (e.g., molecular weight, <a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a>) alongside SMILES reconstruction, and (3) better exploitation of enumerated SMILES to constrain the latent space.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL24</td>
          <td>861,000 SMILES</td>
          <td>Unlabeled, randomly sampled</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>MoleculeNet (10 datasets)</td>
          <td>643 to 93,127 molecules</td>
          <td>See Table 1 for per-dataset details</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer encoder-decoder: 4 blocks each, 4-head attention, 256 embedding dimensions</li>
<li>Pre-training: 5 epochs, Adam optimizer, cross-entropy loss, SMILES enumeration for augmentation</li>
<li>Fingerprint: 1024 dimensions from concatenated mean pool, max pool, and first-token outputs</li>
<li>Downstream: scikit-learn MLP (default hyperparameters) for DEM experiments; ridge/logistic regression for linear model experiments; Optuna for hyperparameter search in large-data comparison</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DSPsleeporg/smiles-transformer">smiles-transformer</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation (Jupyter notebooks)</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>DEM averaged over 7 training fractions (1.25% to 80%), 20 trials each</li>
<li>Random splits for DEM; scaffold splits for HIV, BACE, BBBP in large-data comparison</li>
<li>Metrics: RMSE (regression), ROC-AUC or PRC-AUC (classification) per MoleculeNet conventions</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU type or training time for the pre-training phase.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Honda, S., Shi, S., &amp; Ueda, H. R. (2019). SMILES Transformer: Pre-trained Molecular Fingerprint for Low Data Drug Discovery. <em>arXiv preprint arXiv:1911.04738</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{honda2019smiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SMILES Transformer: Pre-trained Molecular Fingerprint for Low Data Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Honda, Shion and Shi, Shoi and Ueda, Hiroki R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1911.04738}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SMI-TED: Encoder-Decoder Foundation Models for Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smi-ted-encoder-decoder-chemistry/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/smi-ted-encoder-decoder-chemistry/</guid><description>SMI-TED is a family of encoder-decoder transformer models pre-trained on 91M PubChem molecules for molecular property prediction and generation.</description><content:encoded><![CDATA[<h2 id="an-encoder-decoder-chemical-foundation-model-family">An Encoder-Decoder Chemical Foundation Model Family</h2>
<p>SMI-TED is a <strong>Method</strong> paper that introduces a family of encoder-decoder transformer-based foundation models for chemistry. The primary contribution is the SMI-TED289M architecture, a 289-million parameter model pre-trained on 91 million curated SMILES from PubChem, along with a Mixture-of-Experts variant (MoE-OSMI) that scales to 8x289M parameters. The models support molecular property prediction, molecule reconstruction, reaction yield prediction, and few-shot reasoning over molecular embeddings. All model weights and code are open-sourced under an Apache 2.0 license.</p>
<h2 id="bridging-encoding-and-decoding-for-molecular-representations">Bridging Encoding and Decoding for Molecular Representations</h2>
<p>Chemical language models based on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> have gained traction for molecular property prediction and generation. Most existing models, such as <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a> and <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, are encoder-only architectures that produce molecular embeddings through mean pooling. While effective for downstream classification and regression, this encoder-only approach has a limitation: mean pooling has no natural inverse, meaning the model cannot reconstruct the input molecule from its latent representation. This restricts the model&rsquo;s utility for generative tasks and limits the interpretability of the learned latent space.</p>
<p>The authors argue that adding a decoder with a reconstruction objective forces the model to encode a more complete set of structural features. Prior work has shown that the quality of pre-training data matters more than the choice of SMILES vs. <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, and that large-scale pre-training can yield useful chemical representations. SMI-TED builds on these observations by combining an encoder-decoder architecture with a carefully curated 91-million molecule dataset from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>.</p>
<h2 id="invertible-pooling-and-two-phase-pre-training">Invertible Pooling and Two-Phase Pre-Training</h2>
<p>The core architectural innovation in SMI-TED is a learned pooling mechanism that replaces standard mean or max pooling with an invertible projection. Given token embeddings $\mathbf{x} \in \mathbb{R}^{D \times L}$ (where $D = 202$ is the maximum token count and $L = 768$ is the embedding dimension), the submersion into the latent space $\mathbf{z} \in \mathbb{R}^{L}$ is computed as:</p>
<p>$$
\mathbf{z} = \left(\text{LayerNorm}\left(\text{GELU}\left(\mathbf{W}_1^T \mathbf{x} + \mathbf{b}_1\right)\right)\right) \mathbf{W}_2
$$</p>
<p>where $\mathbf{W}_1 \in \mathbb{R}^{D \times L}$, $\mathbf{b}_1 \in \mathbb{R}^{L}$, and $\mathbf{W}_2 \in \mathbb{R}^{L \times L}$. The immersion (inverse mapping) back to the token space is:</p>
<p>$$
\tilde{\mathbf{x}}^T = \left(\text{LayerNorm}\left(\text{GELU}\left(\mathbf{z} \mathbf{W}_3 + \mathbf{b}_3\right)\right)\right) \mathbf{W}_4
$$</p>
<p>where $\mathbf{W}_3 \in \mathbb{R}^{L \times L}$, $\mathbf{b}_3 \in \mathbb{R}^{L}$, and $\mathbf{W}_4 \in \mathbb{R}^{L \times D}$. A decoder language model then predicts the next token from $\tilde{\mathbf{x}}$.</p>
<p>The encoder uses a modified RoFormer attention mechanism with rotary position embeddings:</p>
<p>$$
\text{Attention}_m(Q, K, V) = \frac{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle v_n}{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle}
$$</p>
<p>where $R_m$ are position-dependent rotation matrices and $\varphi$ is a random feature map.</p>
<p><strong>Two-phase pre-training strategy:</strong></p>
<ul>
<li><strong>Phase 1</strong>: The token encoder is pre-trained on 95% of the data using masked language modeling (15% token selection, of which 80% masked, 10% random, 10% unchanged). The remaining 5% trains the encoder-decoder layer, preventing convergence issues from unstable early embeddings.</li>
<li><strong>Phase 2</strong>: After the token embeddings converge, both the encoder and decoder train on 100% of the data jointly.</li>
</ul>
<p><strong><a href="https://en.wikipedia.org/wiki/Mixture_of_experts">Mixture-of-Experts</a> (MoE-OSMI):</strong> The MoE variant composes 8 fine-tuned SMI-TED289M expert models with a gating network. Given an input embedding $x$, the output is:</p>
<p>$$
y = \sum_{i=1}^{n} G(x)_i E_i(\hat{x})
$$</p>
<p>where $G(x) = \text{Softmax}(\text{TopK}(x \cdot W_g))$ selects the top $k = 2$ experts per input, setting all other gate values to zero.</p>
<h2 id="benchmarks-across-property-prediction-generation-and-reaction-yield">Benchmarks Across Property Prediction, Generation, and Reaction Yield</h2>
<h3 id="moleculenet-classification-6-datasets-roc-auc"><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification (6 datasets, ROC-AUC)</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>BBBP</th>
          <th>ClinTox</th>
          <th>HIV</th>
          <th>BACE</th>
          <th>SIDER</th>
          <th>Tox21</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer</td>
          <td>73.6 +/- 0.8</td>
          <td>91.2 +/- 1.4</td>
          <td>80.5 +/- 1.65</td>
          <td>86.3 +/- 0.6</td>
          <td>65.5 +/- 0.2</td>
          <td>80.46 +/- 0.2</td>
      </tr>
      <tr>
          <td>Uni-Mol</td>
          <td>72.9 +/- 0.6</td>
          <td>91.9 +/- 1.8</td>
          <td>80.8 +/- 0.3</td>
          <td>85.7 +/- 0.2</td>
          <td>65.9 +/- 1.3</td>
          <td>79.6 +/- 0.5</td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>72.4 +/- 0.4</td>
          <td>90.1 +/- 1.3</td>
          <td>80.6 +/- 0.9</td>
          <td>85.6 +/- 1.1</td>
          <td>67.2 +/- 0.4</td>
          <td>78.1 +/- 0.1</td>
      </tr>
      <tr>
          <td>SMI-TED289M (pre-trained)</td>
          <td>91.46 +/- 0.47</td>
          <td>93.49 +/- 0.85</td>
          <td>80.51 +/- 1.34</td>
          <td>85.58 +/- 0.92</td>
          <td>66.01 +/- 0.88</td>
          <td>81.53 +/- 0.45</td>
      </tr>
      <tr>
          <td>SMI-TED289M (fine-tuned)</td>
          <td><strong>92.26 +/- 0.57</strong></td>
          <td><strong>94.27 +/- 1.83</strong></td>
          <td>76.85 +/- 0.89</td>
          <td><strong>88.24 +/- 0.50</strong></td>
          <td>65.68 +/- 0.45</td>
          <td><strong>81.85 +/- 1.42</strong></td>
      </tr>
  </tbody>
</table>
<p>SMI-TED achieves the best results in 4 of 6 classification tasks. Notably, the pre-trained version (without fine-tuning) already matches or exceeds many baselines on BBBP, ClinTox, and Tox21.</p>
<h3 id="moleculenet-regression-5-datasets-mae-for-qm9qm8-rmse-for-esolfreesolvlipophilicity">MoleculeNet regression (5 datasets, MAE for QM9/QM8, RMSE for ESOL/FreeSolv/Lipophilicity)</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>QM9</th>
          <th>QM8</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer</td>
          <td>1.5894</td>
          <td>0.0102</td>
          <td>0.880</td>
          <td>2.342</td>
          <td>0.700</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>3.241</td>
          <td>0.0143</td>
          <td>0.98</td>
          <td>2.18</td>
          <td>0.65</td>
      </tr>
      <tr>
          <td>SMI-TED289M (fine-tuned)</td>
          <td><strong>1.3246</strong></td>
          <td><strong>0.0095</strong></td>
          <td><strong>0.6112</strong></td>
          <td><strong>1.2233</strong></td>
          <td><strong>0.5522</strong></td>
      </tr>
  </tbody>
</table>
<p>SMI-TED289M achieves the best results across all 5 regression tasks when fine-tuned. The improvements are substantial on ESOL (0.61 vs. 0.82 for next best) and FreeSolv (1.22 vs. 1.91 for next best).</p>
<h3 id="reaction-yield-prediction-buchwald-hartwig-c-n-cross-coupling">Reaction yield prediction (<a href="https://en.wikipedia.org/wiki/Buchwald%E2%80%93Hartwig_amination">Buchwald-Hartwig</a> C-N cross-coupling)</h3>
<p>The model was tested on Pd-catalyzed Buchwald-Hartwig reactions with 3,955 reactions across varying train/test splits. Selected $R^2$ results:</p>
<table>
  <thead>
      <tr>
          <th>Split</th>
          <th>Yield-BERT (Aug)</th>
          <th>DRFP</th>
          <th>SMI-TED289M</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>70/30</td>
          <td>0.97</td>
          <td>0.95</td>
          <td><strong>0.984</strong></td>
      </tr>
      <tr>
          <td>10/90</td>
          <td>0.81</td>
          <td>0.81</td>
          <td><strong>0.961</strong></td>
      </tr>
      <tr>
          <td>2.5/97.5</td>
          <td>0.61</td>
          <td>0.62</td>
          <td><strong>0.875</strong></td>
      </tr>
      <tr>
          <td>Test 1-4 avg</td>
          <td>0.58</td>
          <td>0.71</td>
          <td><strong>0.983</strong></td>
      </tr>
  </tbody>
</table>
<p>SMI-TED shows particularly strong performance in low-data regimes. With only 2.5% training data, it achieves $R^2 = 0.875$, compared to 0.61-0.62 for competing methods.</p>
<h3 id="moses-molecular-generation-benchmarks"><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> molecular generation benchmarks</h3>
<p>SMI-TED is competitive with baselines including CharRNN, SMILES VAE, JT-VAE, <a href="/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/">LIMO</a>, <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen-7b</a>, and <a href="/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/">GP-MoLFormer</a> on standard metrics (validity, uniqueness, novelty, FCD, internal diversity). It achieves superior scaffold cosine similarity (Scaf) and nearest-neighbor similarity (SNN) scores.</p>
<h3 id="latent-space-compositionality">Latent space compositionality</h3>
<p>Using six families of carbon chains ($\mathcal{F} = {CC, CO, CN, CS, CF, CP}$), the authors test whether the embedding space respects hierarchical distance structures. A linear regression on SMI-TED embeddings yields $R^2 = 0.99$ and $MSE = 0.002$, compared to $R^2 = 0.55$ and $MSE = 0.237$ for MoLFormer. This indicates that the SMI-TED latent space captures compositional chemical relationships far more faithfully.</p>
<p>For structure-property analysis on <a href="/notes/chemistry/datasets/qm9/">QM9</a>, nitrogen-containing molecules represent 9.10% of the dataset but account for 32.81% of the top 10% by HOMO energy. In the SMI-TED latent space, these molecules cluster distinctly (<a href="https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index">Davies-Bouldin index</a> of 2.82 vs. 4.28 for MoLFormer), suggesting the decoder objective encourages encoding of functional group information.</p>
<h2 id="strong-performance-with-a-compositional-latent-space">Strong Performance with a Compositional Latent Space</h2>
<p>SMI-TED289M demonstrates competitive or superior performance across molecular property prediction, reaction yield prediction, and molecular generation benchmarks. The key findings include:</p>
<ol>
<li><strong>Broad applicability</strong>: The single pre-trained model achieves strong results across classification (4/6 best), regression (5/5 best), reaction yield, and generation tasks.</li>
<li><strong>Low-data robustness</strong>: The pre-training on 91M molecules provides chemical knowledge that transfers well to small training sets, as shown by the reaction yield experiments where SMI-TED maintains high accuracy even at 2.5% training data.</li>
<li><strong>Compositional embeddings</strong>: The encoder-decoder architecture produces a latent space where molecular similarity follows chemical intuition, with near-perfect linear relationships between functional group families ($R^2 = 0.99$).</li>
<li><strong>Structure-property capture</strong>: The reconstruction objective appears to enforce encoding of chemically meaningful features like nitrogen substituent effects on <a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO</a> energy, outperforming encoder-only models in latent space organization.</li>
</ol>
<p><strong>Limitations</strong>: The paper evaluates on MoleculeNet benchmarks, which are well-studied but may not reflect performance on more diverse chemical tasks. The BBBP classification result (92.26) shows a large jump from prior methods (73.6 for MoLFormer), which is worth scrutinizing. The MoE variant is evaluated only in supplementary materials, and scaling behavior beyond 8 experts is not explored.</p>
<p><strong>Future directions</strong>: The authors note that compositionality of the learned representations suggests potential for reasoning applications, though they acknowledge that stronger claims require further studies following compositionality analysis methodologies from natural language processing. The model has been integrated into the dZiner agent for inverse molecular design.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>PubChem (curated)</td>
          <td>91M molecules, 4B tokens</td>
          <td>Deduplicated, canonicalized, validity-checked</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>MoleculeNet (BBBP, ClinTox, HIV, BACE, SIDER, Tox21)</td>
          <td>Varies</td>
          <td>Original benchmark splits</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>MoleculeNet (QM9, QM8, ESOL, FreeSolv, Lipophilicity)</td>
          <td>Varies</td>
          <td>Original benchmark splits</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td>MOSES</td>
          <td>1.94M molecules</td>
          <td>Train/test/scaffold test splits</td>
      </tr>
      <tr>
          <td>Reaction yield</td>
          <td>Buchwald-Hartwig HTE</td>
          <td>3,955 reactions</td>
          <td>3x 1536-well plates</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Masked language modeling for token encoder (15% selection: 80% masked, 10% random, 10% unchanged)</li>
<li>Two-phase pre-training (95/5 split then 100% joint training)</li>
<li>RoFormer attention with rotary position embeddings</li>
<li>Vocabulary: 2,993 tokens (2,988 molecular + 5 special)</li>
<li>Maximum sequence length: 202 tokens (covers 99.4% of PubChem)</li>
<li>Learning rate: 1.6e-4, batch size: 288 molecules</li>
<li>40 epochs over the full PubChem corpus</li>
<li>10 random seeds per experiment for robustness</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Variant</th>
          <th>Parameters</th>
          <th>Encoder</th>
          <th>Decoder</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SMI-TED289M base</td>
          <td>289M</td>
          <td>47M</td>
          <td>242M</td>
          <td>12 layers, 12 attention heads, hidden size 768, dropout 0.2</td>
      </tr>
      <tr>
          <td>MoE-OSMI</td>
          <td>8x289M</td>
          <td>-</td>
          <td>-</td>
          <td>8 experts, top-k=2 routing, gating network</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Classification: ROC-AUC</li>
<li>Regression: MAE (QM9, QM8), RMSE (ESOL, FreeSolv, Lipophilicity)</li>
<li>Reaction yield: $R^2$</li>
<li>Generation: Validity, uniqueness, novelty, FCD, IntDiv, Scaf, SNN (MOSES metrics)</li>
<li>Latent space: Linear regression $R^2$, MSE, Davies-Bouldin index, t-SNE visualization</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>24 NVIDIA V100 GPUs (16GB)</li>
<li>4 nodes with DDP (Distributed Data Parallel)</li>
<li>Pre-training: 40 epochs on 91M molecules</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/materials/tree/main/models/smi_ted">IBM/materials (smi_ted)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Training, fine-tuning scripts, Jupyter notebooks</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm/materials.smi-ted">ibm/materials.smi-ted</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Pre-trained model weights</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.15603701">Zenodo archive</a></td>
          <td>Code + Data</td>
          <td>Apache-2.0</td>
          <td>Archival copy of scripts</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Soares, E., Vital Brazil, E., Shirasuna, V., Zubarev, D., Cerqueira, R., &amp; Schmidt, K. (2025). An open-source family of large encoder-decoder foundation models for chemistry. <em>Communications Chemistry</em>, 8(1). <a href="https://doi.org/10.1038/s42004-025-01585-0">https://doi.org/10.1038/s42004-025-01585-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{soares2025smited,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{An open-source family of large encoder-decoder foundation models for chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Soares, Eduardo and Vital Brazil, Emilio and Shirasuna, Victor and Zubarev, Dmitry and Cerqueira, Renato and Schmidt, Kristin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Communications Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42004-025-01585-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Seq2seq Fingerprint: Unsupervised Molecular Embedding</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/</guid><description>Seq2seq fingerprint uses a GRU encoder-decoder trained on SMILES self-translation to produce unsupervised molecular embeddings for property prediction.</description><content:encoded><![CDATA[<h2 id="an-unsupervised-seq2seq-method-for-molecular-fingerprints">An Unsupervised Seq2seq Method for Molecular Fingerprints</h2>
<p>This is a <strong>Method</strong> paper that introduces seq2seq fingerprint, an unsupervised molecular embedding approach based on sequence-to-sequence learning. The core idea is to train a <a href="https://en.wikipedia.org/wiki/Gated_recurrent_unit">GRU</a> encoder-decoder network to translate <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings to themselves, then extract the intermediate fixed-length vector as a molecular fingerprint. These fingerprints are then used with standard supervised classifiers for downstream property prediction tasks such as solubility classification and promiscuity prediction.</p>
<h2 id="the-labeled-data-bottleneck-in-drug-discovery">The Labeled Data Bottleneck in Drug Discovery</h2>
<p>Machine learning approaches to molecular property prediction depend on fixed-length feature vectors as inputs. Traditional molecular fingerprints fall into two categories: hash-based methods like Extended-Connectivity Fingerprints (ECFP) that are fast but lossy and non-invertible, and biologist-guided local-feature fingerprints that require domain expertise and are task-specific. Supervised deep learning fingerprints (e.g., neural fingerprints) can learn representations from data but require large amounts of labeled data, which is expensive to obtain in drug discovery due to the cost of biological experiments.</p>
<p>The authors identify three limitations of existing approaches:</p>
<ol>
<li>Hash-based fingerprints discard information during the hashing process and cannot reconstruct the original molecule</li>
<li>Local-feature fingerprints require expert knowledge and generalize poorly across tasks</li>
<li>Supervised deep learning fingerprints are data-hungry and fail when labeled data is limited</li>
</ol>
<h2 id="self-translation-as-unsupervised-molecular-encoding">Self-Translation as Unsupervised Molecular Encoding</h2>
<p>The key insight is to adapt the <a href="https://en.wikipedia.org/wiki/Seq2seq">sequence-to-sequence</a> learning framework from machine translation (originally English-to-French) to molecular representation learning by setting both the input and output to the same SMILES string. Since the intermediate vector must contain enough information to reconstruct the original SMILES, it serves as a rich, task-agnostic molecular fingerprint.</p>
<p>The architecture consists of two components:</p>
<ul>
<li><strong>Perceiver network</strong>: A multi-layer GRU encoder that reads the SMILES string and compresses it into a fixed-length vector</li>
<li><strong>Interpreter network</strong>: A multi-layer GRU decoder that reconstructs the original SMILES from the fingerprint vector</li>
</ul>
<p>The GRU cell computes a sequence of outputs $(s_1, \ldots, s_T)$ from input sequences $(x_1, \ldots, x_T)$ by iterating:</p>
<p>$$
z_t = \sigma_g(W_z x_t + U_z s_{t-1} + b_z)
$$</p>
<p>$$
r_t = \sigma_r(W_r x_t + U_r s_{t-1} + b_r)
$$</p>
<p>$$
h_t = \tanh(U_h x_t + W_h(s_{t-1} \circ r_t))
$$</p>
<p>$$
s_t = (1 - z_t) \circ h_{t-1} + z_t \circ s_{t-1}
$$</p>
<p>where $z_t$ is the update gate, $r_t$ is the reset gate, $\circ$ denotes element-wise multiplication, and $W$, $U$, $b$ are trainable parameters.</p>
<p>Several adaptations to the original seq2seq framework make this work for molecular data:</p>
<ol>
<li><strong>GRU instead of LSTM</strong>: GRU provides comparable performance with faster training, which is important given the large training data pool</li>
<li><strong>Attention mechanism</strong>: Establishes a stronger connection between the perceiver and interpreter networks via soft alignment, addressing the challenge of passing information through hidden memory for long sequences (SMILES can be up to 250 characters)</li>
<li><strong>Dropout layers</strong>: Added to input and output gates (but not hidden memory transfer) following the approach of Zaremba et al. to combat overfitting when training on large datasets</li>
<li><strong>Fingerprint extraction layer</strong>: A fixed-unit fully connected layer combined with a GRU cell state concatenation layer is inserted between encoder and decoder to explicitly output the fingerprint vector</li>
<li><strong>Reverse target sequence</strong>: Following Sutskever et al., the target sequence is reversed to improve SGD optimization</li>
<li><strong>Bucket training</strong>: Sequences are distributed into buckets by length and padded to enable GPU parallelization</li>
</ol>
<h2 id="classification-experiments-on-logp-and-pm2-datasets">Classification Experiments on LogP and PM2 Datasets</h2>
<h3 id="training-setup">Training Setup</h3>
<p>The unsupervised training used 334,092 valid SMILES representations from combined LogP and PM2-full datasets obtained from the National Center for Advancing Translational Sciences (NCATS) at NIH. Three model variants were trained with fingerprint dimensions of 512, 768, and 1024, differing in the number of GRU layers (2, 3, and 4 respectively) while keeping the latent dimension at 256. Each model was trained for 24 hours on a workstation with an Intel i7-6700K CPU, 16 GB RAM, and an NVIDIA GTX 1080 GPU.</p>
<h3 id="reconstruction-performance">Reconstruction Performance</h3>
<p>The models were evaluated on their ability to reconstruct SMILES strings from their fingerprints:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>GRU Layers</th>
          <th>Latent Dim</th>
          <th>Perplexity</th>
          <th>Exact Match Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>seq2seq-512</td>
          <td>2</td>
          <td>256</td>
          <td>1.00897</td>
          <td>94.24%</td>
      </tr>
      <tr>
          <td>seq2seq-768</td>
          <td>3</td>
          <td>256</td>
          <td>1.00949</td>
          <td>92.92%</td>
      </tr>
      <tr>
          <td>seq2seq-1024</td>
          <td>4</td>
          <td>256</td>
          <td>1.01472</td>
          <td>90.26%</td>
      </tr>
  </tbody>
</table>
<p>Deeper models showed lower reconstruction accuracy, possibly because larger fingerprint spaces introduce more null spaces and require longer training to converge.</p>
<h3 id="classification-results">Classification Results</h3>
<p>Two labeled datasets were used for downstream classification:</p>
<ul>
<li><strong>LogP</strong>: 10,850 samples with <a href="https://en.wikipedia.org/wiki/Partition_coefficient">water-octanol partition coefficient</a> values, binarized at a threshold of 1.88</li>
<li><strong>PM2-10k</strong>: 10,000 samples with binary promiscuity class labels</li>
</ul>
<p>The seq2seq fingerprints were evaluated with three ensemble classifiers (<a href="https://en.wikipedia.org/wiki/AdaBoost">AdaBoost</a>, <a href="https://en.wikipedia.org/wiki/Gradient_boosting">GradientBoost</a>, RandomForest) against circular fingerprints (ECFP) and neural fingerprints. Results are 100-run averages of 5-fold cross-validation accuracy.</p>
<p><strong>LogP classification accuracy:</strong></p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Mean Accuracy</th>
          <th>Std Dev</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Circular FP (ECFP)</td>
          <td>0.3674</td>
          <td>0.0074</td>
      </tr>
      <tr>
          <td>Neural FP</td>
          <td>0.6080</td>
          <td>0.0135</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + GradientBoost</td>
          <td><strong>0.7664</strong></td>
          <td>0.0043</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + AdaBoost</td>
          <td>0.7342</td>
          <td>0.0042</td>
      </tr>
      <tr>
          <td>Seq2seq-512 + GradientBoost</td>
          <td>0.7350</td>
          <td>0.0060</td>
      </tr>
  </tbody>
</table>
<p><strong>PM2-10k classification accuracy:</strong></p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Mean Accuracy</th>
          <th>Std Dev</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Circular FP (ECFP)</td>
          <td>0.3938</td>
          <td>0.0114</td>
      </tr>
      <tr>
          <td>Neural FP</td>
          <td>0.5227</td>
          <td>0.0112</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + GradientBoost</td>
          <td><strong>0.6206</strong></td>
          <td>0.0198</td>
      </tr>
      <tr>
          <td>Seq2seq-1024 + AdaBoost</td>
          <td>0.6036</td>
          <td>0.0147</td>
      </tr>
      <tr>
          <td>Seq2seq-512 + GradientBoost</td>
          <td>0.5741</td>
          <td>0.0086</td>
      </tr>
  </tbody>
</table>
<p>The seq2seq fingerprint outperformed both baselines across all configurations. Despite the seq2seq-1024 model having lower reconstruction accuracy, it provided the best classification performance, suggesting that the longer fingerprint captures more discriminative information for downstream tasks even if the reconstruction is less exact.</p>
<h2 id="unsupervised-transfer-learning-for-molecular-properties">Unsupervised Transfer Learning for Molecular Properties</h2>
<p>The results demonstrate that unsupervised pretraining on large unlabeled molecular datasets can produce fingerprints that transfer well to supervised property prediction with limited labels. The key advantages confirmed by the experiments are:</p>
<ol>
<li><strong>Label-free training</strong>: The unsupervised approach uses essentially unlimited SMILES data, avoiding the expensive label collection process</li>
<li><strong>Task-agnostic representations</strong>: The same fingerprints work across different classification tasks (solubility and promiscuity) without retraining</li>
<li><strong>Invertibility</strong>: The fingerprints contain enough information to reconstruct the original SMILES (up to 94.24% exact match), unlike hash-based methods</li>
</ol>
<p><strong>Limitations</strong> acknowledged by the authors include:</p>
<ul>
<li>Long training times (24 hours per model variant), motivating future work on distributed training</li>
<li>The relationship between fingerprint dimensionality and downstream performance is non-monotonic (768-dim underperforms 512-dim on some tasks), suggesting sensitivity to hyperparameter choices</li>
<li>Only classification tasks were evaluated; regression performance was not assessed</li>
<li>The comparison baselines are limited to ECFP and neural fingerprints from 2015</li>
</ul>
<p><strong>Future directions</strong> proposed include distributed training strategies, hyperparameter optimization methods, and semi-supervised extensions that incorporate label information into the fingerprint training.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unsupervised training</td>
          <td>LogP + PM2-full (combined)</td>
          <td>334,092 SMILES</td>
          <td>Obtained from NCATS at NIH</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>LogP</td>
          <td>10,850 samples</td>
          <td>Binary labels at LogP threshold 1.88</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>PM2-10k</td>
          <td>10,000 samples</td>
          <td>Binary promiscuity labels</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder-decoder: Multi-layer GRU with attention mechanism and dropout</li>
<li>Fingerprint dimensions: 512, 768, 1024 (with 2, 3, 4 GRU layers respectively)</li>
<li>Latent dimension: 256 for all variants</li>
<li>Downstream classifiers: AdaBoost, GradientBoost, RandomForest</li>
<li>Evaluation: 5-fold cross-validation, 100-run averages</li>
<li>Baselines: ECFP via RDKit, Neural Fingerprint from HIPS/neural-fingerprint</li>
</ul>
<h3 id="models">Models</h3>
<p>Three model variants trained for 24 hours each. The paper states code would become publicly available after acceptance, but no public repository has been confirmed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value</th>
          <th>Task</th>
          <th>Configuration</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Classification accuracy</td>
          <td>0.7664</td>
          <td>LogP</td>
          <td>seq2seq-1024 + GradientBoost</td>
      </tr>
      <tr>
          <td>Classification accuracy</td>
          <td>0.6206</td>
          <td>PM2-10k</td>
          <td>seq2seq-1024 + GradientBoost</td>
      </tr>
      <tr>
          <td>Exact match reconstruction</td>
          <td>94.24%</td>
          <td>SMILES recovery</td>
          <td>seq2seq-512</td>
      </tr>
      <tr>
          <td>Perplexity</td>
          <td>1.00897</td>
          <td>SMILES recovery</td>
          <td>seq2seq-512</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: Intel i7-6700K @ 4.00 GHz, 16 GB RAM, NVIDIA GTX 1080 GPU</li>
<li>Hyperparameter search and classifier training: TACC Lonestar 5 cluster</li>
<li>Training time: 24 hours per model variant</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/HIPS/neural-fingerprint">Neural Fingerprint (baseline)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Baseline comparison code</td>
      </tr>
  </tbody>
</table>
<p>The authors indicated the seq2seq fingerprint code would be released after acceptance, but no public repository has been found as of this writing. The datasets were sourced from NCATS/NIH.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, Z., Wang, S., Zhu, F., &amp; Huang, J. (2017). Seq2seq Fingerprint: An Unsupervised Deep Molecular Embedding for Drug Discovery. <em>Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics (ACM-BCB &lsquo;17)</em>, 285-294. <a href="https://doi.org/10.1145/3107411.3107424">https://doi.org/10.1145/3107411.3107424</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xu2017seq2seq,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Seq2seq Fingerprint: An Unsupervised Deep Molecular Embedding for Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xu, Zheng and Wang, Sheng and Zhu, Feiyun and Huang, Junzhou}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{285--294}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2017}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1145/3107411.3107424}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>S4 Structured State Space Models for De Novo Drug Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/s4-chemical-language-modeling/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/s4-chemical-language-modeling/</guid><description>S4 state space models are applied to chemical language modeling for de novo drug design, outperforming LSTMs and GPTs in bioactivity learning from SMILES.</description><content:encoded><![CDATA[<h2 id="structured-state-spaces-meet-chemical-language-modeling">Structured State Spaces Meet Chemical Language Modeling</h2>
<p>This is a <strong>Method</strong> paper that introduces structured state space sequence (S4) models to chemical language modeling (CLM) for de novo drug design. S4 models have a dual formulation: they process entire input sequences via convolution during training (like Transformers) and generate sequences element-by-element via recurrence during inference (like LSTMs). The authors benchmark S4 against LSTM and GPT architectures across multiple drug discovery tasks, including drug-like molecule generation, bioactivity learning, chemical space exploration, natural product design, and prospective kinase inhibitor design validated by molecular dynamics simulations.</p>
<h2 id="bridging-the-lstm-transformer-gap-in-molecular-generation">Bridging the LSTM-Transformer Gap in Molecular Generation</h2>
<p>Chemical language models (CLMs) generate molecules by learning the &ldquo;chemical language&rdquo; of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> string representations. The two dominant architectures for CLMs are LSTMs and GPTs, each with complementary strengths and limitations:</p>
<ul>
<li><strong>LSTMs</strong> generate sequences recurrently (element-by-element), which enables efficient generation and good learning of local/short-range dependencies. However, their sequential information bottleneck limits learning of global sequence properties.</li>
<li><strong>GPTs</strong> (Transformer decoders) process the entire input at once, better capturing global properties like bioactivity. However, they become increasingly compute-intensive for longer SMILES strings and struggle with chemical space exploration at higher sampling temperatures.</li>
</ul>
<p>Complex molecular properties like bioactivity can emerge from separated portions of a SMILES string (e.g., distant functional groups in the linear notation). Neither architecture fully addresses the need to learn these long-range dependencies while maintaining efficient, robust generation. The chemical space, estimated at up to $10^{60}$ small molecules, demands models that can both capture complex property relationships and explore diverse scaffolds efficiently.</p>
<h2 id="the-dual-nature-of-s4-convolution-meets-recurrence">The Dual Nature of S4: Convolution Meets Recurrence</h2>
<p>S4 models are built on discrete <a href="https://en.wikipedia.org/wiki/State-space_model">state space models</a>, which map an input sequence $\mathbf{u}$ to an output sequence $\mathbf{y}$ through learnable parameters $\overline{\mathbf{A}} \in \mathbb{R}^{N \times N}$, $\overline{\mathbf{B}} \in \mathbb{R}^{N \times 1}$, $\overline{\mathbf{C}} \in \mathbb{R}^{1 \times N}$, and $\overline{\mathbf{D}} \in \mathbb{R}^{1 \times 1}$:</p>
<p>$$
x_{k} = \overline{\mathbf{A}} x_{k-1} + \overline{\mathbf{B}} u_{k}
$$</p>
<p>$$
y_{k} = \overline{\mathbf{C}} x_{k} + \overline{\mathbf{D}} u_{k}
$$</p>
<p>This linear recurrence can equivalently be &ldquo;unrolled&rdquo; into a global convolution:</p>
<p>$$
\mathbf{y} = \mathbf{u} * \overline{\mathbf{K}}
$$</p>
<p>where $\overline{\mathbf{K}}$ is a convolution filter parameterized by $\overline{\mathbf{A}}$, $\overline{\mathbf{B}}$, and $\overline{\mathbf{C}}$. This duality is the core innovation for CLMs:</p>
<ul>
<li><strong>Training</strong>: S4 uses the convolutional formulation to learn from entire SMILES sequences simultaneously, capturing global molecular properties.</li>
<li><strong>Generation</strong>: S4 switches to the recurrent formulation, producing SMILES tokens one at a time for efficient, robust chemical space exploration.</li>
</ul>
<p>S4 addresses the numerical instabilities of naive state space models through high-order polynomial projection operators (HiPPO) and reduction to the stable Cauchy kernel computation, enabling effective learning of long-range dependencies.</p>
<p>For molecular ranking after fine-tuning, the log-likelihood score subtracts the pre-training likelihood to isolate target-specific information:</p>
<p>$$
\mathcal{L}_{\text{score}}(\mathbf{M}) = \mathcal{L}(\mathbf{M}_{\text{ft}}) - \mathcal{L}(\mathbf{M}_{\text{pt}})
$$</p>
<p>where $\mathcal{L}(\mathbf{M}_{\text{ft}})$ and $\mathcal{L}(\mathbf{M}_{\text{pt}})$ are the fine-tuned and pre-trained model log-likelihoods, respectively.</p>
<h2 id="benchmarking-s4-across-drug-discovery-tasks">Benchmarking S4 Across Drug Discovery Tasks</h2>
<h3 id="drug-like-molecule-generation">Drug-like molecule generation</h3>
<p>All three CLMs (S4, LSTM, GPT) were pre-trained on 1.9M canonical SMILES from ChEMBL v31 (molecules with fewer than 100 tokens). Each model generated 102,400 SMILES strings de novo.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Valid</th>
          <th>Unique</th>
          <th>Novel</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>S4</td>
          <td>99,268 (97%)</td>
          <td>98,712 (96%)</td>
          <td>95,552 (93%)</td>
      </tr>
      <tr>
          <td>LSTM</td>
          <td>97,151 (95%)</td>
          <td>96,618 (94%)</td>
          <td>82,988 (81%)</td>
      </tr>
      <tr>
          <td>GPT</td>
          <td>93,580 (91%)</td>
          <td>93,263 (91%)</td>
          <td>91,590 (89%)</td>
      </tr>
  </tbody>
</table>
<p>S4 produces the most valid, unique, and novel molecules. Error analysis reveals that each architecture shows different failure modes: LSTMs struggle most with branching errors, GPTs with ring and bond assignment errors, while S4 generates fewer branching and ring errors but more bond assignment errors than LSTM. This pattern supports the hypothesis that S4 captures long-range dependencies (branching, ring opening/closure) better while local dependencies (bond assignment) are handled better by recurrent processing.</p>
<h3 id="bioactivity-learning-via-transfer-learning">Bioactivity learning via transfer learning</h3>
<p>Five fine-tuning campaigns were conducted on targets from the LIT-PCBA dataset: PKM2, <a href="https://en.wikipedia.org/wiki/Mitogen-activated_protein_kinase_1">MAPK1</a>, GBA, mTORC1, and TP53. After fine-tuning, models ranked held-out test molecules by learned log-likelihoods to evaluate bioactive compound prioritization.</p>
<p>S4 outperformed both benchmarks across targets. Wilcoxon signed-rank tests on pooled scores confirmed statistically significant superiority:</p>
<ul>
<li>S4 vs. LSTM: $p$ [top 10] = 8.41e-6, $p$ [top 50] = 2.93e-7, $p$ [top 100] = 1.45e-7</li>
<li>S4 vs. GPT: $p$ [top 10] = 2.33e-3, $p$ [top 50] = 3.72e-3, $p$ [top 100] = 2.61e-2</li>
</ul>
<p>TP53 was the most challenging target, where no model consistently retrieved actives in the top 10, possibly due to <a href="/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/">activity cliffs</a> in the test set.</p>
<h3 id="chemical-space-exploration-with-temperature-sampling">Chemical space exploration with temperature sampling</h3>
<p>Models were evaluated across sampling temperatures from $T = 1.0$ to $T = 2.0$ on three metrics: SMILES validity, rediscovery rate of known actives, and scaffold diversity. Key findings:</p>
<ul>
<li><strong>Validity</strong>: S4 and LSTM maintain higher validity than GPT at elevated temperatures (GPT median validity drops below 40% at high T).</li>
<li><strong>Rediscovery</strong>: S4 outperforms LSTM in rediscovering bioactive molecules at all temperatures.</li>
<li><strong>Scaffold diversity</strong>: LSTM achieves the highest number of unique scaffold clusters (median 6,602 at $T = 1.75$), with S4 as close second (6,520 clusters).</li>
</ul>
<p>S4 provides the best balance between bioactivity capture and structural diversity.</p>
<h3 id="natural-product-design">Natural product design</h3>
<p>Models were trained on 32,360 large natural product SMILES (length &gt; 100 tokens) from the COCONUT database and used to generate 102,400 designs each.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>S4</th>
          <th>LSTM</th>
          <th>GPT</th>
          <th>Training Set</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid</td>
          <td>82,633 (81%)</td>
          <td>76,264 (74%)</td>
          <td>70,117 (68%)</td>
          <td>n.a.</td>
      </tr>
      <tr>
          <td>Unique</td>
          <td>53,293 (52%)</td>
          <td>51,326 (50%)</td>
          <td>50,487 (49%)</td>
          <td>n.a.</td>
      </tr>
      <tr>
          <td>Novel</td>
          <td>40,897 (40%)</td>
          <td>43,245 (42%)</td>
          <td>43,168 (42%)</td>
          <td>n.a.</td>
      </tr>
      <tr>
          <td>NP-likeness</td>
          <td>1.6 +/- 0.7</td>
          <td>1.5 +/- 0.7</td>
          <td>1.5 +/- 0.7</td>
          <td>1.6 +/- 0.7</td>
      </tr>
  </tbody>
</table>
<p>S4 designs the most valid molecules (6,000 to 12,000 more than benchmarks) and achieves significantly higher NP-likeness ($p = 1.41 \times 10^{-53}$ vs. LSTM, $p = 1.02 \times 10^{-82}$ vs. GPT). S4 also achieves the lowest Kolmogorov-Smirnov distances to the training/test distributions across multiple structural properties (sp3 carbons, aliphatic rings, spiro atoms, molecular weight, fused ring size, heavy atoms).</p>
<p>For computational efficiency, S4 trains as fast as GPT (both approximately 1.3x faster than LSTM) and generates fastest among all architectures.</p>
<h3 id="prospective-mapk1-inhibitor-design">Prospective MAPK1 inhibitor design</h3>
<p>The pre-trained S4 model was fine-tuned on 68 manually curated MAPK1 inhibitors ($K_i &lt; 1 \mu M$) from ChEMBL v33. The last five fine-tuning epochs generated 256K molecules across five temperature values. After ranking and filtering by log-likelihood score and scaffold similarity, the top 10 designs were evaluated via <a href="/notes/chemistry/molecular-simulation/classical-methods/umbrella-sampling/">Umbrella Sampling</a> <a href="/notes/chemistry/molecular-simulation/">molecular dynamics</a> simulations.</p>
<p>Eight out of ten designs showed high predicted affinity, with $\Delta G$ values ranging from $-10.3 \pm 0.6$ to $-23 \pm 4$ kcal/mol. These affinities are comparable to or exceed those of the closest known active neighbors ($\Delta G = -9.1 \pm 0.8$ to $-13 \pm 2$ kcal/mol). The most potent predicted design (molecule 2, $\Delta G = -23 \pm 4$ kcal/mol) engages extensively with the MAPK1 binding pocket, though synthetic accessibility may be limited. Several designs incorporate halogen substitutions favorable for MAPK1 inhibition, consistent with known structure-activity relationships.</p>
<h2 id="s4-combines-the-best-of-lstms-and-gpts-for-molecular-design">S4 Combines the Best of LSTMs and GPTs for Molecular Design</h2>
<p>The main findings of this study are:</p>
<ol>
<li><strong>S4 outperforms both LSTM and GPT</strong> in learning complex molecular properties like bioactivity, while maintaining competitive or superior performance in syntax learning and chemical space exploration.</li>
<li><strong>The dual formulation is key</strong>: holistic training (convolution) enables better capture of global molecular properties, while recurrent generation preserves robust chemical syntax and diverse scaffold exploration.</li>
<li><strong>S4 is especially strong for longer sequences</strong>: natural product design (SMILES &gt; 100 tokens) shows the largest advantages over benchmarks in validity and property matching.</li>
<li><strong>Prospective validation</strong>: 8/10 S4-designed MAPK1 inhibitors are predicted as highly active by molecular dynamics, with affinities comparable to or exceeding known actives.</li>
</ol>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>All evaluations are computational; no wet-lab experimental validation is reported.</li>
<li>Bioactivity evaluation relies on likelihood-based ranking, which is an indirect proxy.</li>
<li>The MD simulations, while more rigorous than simple docking, still represent in silico predictions.</li>
<li>SMILES augmentation and improved ranking protocols could further boost performance.</li>
</ul>
<p><strong>Future directions</strong> include application to macrocyclic peptides and protein sequences, organic reaction planning, structure-based drug design, and integration with wet-lab experimental validation.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL v31</td>
          <td>1.9M SMILES</td>
          <td>Molecules with SMILES length &lt;= 100 tokens</td>
      </tr>
      <tr>
          <td>Fine-tuning (bioactivity)</td>
          <td>LIT-PCBA (5 targets)</td>
          <td>11-56 actives + ~10K inactives per target</td>
          <td>PKM2, MAPK1, GBA, mTORC1, TP53</td>
      </tr>
      <tr>
          <td>Natural product training</td>
          <td>COCONUT</td>
          <td>32,360 SMILES</td>
          <td>SMILES length &gt; 100 tokens</td>
      </tr>
      <tr>
          <td>Prospective fine-tuning</td>
          <td>ChEMBL v33 (MAPK1)</td>
          <td>68 inhibitors</td>
          <td>$K_i &lt; 1 \mu M$, target ID CHEMBL4040</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Pre-training: next-token prediction on SMILES strings</li>
<li>Fine-tuning: transfer learning with early stopping (patience 5, tolerance $10^{-5}$)</li>
<li>Molecule ranking: log-likelihood scoring with pre-training bias subtraction (Eq. 5)</li>
<li>Temperature sampling: $T$ from 1.0 to 2.0 (step 0.25) for chemical space exploration</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>S4</strong>: Structured state space sequence model with HiPPO initialization; hyperparameter search over 242 + 108 configurations</li>
<li><strong>LSTM</strong>: 40 configurations optimized via random search</li>
<li><strong>GPT</strong>: 35 configurations optimized via random search</li>
<li>All models share the same pre-training data and fine-tuning protocol for fair comparison</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Model</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity (ChEMBL)</td>
          <td>S4</td>
          <td>97%</td>
          <td>Out of 102,400 generated SMILES</td>
      </tr>
      <tr>
          <td>Uniqueness (ChEMBL)</td>
          <td>S4</td>
          <td>96%</td>
          <td>Among valid designs</td>
      </tr>
      <tr>
          <td>Novelty (ChEMBL)</td>
          <td>S4</td>
          <td>93%</td>
          <td>Not in training set</td>
      </tr>
      <tr>
          <td>Bioactivity ranking (top 10)</td>
          <td>S4</td>
          <td>Significant (p = 8.41e-6 vs LSTM)</td>
          <td>Wilcoxon signed-rank test</td>
      </tr>
      <tr>
          <td>NP validity</td>
          <td>S4</td>
          <td>81%</td>
          <td>COCONUT, SMILES &gt; 100 tokens</td>
      </tr>
      <tr>
          <td>MAPK1 inhibitor success</td>
          <td>S4</td>
          <td>8/10 designs active</td>
          <td>Validated by MD (Umbrella Sampling)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Hyperparameter search: NVIDIA A100 40GB GPUs</li>
<li>LSTM/GPT search: 5 days on single A100</li>
<li>S4 search: 10 days on multiple A100 GPUs</li>
<li>MD simulations: Dutch supercomputer Snellius; 1.2-1.6 microseconds per ligand (<a href="/notes/chemistry/molecular-simulation/classical-methods/umbrella-sampling/">Umbrella Sampling</a>)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/molML/s4-for-de-novo-drug-design">S4 for de novo drug design</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with data and trained models</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.12666371">Zenodo archive</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Source data and molecule designs</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ozcelik, R., de Ruiter, S., Criscuolo, E., &amp; Grisoni, F. (2024). Chemical language modeling with structured state space sequence models. <em>Nature Communications</em>, 15, 6176.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ozcelik2024chemical,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Chemical language modeling with structured state space sequence models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{\&#34;O{}z\c{c}elik, R{\i}za and de Ruiter, Sarah and Criscuolo, Emanuele and Grisoni, Francesca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6176}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-024-50469-9}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RNNs vs Transformers for Molecular Generation Tasks</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molecular-language-models-rnns-or-transformer/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molecular-language-models-rnns-or-transformer/</guid><description>Empirical comparison of RNN and Transformer architectures for molecular generation using SMILES and SELFIES across three generative tasks.</description><content:encoded><![CDATA[<h2 id="an-empirical-comparison-of-sequence-architectures-for-molecular-generation">An Empirical Comparison of Sequence Architectures for Molecular Generation</h2>
<p>This is an <strong>Empirical</strong> paper that systematically compares two dominant sequence modeling architectures, recurrent neural networks (RNNs) and the Transformer, for chemical language modeling. The primary contribution is a controlled experimental comparison across three generative tasks of increasing complexity, combined with an evaluation of two molecular string representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>). The paper does not propose a new method; instead, it provides practical guidance on when each architecture is more appropriate for molecular generation.</p>
<h2 id="why-compare-rnns-and-transformers-for-molecular-design">Why Compare RNNs and Transformers for Molecular Design?</h2>
<p>Exploring unknown molecular space and designing molecules with target properties is a central goal in computational drug design. Language models trained on molecular string representations (SMILES, SELFIES) have shown the capacity to learn complex molecular distributions. RNN-based models, including LSTM and GRU variants, were the first widely adopted architectures for this task. Models like <a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">CharRNN</a>, ReLeaSE, and conditional RNNs demonstrated success in generating focused molecular libraries. More recently, self-attention-based Transformer models (Mol-GPT, LigGPT) have gained popularity due to their parallelizability and ability to capture long-range dependencies.</p>
<p>Despite the widespread adoption of Transformers across NLP, it was not clear whether they uniformly outperform RNNs for molecular generation. Prior work by Dollar et al. showed that RNN-based models achieved higher validity than Transformer-based models in some settings. Flam-Shepherd et al. demonstrated that RNN language models could learn complex molecular distributions across challenging generative tasks. This paper extends that comparison by adding the Transformer architecture to the same set of challenging tasks and evaluating both SMILES and SELFIES representations.</p>
<h2 id="experimental-design-three-tasks-two-architectures-two-representations">Experimental Design: Three Tasks, Two Architectures, Two Representations</h2>
<p>The core experimental design uses a 2x2 setup: two architectures (RNN and Transformer) crossed with two molecular representations (SMILES and SELFIES), yielding four model variants: SM-RNN, SF-RNN, SM-Transformer, and SF-Transformer.</p>
<h3 id="three-generative-tasks">Three generative tasks</h3>
<p>The three tasks, drawn from <a href="/notes/chemistry/molecular-design/property-prediction/lm-complex-molecular-distributions/">Flam-Shepherd et al.</a>, are designed with increasing complexity:</p>
<ol>
<li>
<p><strong>Penalized LogP task</strong>: Generate molecules with high penalized LogP scores (LogP minus synthetic accessibility and long-cycle penalties). The dataset is built from ZINC15 molecules with penalized LogP &gt; 4.0. Molecule sequences are relatively short (50-75 tokens).</p>
</li>
<li>
<p><strong>Multidistribution task</strong>: Learn a multimodal molecular weight distribution constructed from four distinct subsets: GDB13 (MW &lt;= 185), ZINC (185 &lt;= MW &lt;= 425), Harvard Clean Energy Project (460 &lt;= MW &lt;= 600), and POLYMERS (MW &gt; 600). This tests the ability to capture multiple modes simultaneously.</p>
</li>
<li>
<p><strong>Large-scale task</strong>: Generate large molecules from PubChem with more than 100 heavy atoms and MW ranging from 1250 to 5000. This tests long-sequence generation capability.</p>
</li>
</ol>
<h3 id="model-configuration">Model configuration</h3>
<p>Models are compared with matched parameter counts (5.2-5.3M to 36.4M parameters). Hyperparameter optimization uses random search over learning rate [0.0001, 0.001], hidden units (500-1000 for RNNs, 376-776 for Transformers), layer number [3, 5], and dropout [0.0, 0.5]. A regex-based tokenizer replaces character-by-character tokenization, reducing token lengths from 10,000 to under 3,000 for large molecules.</p>
<h3 id="evaluation-metrics">Evaluation metrics</h3>
<p>The evaluation covers multiple dimensions:</p>
<ul>
<li><strong>Standard metrics</strong>: validity, uniqueness, novelty</li>
<li><strong>Molecular properties</strong>: <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">FCD</a>, LogP, SA, QED, Bertz complexity (BCT), natural product likeness (NP), molecular weight (MW)</li>
<li><strong>Wasserstein distance</strong>: measures distributional similarity between generated and training molecules for each property</li>
<li><strong>Tanimoto similarity</strong>: structural and scaffold similarity between generated and training molecules</li>
<li><strong>Token length (TL)</strong>: comparison of generated vs. training sequence lengths</li>
</ul>
<p>For each task, 10,000 molecules are generated and evaluated.</p>
<h2 id="key-results-across-tasks">Key Results Across Tasks</h2>
<h3 id="penalized-logp-task">Penalized LogP task</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>FCD</th>
          <th>LogP</th>
          <th>SA</th>
          <th>QED</th>
          <th>BCT</th>
          <th>NP</th>
          <th>MW</th>
          <th>TL</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SM-RNN</td>
          <td>0.56</td>
          <td>0.12</td>
          <td>0.02</td>
          <td>0.01</td>
          <td>16.61</td>
          <td>0.09</td>
          <td>5.90</td>
          <td>0.43</td>
      </tr>
      <tr>
          <td>SF-RNN</td>
          <td>1.63</td>
          <td>0.25</td>
          <td>0.42</td>
          <td>0.02</td>
          <td>36.43</td>
          <td>0.23</td>
          <td>2.35</td>
          <td>0.40</td>
      </tr>
      <tr>
          <td>SM-Transformer</td>
          <td>0.83</td>
          <td>0.18</td>
          <td>0.02</td>
          <td>0.01</td>
          <td>23.77</td>
          <td>0.09</td>
          <td>7.99</td>
          <td>0.84</td>
      </tr>
      <tr>
          <td>SF-Transformer</td>
          <td>1.97</td>
          <td>0.22</td>
          <td>0.47</td>
          <td>0.02</td>
          <td>44.43</td>
          <td>0.28</td>
          <td>5.04</td>
          <td>0.53</td>
      </tr>
  </tbody>
</table>
<p>RNN-based models achieve smaller Wasserstein distances across most properties. The authors attribute this to LogP being computed as a sum of atomic contributions (a local property), which aligns with RNNs&rsquo; strength in capturing local structural features. RNNs also generated ring counts closer to the training distribution (4.10 for SM-RNN vs. 4.04 for SM-Transformer, with training data at 4.21). The Transformer performed better on global structural similarity (higher Tanimoto similarity to training data).</p>
<h3 id="multidistribution-task">Multidistribution task</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>FCD</th>
          <th>LogP</th>
          <th>SA</th>
          <th>QED</th>
          <th>BCT</th>
          <th>NP</th>
          <th>MW</th>
          <th>TL</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SM-RNN</td>
          <td>0.16</td>
          <td>0.07</td>
          <td>0.03</td>
          <td>0.01</td>
          <td>18.34</td>
          <td>0.02</td>
          <td>7.07</td>
          <td>0.81</td>
      </tr>
      <tr>
          <td>SF-RNN</td>
          <td>1.46</td>
          <td>0.38</td>
          <td>0.55</td>
          <td>0.03</td>
          <td>110.72</td>
          <td>0.24</td>
          <td>10.00</td>
          <td>1.58</td>
      </tr>
      <tr>
          <td>SM-Transformer</td>
          <td>0.16</td>
          <td>0.16</td>
          <td>0.03</td>
          <td>0.01</td>
          <td>39.94</td>
          <td>0.02</td>
          <td>10.03</td>
          <td>1.28</td>
      </tr>
      <tr>
          <td>SF-Transformer</td>
          <td>1.73</td>
          <td>0.37</td>
          <td>0.63</td>
          <td>0.04</td>
          <td>107.46</td>
          <td>0.30</td>
          <td>17.57</td>
          <td>2.40</td>
      </tr>
  </tbody>
</table>
<p>Both SMILES-based models captured all four modes of the MW distribution well. While RNNs had smaller overall Wasserstein distances, the Transformer fitted the higher-MW modes better. This aligns with the observation that longer molecular sequences (which correlate with higher MW) favor the Transformer&rsquo;s global attention mechanism over the RNN&rsquo;s sequential processing.</p>
<h3 id="large-scale-task">Large-scale task</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>FCD</th>
          <th>LogP</th>
          <th>SA</th>
          <th>QED</th>
          <th>BCT</th>
          <th>NP</th>
          <th>MW</th>
          <th>TL</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SM-RNN</td>
          <td>0.46</td>
          <td>1.89</td>
          <td>0.20</td>
          <td>0.01</td>
          <td>307.09</td>
          <td>0.03</td>
          <td>105.29</td>
          <td>12.05</td>
      </tr>
      <tr>
          <td>SF-RNN</td>
          <td>1.65</td>
          <td>1.78</td>
          <td>0.43</td>
          <td>0.01</td>
          <td>456.98</td>
          <td>0.14</td>
          <td>100.79</td>
          <td>15.26</td>
      </tr>
      <tr>
          <td>SM-Transformer</td>
          <td>0.36</td>
          <td>1.64</td>
          <td>0.07</td>
          <td>0.01</td>
          <td>172.93</td>
          <td>0.02</td>
          <td>59.04</td>
          <td>7.41</td>
      </tr>
      <tr>
          <td>SF-Transformer</td>
          <td>1.91</td>
          <td>2.82</td>
          <td>0.47</td>
          <td>0.01</td>
          <td>464.75</td>
          <td>0.18</td>
          <td>92.91</td>
          <td>11.57</td>
      </tr>
  </tbody>
</table>
<p>The Transformer demonstrates a clear advantage on large molecules. SM-Transformer achieves substantially lower Wasserstein distances than SM-RNN across nearly all properties, with particularly large improvements in BCT (172.93 vs. 307.09) and MW (59.04 vs. 105.29). The Transformer also produces better Tanimoto similarity scores and more accurate token length distributions.</p>
<h3 id="standard-metrics-across-all-tasks">Standard metrics across all tasks</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>SM-RNN</th>
          <th>SF-RNN</th>
          <th>SM-Transformer</th>
          <th>SF-Transformer</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LogP</td>
          <td>Valid</td>
          <td>0.90</td>
          <td>1.00</td>
          <td>0.89</td>
          <td>1.00</td>
      </tr>
      <tr>
          <td>LogP</td>
          <td>Uniqueness</td>
          <td>0.98</td>
          <td>0.99</td>
          <td>0.98</td>
          <td>0.99</td>
      </tr>
      <tr>
          <td>LogP</td>
          <td>Novelty</td>
          <td>0.75</td>
          <td>0.71</td>
          <td>0.71</td>
          <td>0.71</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>Valid</td>
          <td>0.95</td>
          <td>1.00</td>
          <td>0.97</td>
          <td>1.00</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>Uniqueness</td>
          <td>0.96</td>
          <td>1.00</td>
          <td>1.00</td>
          <td>1.00</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>Novelty</td>
          <td>0.91</td>
          <td>0.98</td>
          <td>0.91</td>
          <td>0.98</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>Valid</td>
          <td>0.84</td>
          <td>1.00</td>
          <td>0.88</td>
          <td>1.00</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>Uniqueness</td>
          <td>0.99</td>
          <td>0.99</td>
          <td>0.98</td>
          <td>0.99</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>Novelty</td>
          <td>0.85</td>
          <td>0.92</td>
          <td>0.86</td>
          <td>0.94</td>
      </tr>
  </tbody>
</table>
<p>SELFIES achieves 100% validity across all tasks by construction, while SMILES validity drops for large molecules. The Transformer achieves slightly higher validity than the RNN for SMILES-based models, particularly on the large-scale task (0.88 vs. 0.84).</p>
<h2 id="conclusions-and-practical-guidelines">Conclusions and Practical Guidelines</h2>
<p>The central finding is that neither architecture universally dominates. The choice between RNNs and Transformers should depend on the characteristics of the molecular data:</p>
<ul>
<li>
<p><strong>RNNs are preferred</strong> when molecular properties depend on local structural features (e.g., LogP, ring counts) and when sequences are relatively short. They better capture local fragment distributions.</p>
</li>
<li>
<p><strong>Transformers are preferred</strong> when dealing with large molecules (high MW, long sequences) where global attention can capture the overall distribution more effectively. RNNs suffer from information obliteration on long sequences.</p>
</li>
<li>
<p><strong>SMILES outperforms SELFIES</strong> on property distribution metrics across nearly all tasks and models. While SELFIES guarantees 100% syntactic validity, its generated molecules show worse distributional fidelity to training data. The authors argue that validity is a less important concern than property fidelity, since invalid SMILES can be filtered easily.</p>
</li>
</ul>
<p>The authors acknowledge that longer sequences remain challenging for both architectures. For Transformers, the quadratic growth of the attention matrix limits scalability. For RNNs, the vanishing gradient problem limits effective context length.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Task 1</td>
          <td>ZINC15 (penalized LogP &gt; 4.0)</td>
          <td>Not specified</td>
          <td>High penalized LogP molecules</td>
      </tr>
      <tr>
          <td>Task 2</td>
          <td><a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> + ZINC + CEP + POLYMERS</td>
          <td>~200K</td>
          <td>Multimodal MW distribution</td>
      </tr>
      <tr>
          <td>Task 3</td>
          <td>PubChem (&gt;100 heavy atoms)</td>
          <td>Not specified</td>
          <td>MW range 1250-5000</td>
      </tr>
  </tbody>
</table>
<p>Data processing code available at <a href="https://github.com/danielflamshep/genmoltasks">https://github.com/danielflamshep/genmoltasks</a> (from the original Flam-Shepherd et al. study).</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: Regex-based tokenizer (not character-by-character)</li>
<li><strong>Hyperparameter search</strong>: Random search over learning rate [0.0001, 0.001], hidden units, layers [3, 5], dropout [0.0, 0.5]</li>
<li><strong>Selection</strong>: Top 20% by sum of valid + unique + novelty, then final selection on all indicators</li>
<li><strong>Generation</strong>: 10K molecules per model per task</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Parameters</th>
          <th>Architecture</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN variants</td>
          <td>5.2M - 36.4M</td>
          <td>RNN (LSTM/GRU)</td>
      </tr>
      <tr>
          <td>Transformer variants</td>
          <td>5.3M - 36.4M</td>
          <td>Transformer decoder</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>Wasserstein distance for property distributions (FCD, LogP, SA, QED, BCT, NP, MW, TL), Tanimoto similarity (molecular and scaffold), validity, uniqueness, novelty.</p>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/viko-3/language_model">trans_language</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Transformer implementation by the authors</td>
      </tr>
      <tr>
          <td><a href="https://github.com/danielflamshep/genmoltasks">genmoltasks</a></td>
          <td>Code/Data</td>
          <td>Apache-2.0</td>
          <td>Dataset construction from Flam-Shepherd et al.</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, Y., Wang, Z., Zeng, X., Li, Y., Li, P., Ye, X., &amp; Sakurai, T. (2023). Molecular language models: RNNs or transformer? <em>Briefings in Functional Genomics</em>, 22(4), 392-400. <a href="https://doi.org/10.1093/bfgp/elad012">https://doi.org/10.1093/bfgp/elad012</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chen2023molecular,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecular language models: RNNs or transformer?}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Yangyang and Wang, Zixu and Zeng, Xiangxiang and Li, Yayang and Li, Pengyong and Ye, Xiucai and Sakurai, Tetsuya}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Functional Genomics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{392--400}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bfgp/elad012}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Review: Deep Learning for Molecular Design (2019)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/deep-learning-molecular-design-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/deep-learning-molecular-design-review/</guid><description>A 2019 review surveying deep generative models for molecular design, covering RNNs, VAEs, GANs, and RL approaches with SMILES and graph representations.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-deep-generative-models-for-molecular-design">A Systematization of Deep Generative Models for Molecular Design</h2>
<p>This is a <strong>Systematization</strong> paper that organizes and compares the rapidly growing literature on deep generative modeling for molecules. Published in 2019, it catalogs 45 papers from the preceding two years, classifying them by architecture (RNNs, VAEs, GANs, reinforcement learning) and molecular representation (SMILES strings, context-free grammars, graph tensors, 3D voxels). The review provides mathematical foundations for each technique, identifies cross-cutting themes, and proposes a framework for reward function design that addresses diversity, novelty, stability, and synthesizability.</p>
<h2 id="the-challenge-of-navigating-vast-chemical-space">The Challenge of Navigating Vast Chemical Space</h2>
<p>The space of potential drug-like molecules has been estimated to contain between $10^{23}$ and $10^{60}$ compounds, while only about $10^{8}$ have ever been synthesized. Traditional approaches to molecular design rely on combinatorial methods, mixing known scaffolds and functional groups, but these generate many unstable or unsynthesizable candidates. High-throughput screening (HTS) and virtual screening (HTVS) help but remain computationally expensive. The average cost to bring a new drug to market exceeds one billion USD, with a 13-year average timeline from discovery to market.</p>
<p>By 2016, <a href="/notes/machine-learning/generative-models/">deep generative models</a> had shown strong results in producing original images, music, and text. The &ldquo;molecular autoencoder&rdquo; of <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al. (2016/2018)</a> first applied these techniques to molecular generation, triggering an explosion of follow-up work. By the time of this review, the landscape had grown complex enough, with many architectures, representation schemes, and no agreed-upon benchmarking standards, to warrant systematic organization.</p>
<h2 id="molecular-representations-and-architecture-taxonomy">Molecular Representations and Architecture Taxonomy</h2>
<p>The review&rsquo;s core organizational contribution is a two-axis taxonomy: molecular representations on one axis and deep learning architectures on the other.</p>
<h3 id="molecular-representations">Molecular Representations</h3>
<p>The review categorizes representations into 3D and 2D graph-based schemes:</p>
<p><strong>3D representations</strong> include raw voxels (placing nuclear charges on a grid), smoothed voxels (Gaussian blurring around nuclei), and tensor field networks. These capture full geometric information but suffer from high dimensionality, sparsity, and difficulty encoding rotation/translation invariance.</p>
<p><strong>2D graph representations</strong> include:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings</strong>: The dominant representation, encoding molecular graphs as ASCII character sequences via depth-first traversal. Non-unique (each molecule with $N$ heavy atoms has at least $N$ SMILES representations), but invertible and widely supported.</li>
<li><strong>Canonical SMILES</strong>: Unique but potentially encode grammar rules rather than chemical structure.</li>
<li><strong>Context-free grammars (CFGs)</strong>: Decompose SMILES into grammar rules to improve validity rates, though not to 100%.</li>
<li><strong>Tensor representations</strong>: Store atom types in a vertex feature matrix $X \in \mathbb{R}^{N \times |\mathcal{A}|}$ and bond types in an adjacency tensor $A \in \mathbb{R}^{N \times N \times Y}$.</li>
<li><strong>Graph operations</strong>: Directly build molecular graphs by adding atoms and bonds, guaranteeing 100% chemical validity.</li>
</ul>
<h3 id="deep-learning-architectures">Deep Learning Architectures</h3>
<p><strong>Recurrent Neural Networks (RNNs)</strong> generate SMILES strings character by character, typically using LSTM or GRU units. Training uses maximum likelihood estimation (MLE) with teacher forcing:</p>
<p>$$
L^{\text{MLE}} = -\sum_{s \in \mathcal{X}} \sum_{t=2}^{T} \log \pi_{\theta}(s_{t} \mid S_{1:t-1})
$$</p>
<p>Thermal rescaling of the output distribution controls the diversity-validity tradeoff via a temperature parameter $T$. RNNs achieved SMILES validity rates of 94-98%.</p>
<p><strong><a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational Autoencoders (VAEs)</a></strong> learn a continuous latent space by maximizing the evidence lower bound (ELBO):</p>
<p>$$
\mathcal{L}_{\theta,\phi}(x) = \mathbb{E}_{z \sim q_{\phi}(z|x)}[\log p_{\theta}(x|z)] - D_{\text{KL}}[q_{\phi}(z|x), p(z)]
$$</p>
<p>The first term encourages accurate reconstruction while the KL divergence term regularizes the latent distribution toward a standard Gaussian prior $p(z) = \mathcal{N}(z, 0, I)$. Variants include <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">grammar VAEs</a> (GVAEs), syntax-directed VAEs, junction tree VAEs, and adversarial autoencoders (AAEs) that replace the KL term with adversarial training.</p>
<p><strong><a href="/posts/what-is-a-gan/">Generative Adversarial Networks (GANs)</a></strong> train a generator against a discriminator using the minimax objective:</p>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{d}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))]
$$</p>
<p>The review shows that with an optimal discriminator, the generator objective reduces to minimizing the Jensen-Shannon divergence, which captures both forward and reverse KL divergence terms. This provides a more &ldquo;balanced&rdquo; training signal than MLE alone. The Wasserstein GAN (WGAN) uses the Earth mover&rsquo;s distance for more stable training:</p>
<p>$$
W(p, q) = \inf_{\gamma \in \Pi(p,q)} \mathbb{E}_{(x,y) \sim \gamma} |x - y|
$$</p>
<p><strong>Reinforcement Learning</strong> recasts molecular generation as a sequential decision problem. The policy gradient (REINFORCE) update is:</p>
<p>$$
\nabla J(\theta) = \mathbb{E}\left[G_{t} \frac{\nabla_{\theta} \pi_{\theta}(a_{t} \mid y_{1:t-1})}{\pi_{\theta}(a_{t} \mid y_{1:t-1})}\right]
$$</p>
<p>To prevent RL fine-tuning from causing the generator to &ldquo;drift&rdquo; away from viable chemical structures, an augmented reward function incorporates the prior likelihood:</p>
<p>$$
R&rsquo;(S) = [\sigma R(S) + \log P_{\text{prior}}(S) - \log P_{\text{current}}(S)]^{2}
$$</p>
<h2 id="cataloging-45-models-and-their-design-choices">Cataloging 45 Models and Their Design Choices</h2>
<p>Rather than running new experiments, the review&rsquo;s methodology involves systematically cataloging and comparing 45 published models. Table 2 in the paper lists each model&rsquo;s architecture, representation, training dataset, and dataset size. Key patterns include:</p>
<ul>
<li><strong>RNN-based models</strong> (16 entries): Almost exclusively use SMILES, trained on ZINC or ChEMBL datasets with 0.1M-1.7M molecules.</li>
<li><strong>VAE variants</strong> (20 entries): The most diverse category, spanning SMILES VAEs, grammar VAEs, junction tree VAEs, graph-based VAEs, and 3D VAEs. Training sets range from 10K to 72M molecules.</li>
<li><strong>GAN models</strong> (7 entries): Include <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a>, RANC, ATNC, MolGAN, and CycleGAN approaches. Notably, GANs appear to work with fewer training samples.</li>
<li><strong>Other approaches</strong> (2 entries): Pure RL methods from Zhou et al. and Stahl et al. that do not require pretraining on a dataset.</li>
</ul>
<p>The review also catalogs 13 publicly available datasets (Table 3), ranging from <a href="/notes/chemistry/datasets/qm9/">QM9</a> (133K molecules with quantum chemical properties) to <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> (977M combinatorially generated molecules) and ZINC15 (750M+ commercially available compounds).</p>
<h3 id="metrics-and-reward-function-design">Metrics and Reward Function Design</h3>
<p>A significant contribution is the systematic treatment of reward functions. The review argues that generated molecules should satisfy six desiderata: diversity, novelty, stability, synthesizability, non-triviality, and good properties. Key metrics formalized include:</p>
<p><strong>Diversity</strong> using Tanimoto similarity over fingerprints:</p>
<p>$$
r_{\text{diversity}} = 1 - \frac{1}{|\mathcal{G}|} \sum_{(x_{1}, x_{2}) \in \mathcal{G} \times \mathcal{G}} D(x_{1}, x_{2})
$$</p>
<p><strong>Novelty</strong> measured as the fraction of generated molecules not appearing in a hold-out test set:</p>
<p>$$
r_{\text{novel}} = 1 - \frac{|\mathcal{G} \cap \mathcal{T}|}{|\mathcal{T}|}
$$</p>
<p><strong>Synthesizability</strong> primarily assessed via the SA score, sometimes augmented with ring penalties and medicinal chemistry filters.</p>
<p>The review also discusses the <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> as an analog of FID for molecular generation, and notes the emergence of standardized benchmarking platforms including <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>, and DiversityNet.</p>
<h2 id="key-findings-and-future-directions">Key Findings and Future Directions</h2>
<p>The review identifies several major trends and conclusions:</p>
<p><strong>Shift from SMILES to graph-based representations.</strong> SMILES-based methods struggle with validity (the molecular autoencoder VAE achieved only 0.7-75% valid SMILES depending on sampling strategy). Methods that work directly on molecular graphs with chemistry-preserving operations achieve 100% validity, and the review predicts this trend will continue.</p>
<p><strong>Advantages of adversarial and RL training over MLE.</strong> The mathematical analysis shows that MLE only optimizes forward KL divergence, which can lead to models that place probability mass where the data distribution is zero. GAN training optimizes the Jensen-Shannon divergence, which balances forward and reverse KL terms. RL approaches, particularly pure RL without pretraining, showed competitive performance with much less training data.</p>
<p><strong>Genetic algorithms remain competitive.</strong> The review notes that the latest genetic algorithm approaches (Grammatical Evolution) could match deep learning methods for molecular optimization under some metrics, and at 100x lower computational cost in some comparisons. This serves as an important baseline calibration.</p>
<p><strong>Reward function design is underappreciated.</strong> Early models generated unstable molecules with labile groups (enamines, hemiaminals, enol ethers). Better reward functions that incorporate synthesizability, diversity, and stability constraints significantly improved practical utility.</p>
<p><strong>Need for standardized benchmarks.</strong> The review identifies a lack of agreement on evaluation methodology as a major barrier to progress, noting that published comparisons are often subtly biased toward novel methods.</p>
<h3 id="limitations">Limitations</h3>
<p>As a review paper from early 2019, the work predates several important developments: transformer-based architectures (which would soon dominate), SELFIES representations, diffusion models for molecules, and large-scale pretrained chemical language models. The review focuses primarily on drug-like small molecules and does not deeply cover protein design or materials optimization.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a review paper that does not present new experimental results. The paper catalogs 13 publicly available datasets used across the reviewed works:</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td><a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a></td>
          <td>977M</td>
          <td>Combinatorially generated library</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>ZINC15</td>
          <td>750M+</td>
          <td>Commercially available compounds</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td><a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a></td>
          <td>50M</td>
          <td>Combinatorially generated library</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>ChEMBL</td>
          <td>2M</td>
          <td>Curated bioactive molecules</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>QM9</td>
          <td>133,885</td>
          <td>Small organic molecules with DFT properties</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>PubChemQC</td>
          <td>3.98M</td>
          <td>PubChem compounds with DFT data</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The review provides mathematical derivations for MLE training (Eq. 1), VAE ELBO (Eqs. 9-13), AAE objectives (Eqs. 15-16), GAN objectives (Eqs. 19-22), WGAN (Eq. 24), REINFORCE gradient (Eq. 7), and numerous reward function formulations (Eqs. 26-36).</p>
<h3 id="evaluation">Evaluation</h3>
<p>Key evaluation frameworks discussed:</p>
<ul>
<li><a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> (molecular analog of FID)</li>
<li><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> benchmarking platform</li>
<li><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> benchmarking suite</li>
<li>Validity rate, uniqueness, novelty, and internal diversity metrics</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Elton, D. C., Boukouvalas, Z., Fuge, M. D., &amp; Chung, P. W. (2019). Deep Learning for Molecular Design: A Review of the State of the Art. <em>Molecular Systems Design &amp; Engineering</em>, 4(4), 828-849. <a href="https://doi.org/10.1039/C9ME00039A">https://doi.org/10.1039/C9ME00039A</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{elton2019deep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Deep Learning for Molecular Design -- A Review of the State of the Art}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Elton, Daniel C. and Boukouvalas, Zois and Fuge, Mark D. and Chung, Peter W.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Molecular Systems Design \&amp; Engineering}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{828--849}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/C9ME00039A}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>REINVENT 4: Open-Source Generative Molecule Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/</guid><description>REINVENT 4 is an open-source generative AI framework combining RNNs and transformers with reinforcement and curriculum learning for de novo molecular design.</description><content:encoded><![CDATA[<h2 id="an-open-source-reference-implementation-for-generative-molecular-design">An Open-Source Reference Implementation for Generative Molecular Design</h2>
<p>REINVENT 4 is a <strong>Resource</strong> paper presenting a production-grade, open-source software framework for AI-driven generative molecular design. The primary contribution is the unified codebase that integrates four distinct molecule generators (de novo, scaffold decoration, linker design, molecular optimization) within three machine learning optimization algorithms (transfer learning, reinforcement learning, <a href="/notes/chemistry/molecular-design/generation/rl-tuned/curriculum-learning-molecular-design/">curriculum learning</a>). The software is released under the Apache 2.0 license and represents the fourth major version of the REINVENT platform, which has been in continuous production use at AstraZeneca for drug discovery.</p>
<h2 id="bridging-the-gap-between-research-prototypes-and-production-molecular-design">Bridging the Gap Between Research Prototypes and Production Molecular Design</h2>
<p>The motivation for REINVENT 4 stems from several gaps in the generative molecular design landscape. While numerous AI model architectures have been developed for molecular generation (<a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">VAEs</a>, GANs, RNNs, transformers, flow models, diffusion models), most exist as research prototypes released alongside individual publications rather than as maintained, integrated software. The authors argue that the scientific community needs reference implementations of common generative molecular design algorithms in the public domain to:</p>
<ol>
<li>Enable nuanced debate about the application of AI in drug discovery</li>
<li>Serve as educational tools for practitioners entering the field</li>
<li>Increase transparency around AI-driven molecular design</li>
<li>Provide a foundation for future innovation</li>
</ol>
<p>REINVENT 4 consolidates previously separate codebases (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> v1, v2, LibInvent, LinkInvent, Mol2Mol) into a single repository with a consistent interface, addressing the fragmentation that characterized earlier releases.</p>
<h2 id="unified-framework-for-sequence-based-molecular-generation">Unified Framework for Sequence-Based Molecular Generation</h2>
<p>The core design of REINVENT 4 centers on sequence-based neural network models that generate <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings in an autoregressive manner. All generators model the probability of producing a token sequence, with two formulations.</p>
<p>For unconditional agents (de novo generation), the joint probability of a sequence $T$ with tokens $t_1, t_2, \ldots, t_\ell$ is:</p>
<p>$$
\mathbf{P}(T) = \prod_{i=1}^{\ell} \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1)
$$</p>
<p>For conditional agents (scaffold decoration, linker design, molecular optimization), the joint probability given an input sequence $S$ is:</p>
<p>$$
\mathbf{P}(T \mid S) = \prod_{i=1}^{\ell} \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1, S)
$$</p>
<p>The negative log-likelihood for unconditional agents is:</p>
<p>$$
NLL(T) = -\log \mathbf{P}(T) = -\sum_{i=1}^{\ell} \log \mathbf{P}(t_i \mid t_{i-1}, t_{i-2}, \ldots, t_1)
$$</p>
<h3 id="reinforcement-learning-with-dap">Reinforcement Learning with DAP</h3>
<p>The key optimization mechanism is reinforcement learning via the &ldquo;Difference between Augmented and Posterior&rdquo; (DAP) strategy. For each generated sequence $T$, the augmented likelihood is defined as:</p>
<p>$$
\log \mathbf{P}_{\text{aug}}(T) = \log \mathbf{P}_{\text{prior}}(T) + \sigma \mathbf{S}(T)
$$</p>
<p>where $\mathbf{S}(T) \in [0, 1]$ is the scalar score and $\sigma \geq 0$ controls the balance between reward and regularization. The DAP loss is:</p>
<p>$$
\mathcal{L}(T) = \left(\log \mathbf{P}_{\text{aug}}(T) - \log \mathbf{P}_{\text{agent}}(T)\right)^2
$$</p>
<p>The presence of the prior likelihood in the augmented likelihood constrains how far the agent can deviate from chemically plausible space, functioning similarly to proximal policy gradient methods. The loss is lower-bounded by:</p>
<p>$$
\mathcal{L}(T) \geq \max\left(0, \log \mathbf{P}_{\text{prior}}(T) + \sigma \mathbf{S}(T)\right)^2
$$</p>
<h3 id="four-molecule-generators">Four Molecule Generators</h3>
<p>REINVENT 4 supports four generator types:</p>
<table>
  <thead>
      <tr>
          <th>Generator</th>
          <th>Architecture</th>
          <th>Input</th>
          <th>Task</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Reinvent</td>
          <td>RNN</td>
          <td>None</td>
          <td>De novo design from scratch</td>
      </tr>
      <tr>
          <td>LibInvent</td>
          <td>RNN</td>
          <td>Scaffold SMILES</td>
          <td>R-group replacement, library design</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/">LinkInvent</a></td>
          <td>RNN</td>
          <td>Two warhead fragments</td>
          <td>Linker design, scaffold hopping</td>
      </tr>
      <tr>
          <td>Mol2Mol</td>
          <td>Transformer</td>
          <td>Input molecule</td>
          <td>Molecular optimization within similarity bounds</td>
      </tr>
  </tbody>
</table>
<p>All generators are fully integrated with all three optimization algorithms (TL, RL, CL). The Mol2Mol transformer was trained on over 200 billion molecular pairs from PubChem with <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> $\geq 0.50$, using ranking loss to directly link negative log-likelihood to molecular similarity.</p>
<h3 id="staged-learning-curriculum-learning">Staged Learning (Curriculum Learning)</h3>
<p>A key new feature is staged learning, which implements curriculum learning as multi-stage RL. Each stage can define a different scoring profile, allowing users to gradually phase in computationally expensive scoring functions. For example, cheap drug-likeness filters can run first, followed by docking in later stages. Stages terminate when a maximum score threshold is exceeded or a step limit is reached.</p>
<h3 id="scoring-subsystem">Scoring Subsystem</h3>
<p>The scoring subsystem implements a plugin architecture supporting over 25 scoring components, including:</p>
<ul>
<li>Physicochemical descriptors from RDKit (QED, SLogP, TPSA, molecular weight, etc.)</li>
<li>Molecular docking via DockStream (<a href="https://en.wikipedia.org/wiki/AutoDock">AutoDock Vina</a>, rDock, Hybrid, Glide, GOLD)</li>
<li>QSAR models via Qptuna and ChemProp (D-MPNN)</li>
<li>Shape similarity via ROCS</li>
<li>Synthesizability estimation via SA score</li>
<li>Matched molecular pairs via mmpdb</li>
<li>Generic REST and external process interfaces</li>
</ul>
<p>Scores are aggregated via weighted arithmetic or geometric mean. A transform system (sigmoid, step functions, value maps) normalizes individual component scores to $[0, 1]$.</p>
<h2 id="pdk1-inhibitor-case-study">PDK1 Inhibitor Case Study</h2>
<p>The paper demonstrates REINVENT 4 through a structure-based drug design exercise targeting <a href="https://en.wikipedia.org/wiki/PDPK1">Phosphoinositide-dependent kinase-1 (PDK1)</a> inhibitors. The experimental setup uses PDB crystal structure 2XCH with DockStream and Glide for docking, defining hits as molecules with docking score $\leq -8$ kcal/mol and QED $\geq 0.7$.</p>
<p><strong>Baseline RL from prior</strong>: 50 epochs of staged learning with batch size 128 produced 119 hits from 6,400 generated molecules (1.9% hit rate), spread across 103 generic Bemis-Murcko scaffolds.</p>
<p><strong>Transfer learning + RL</strong>: After 10 epochs of TL on 315 congeneric pyridinone PDK1 actives from PubChem Assay AID1798002, the same 50-epoch RL run produced 222 hits (3.5% hit rate) across 176 unique generic scaffolds, nearly doubling productivity.</p>
<p>Both approaches generated top-scoring molecules (docking score of -10.1 kcal/mol each) with plausible binding poses reproducing key protein-ligand interactions seen in the native crystal structure, including hinge interactions with ALA 162 and contacts with LYS 111.</p>
<p>The paper also demonstrates the agent&rsquo;s plasticity through a molecular weight switching experiment: after 500 epochs driving generation toward 1500 Da molecules, switching the reward to favor molecules $\leq 500$ Da resulted in rapid adaptation within ~50 epochs, showing that the RL agent can recover from extreme biases.</p>
<h2 id="practical-software-for-ai-driven-drug-discovery">Practical Software for AI-Driven Drug Discovery</h2>
<p>REINVENT 4 represents a mature, well-documented framework that consolidates years of incremental development into a single codebase. Key practical features include TOML/JSON configuration, TensorBoard visualization, multinomial sampling and beam search decoding, diversity filters for scaffold-level novelty, experience replay (inception), and a plugin mechanism for extending the scoring subsystem.</p>
<p>The authors acknowledge that this is one approach among many and that there is no single solution that uniformly outperforms others. REINVENT has demonstrated strong sample efficiency in benchmarks and produced realistic 3D docking poses, but the paper does not claim universal superiority. The focus is on providing a well-engineered, transparent reference implementation rather than advancing a novel algorithm.</p>
<p>Limitations include that only the Mol2Mol prior supports stereochemistry, the training data biases constrain the explorable chemical space, and the SMILES-based representation inherits the known fragility of string-based molecular encodings.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior training (Reinvent)</td>
          <td>ChEMBL 25</td>
          <td>~1.7M molecules</td>
          <td>Drug-like compounds</td>
      </tr>
      <tr>
          <td>Prior training (LibInvent)</td>
          <td>ChEMBL 27</td>
          <td>~1.9M molecules</td>
          <td>Scaffold-decoration pairs</td>
      </tr>
      <tr>
          <td>Prior training (LinkInvent)</td>
          <td>ChEMBL 27</td>
          <td>~1.9M molecules</td>
          <td>Fragment-linker pairs</td>
      </tr>
      <tr>
          <td>Prior training (Mol2Mol)</td>
          <td>ChEMBL 28 / PubChem</td>
          <td>~200B pairs</td>
          <td>Tanimoto similarity $\geq 0.50$</td>
      </tr>
      <tr>
          <td>Case study TL</td>
          <td>PubChem AID1798002</td>
          <td>315 compounds</td>
          <td>Congeneric PDK1 actives</td>
      </tr>
      <tr>
          <td>Case study docking</td>
          <td>PDB 2XCH</td>
          <td>1 structure</td>
          <td>PDK1 crystal structure</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimization</strong>: DAP (recommended), plus three deprecated alternatives (REINFORCE, A2C, MAULI)</li>
<li><strong>Decoding</strong>: Multinomial sampling (default, temperature $K = 1$) and beam search</li>
<li><strong>Diversity filter</strong>: Murcko scaffold, topological scaffold, scaffold similarity, same-SMILES penalty</li>
<li><strong>Experience replay</strong>: Inception memory with configurable size and sampling rate</li>
<li><strong>Gradient descent</strong>: Adam optimizer</li>
</ul>
<h3 id="models">Models</h3>
<p>All pre-trained priors are distributed with the repository. RNN-based generators (Reinvent, LibInvent, LinkInvent) and transformer-based generator (Mol2Mol) with multiple similarity-conditioned variants.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Condition</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Hit rate (RL)</td>
          <td>1.9%</td>
          <td>50 epochs, batch 128</td>
          <td>PDK1 case study</td>
      </tr>
      <tr>
          <td>Hit rate (TL+RL)</td>
          <td>3.5%</td>
          <td>10 TL + 50 RL epochs</td>
          <td>PDK1 case study</td>
      </tr>
      <tr>
          <td>Scaffold diversity (RL)</td>
          <td>103 scaffolds</td>
          <td>From 119 hits</td>
          <td>Generic Bemis-Murcko</td>
      </tr>
      <tr>
          <td>Scaffold diversity (TL+RL)</td>
          <td>176 scaffolds</td>
          <td>From 222 hits</td>
          <td>Generic Bemis-Murcko</td>
      </tr>
      <tr>
          <td>Best docking score</td>
          <td>-10.1 kcal/mol</td>
          <td>Both methods</td>
          <td>Glide SP</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify hardware requirements. REINVENT 4 supports both GPU and CPU execution. Python 3.10+ is required, with PyTorch 1.x (2.0 also compatible) and RDKit 2022.9+.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MolecularAI/REINVENT4">REINVENT4</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Full framework with pre-trained priors</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MolecularAI/DockStream">DockStream</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Docking wrapper for scoring</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Loeffler, H. H., He, J., Tibo, A., Janet, J. P., Voronov, A., Mervin, L. H., &amp; Engkvist, O. (2024). Reinvent 4: Modern AI-driven generative molecule design. <em>Journal of Cheminformatics</em>, 16, 20. <a href="https://doi.org/10.1186/s13321-024-00812-5">https://doi.org/10.1186/s13321-024-00812-5</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{loeffler2024reinvent,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Reinvent 4: Modern AI-driven generative molecule design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Loeffler, Hannes H. and He, Jiazhen and Tibo, Alessandro and Janet, Jon Paul and Voronov, Alexey and Mervin, Lewis H. and Engkvist, Ola}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-024-00812-5}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Randomized SMILES Improve Molecular Generative Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/</guid><description>Randomized SMILES improve RNN molecular generative models by increasing chemical space coverage, uniformity, and completeness versus canonical SMILES.</description><content:encoded><![CDATA[<h2 id="data-augmentation-through-smiles-randomization">Data Augmentation Through SMILES Randomization</h2>
<p>This is an <strong>Empirical</strong> paper that performs an extensive benchmark of RNN-based molecular generative models trained with different SMILES string variants. The primary contribution is demonstrating that randomized SMILES (non-unique molecular string representations obtained by randomizing atom orderings) substantially improve the quality of the generated chemical space compared to canonical SMILES, without requiring any changes to the model architecture.</p>
<p>The paper evaluates three properties of generated chemical spaces: uniformity (equal probability of sampling each molecule), completeness (coverage of the target space), and closedness (generating only molecules within the target space). These are measured using a new composite metric called UC-JSD.</p>
<h2 id="canonical-smiles-bias-in-generative-models">Canonical SMILES Bias in Generative Models</h2>
<p>Recurrent Neural Networks trained on SMILES strings have shown the capacity to create large chemical spaces of valid molecules. However, when trained with canonical SMILES (the unique string representation produced by a canonicalization algorithm), these models exhibit biases. Specifically, prior work by the same group showed that models trained on one million <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> molecules could only recover 68% of GDB-13 when sampled two billion times, compared to the theoretical maximum of 87% from an ideal uniform sampler.</p>
<p>The canonical SMILES representation introduces two problems. First, the canonicalization algorithm constrains how the molecular graph is traversed (e.g., prioritizing sidechains over ring atoms), forcing the model to learn both valid SMILES syntax and the specific canonical ordering rules. Second, structurally similar molecules can have substantially different canonical SMILES, making some molecules harder to sample than others. Molecules with more ring systems and complex topologies are particularly underrepresented.</p>
<p>The authors also note that DeepSMILES, a recently proposed alternative syntax, had not been benchmarked against randomized SMILES, and that the data augmentation capabilities of randomized SMILES at different training set sizes were unexplored.</p>
<h2 id="randomized-smiles-as-non-canonical-representations">Randomized SMILES as Non-Canonical Representations</h2>
<p>The core insight is that by randomizing the atom ordering before SMILES generation, each molecule can be represented by multiple different but equally valid SMILES strings. This effectively provides data augmentation: a molecule with $n$ heavy atoms can theoretically yield up to $n$ different SMILES strings (though the actual number is typically lower due to molecular symmetry).</p>
<p>Two randomized SMILES variants are explored:</p>
<ul>
<li><strong>Restricted randomized SMILES</strong>: Atom ordering is randomized, but RDKit&rsquo;s built-in fixes are applied. These fixes prevent overly complicated traversals, such as prioritizing sidechains before completing ring atoms.</li>
<li><strong>Unrestricted randomized SMILES</strong>: Atom ordering is randomized without any RDKit restrictions, producing a superset of the restricted variant that includes more convoluted SMILES strings.</li>
</ul>
<p>For each training epoch, a new set of randomized SMILES is generated for the same molecules, so a model trained for 300 epochs on one million molecules sees approximately 300 million different SMILES strings (with some overlap due to sampling).</p>
<p>The model architecture is a standard RNN with an embedding layer, $l$ layers of LSTM or GRU cells of size $w$, optional dropout, and a linear output layer with softmax. The training objective minimizes the average negative log-likelihood (NLL):</p>
<p>$$
J(T) = -\ln P(X_{0} = x_{0}) - \sum_{t=1}^{T} \ln P(X_{t} = x_{t} \mid X_{t-1} = x_{t-1} \dots X_{1} = x_{1})
$$</p>
<p>The key metric is the Uniformity-Completeness JSD (UC-JSD), which extends the Jensen-Shannon Divergence to measure how uniform, complete, and closed the generated chemical space is:</p>
<p>$$
JSD = H\left(\sum_{d \in D} \alpha_{i} \cdot d_{i}\right) - \sum_{d \in D} \alpha_{i} H(d_{i})
$$</p>
<p>where $H(d)$ is the Shannon entropy of a probability distribution. The UC-JSD is computed over the NLL vectors of the validation, training, and sampled sets. The composite UCC score is defined as:</p>
<p>$$
UCC = \text{completeness} \times \text{uniformity} \times \text{closedness}
$$</p>
<p>where completeness measures coverage of GDB-13, uniformity measures how equal the sampling probabilities are, and closedness measures how few invalid (out-of-target-space) molecules are generated.</p>
<h2 id="benchmark-design-across-smiles-variants-training-sizes-and-architectures">Benchmark Design Across SMILES Variants, Training Sizes, and Architectures</h2>
<p>The benchmark covers a systematic grid of experimental conditions:</p>
<p><strong>SMILES variants</strong>: Canonical, restricted randomized, unrestricted randomized, and three DeepSMILES variants (branch syntax, ring syntax, both).</p>
<p><strong>Training set sizes from GDB-13</strong>: 1,000,000, 10,000, and 1,000 molecules with corresponding validation sets.</p>
<p><strong>Architecture choices</strong>: LSTM vs. GRU cells, with hyperparameter grids over number of layers ($l$), hidden size ($w$), dropout rate ($d$), and batch size ($b$).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers ($l$)</th>
          <th>Hidden ($w$)</th>
          <th>Dropout ($d$)</th>
          <th>Batch ($b$)</th>
          <th>Cell</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GDB-13 1M</td>
          <td>3</td>
          <td>512</td>
          <td>0, 25, 50</td>
          <td>64, 128, 256, 512</td>
          <td>GRU, LSTM</td>
      </tr>
      <tr>
          <td>GDB-13 10K</td>
          <td>2, 3, 4</td>
          <td>256, 384, 512</td>
          <td>0, 25, 50</td>
          <td>8, 16, 32</td>
          <td>LSTM</td>
      </tr>
      <tr>
          <td>GDB-13 1K</td>
          <td>2, 3, 4</td>
          <td>128, 192, 256</td>
          <td>0, 25, 50</td>
          <td>4, 8, 16</td>
          <td>LSTM</td>
      </tr>
      <tr>
          <td>ChEMBL</td>
          <td>3</td>
          <td>512</td>
          <td>0, 25, 50</td>
          <td>64, 128, 256, 512</td>
          <td>LSTM</td>
      </tr>
  </tbody>
</table>
<p>Each model&rsquo;s best epoch was selected using a smoothed UC-JSD curve, and the best epoch was then sampled with replacement $k = 2 \times 10^{9}$ times for GDB-13 benchmarks.</p>
<p>For ChEMBL experiments, models were trained on 1,483,943 molecules with a validation set of 78,102 molecules. Evaluation used validity, unique molecule count, and Frechet ChemNet Distance (FCD).</p>
<h2 id="randomized-smiles-produce-more-complete-and-uniform-chemical-spaces">Randomized SMILES Produce More Complete and Uniform Chemical Spaces</h2>
<h3 id="gdb-13-results-1m-training-set">GDB-13 results (1M training set)</h3>
<p>The restricted randomized SMILES model recovered 83.0% of GDB-13, compared to 72.8% for canonical SMILES and 68.4-72.1% for DeepSMILES variants. All three quality metrics improved substantially:</p>
<table>
  <thead>
      <tr>
          <th>SMILES Variant</th>
          <th>% GDB-13</th>
          <th>Uniformity</th>
          <th>Completeness</th>
          <th>Closedness</th>
          <th>UCC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Canonical</td>
          <td>72.8</td>
          <td>0.879</td>
          <td>0.836</td>
          <td>0.861</td>
          <td>0.633</td>
      </tr>
      <tr>
          <td>Rand. restricted</td>
          <td>83.0</td>
          <td>0.977</td>
          <td>0.953</td>
          <td>0.925</td>
          <td>0.860</td>
      </tr>
      <tr>
          <td>Rand. unrestricted</td>
          <td>80.9</td>
          <td>0.970</td>
          <td>0.929</td>
          <td>0.876</td>
          <td>0.790</td>
      </tr>
      <tr>
          <td>DeepSMILES (both)</td>
          <td>68.4</td>
          <td>0.851</td>
          <td>0.785</td>
          <td>0.796</td>
          <td>0.532</td>
      </tr>
  </tbody>
</table>
<p>The NLL distribution of GDB-13 molecules under the randomized SMILES model was centered near $NLL_{GDB13} = -\ln(1/|GDB13|) = 20.6$ with a narrow spread, indicating near-uniform sampling probability. The canonical model showed a much wider NLL distribution, meaning some molecules were orders of magnitude harder to sample.</p>
<p>Randomized SMILES without data augmentation (same SMILES each epoch) still outperformed canonical SMILES (UCC 0.712 vs. 0.633 for restricted), confirming that the non-canonical representation itself is beneficial beyond the augmentation effect.</p>
<h3 id="smaller-training-sets-amplify-the-advantage">Smaller training sets amplify the advantage</h3>
<p>With only 10,000 training molecules (0.001% of GDB-13), the randomized model generated 62.3% of GDB-13 vs. 38.8% for canonical. With 1,000 training molecules, the gap widened further: 34.1% vs. 14.5%. Validity also improved dramatically (81.2% vs. 50.4% for the 1K setting), suggesting randomized SMILES helps the model learn valid SMILES syntax more effectively from limited data.</p>
<h3 id="chembl-results">ChEMBL results</h3>
<p>On the drug-like ChEMBL dataset, the randomized SMILES model generated at least double the number of unique molecules compared to canonical (64.09% vs. 34.67% unique in a 2B sample), with comparable validity (98.33% vs. 98.26%). The canonical model showed a lower FCD (0.0712 vs. 0.1265), but the authors argue this reflects overfitting: the canonical model&rsquo;s NLL distributions for training and validation sets overlapped tightly, while the randomized model showed more uniform coverage. Physicochemical property distributions (molecular weight, logP, SA score, QED, NP score, internal diversity) were nearly identical across both models.</p>
<h3 id="architecture-findings">Architecture findings</h3>
<p>LSTM cells consistently outperformed GRU cells across all SMILES variants. Despite GRU&rsquo;s faster per-epoch training time, LSTM models converged in fewer epochs, making them faster overall. Dropout improved canonical SMILES models but was less beneficial (or detrimental) for randomized SMILES, suggesting that randomized SMILES themselves serve as a regularization mechanism. Larger batch sizes generally improved performance across all variants.</p>
<h3 id="uc-jsd-as-a-model-selection-metric">UC-JSD as a model selection metric</h3>
<p>The UC-JSD showed strong correlation with UCC ($R^{2} = 0.931$ for canonical, $R^{2} = 0.856$ for restricted randomized, $R^{2} = 0.885$ for unrestricted randomized), validating its use as a model selection criterion without requiring expensive sampling of every model.</p>
<p>The authors interpret randomized SMILES models as occupying a hybrid space between grammar-based and action-based generative models. The vocabulary serves as a fixed action space where atom tokens are &ldquo;add atom&rdquo; actions, bond tokens are &ldquo;add bond&rdquo; actions, and ring/branching tokens enable graph traversal. Canonical SMILES constrain this action space to a single deterministic path, while randomized SMILES allow the model to explore multiple valid traversals. This perspective also explains why DeepSMILES performed worse: its altered syntax creates a more complex action space without compensating benefits.</p>
<p>The authors encourage the use of randomized SMILES across different model architectures and tasks, including classification and property prediction, and suggest that finding optimal restricted variants of randomized SMILES is a promising research direction.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>GDB-13 subsets</td>
          <td>1M / 10K / 1K molecules</td>
          <td>Randomly sampled from 975M GDB-13</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>ChEMBL</td>
          <td>1,483,943 training / 78,102 validation</td>
          <td>Filtered subset of ChEMBL database</td>
      </tr>
  </tbody>
</table>
<p>GDB-13 is available from the <a href="http://gdb.unibe.ch/downloads">Reymond group website</a>. ChEMBL is publicly available.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Character-level tokenization with special handling for multi-character tokens (Cl, Br, bracketed atoms, %-prefixed ring numbers)</li>
<li>Teacher forcing during training with NLL loss</li>
<li>Gradient norm clipping to 1.0</li>
<li>Weight initialization from $\mathcal{U}(-\sqrt{1/w}, \sqrt{1/w})$</li>
<li>Adaptive learning rate decay based on UC-JSD</li>
<li>Best epoch selection via smoothed UC-JSD (window size 4)</li>
</ul>
<h3 id="models">Models</h3>
<p>Standard RNN architecture: embedding layer, stacked LSTM/GRU layers with optional dropout, linear output with softmax. Best models used 3 layers of 512-dimensional LSTM cells. Vocabulary sizes: 26 (GDB-13), 31 (ChEMBL).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Randomized</th>
          <th>Best Canonical</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>% GDB-13 (1M)</td>
          <td>83.0%</td>
          <td>72.8%</td>
          <td>2B sample with replacement</td>
      </tr>
      <tr>
          <td>UCC (1M)</td>
          <td>0.860</td>
          <td>0.633</td>
          <td>Composite score</td>
      </tr>
      <tr>
          <td>% GDB-13 (10K)</td>
          <td>62.3%</td>
          <td>38.8%</td>
          <td>2B sample with replacement</td>
      </tr>
      <tr>
          <td>% GDB-13 (1K)</td>
          <td>34.1%</td>
          <td>14.5%</td>
          <td>2B sample with replacement</td>
      </tr>
      <tr>
          <td>% Unique ChEMBL</td>
          <td>64.09%</td>
          <td>34.67%</td>
          <td>2B sample with replacement</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Nvidia Tesla V100 (Volta) 16 GB VRAM with CUDA 9.1, driver 390.30. Training times ranged from 1 minute (1K canonical) to 131 hours (ChEMBL canonical). Randomized SMILES models required longer per-epoch training due to augmentation overhead but converged to better solutions.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/undeadpixel/reinvent-randomized">reinvent-randomized</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training and benchmarking code</td>
      </tr>
      <tr>
          <td><a href="http://gdb.unibe.ch/downloads">GDB-13</a></td>
          <td>Dataset</td>
          <td>Academic use</td>
          <td>975 million fragment-like molecules</td>
      </tr>
      <tr>
          <td><a href="https://github.com/molecularsets/moses">MOSES benchmark</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Used for FCD and property calculations</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Arús-Pous, J., Johansson, S. V., Prykhodko, O., Bjerrum, E. J., Tyrchan, C., Reymond, J.-L., Chen, H., &amp; Engkvist, O. (2019). Randomized SMILES strings improve the quality of molecular generative models. <em>Journal of Cheminformatics</em>, 11(1), 71. <a href="https://doi.org/10.1186/s13321-019-0393-0">https://doi.org/10.1186/s13321-019-0393-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{aruspous2019randomized,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Randomized SMILES strings improve the quality of molecular generative models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ar{\&#39;u}s-Pous, Josep and Johansson, Simon Viet and Prykhodko, Oleksii and Bjerrum, Esben Jannik and Tyrchan, Christian and Reymond, Jean-Louis and Chen, Hongming and Engkvist, Ola}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{71}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-019-0393-0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Protein-to-Drug Molecule Translation via Transformer</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/transformer-protein-drug-generation/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/transformer-protein-drug-generation/</guid><description>A Transformer model frames protein-targeted drug generation as machine translation from amino acid sequences to SMILES molecular strings.</description><content:encoded><![CDATA[<h2 id="protein-targeted-drug-generation-as-machine-translation">Protein-Targeted Drug Generation as Machine Translation</h2>
<p>This is a <strong>Method</strong> paper that proposes using the Transformer neural network architecture for protein-specific de novo drug generation. The primary contribution is framing the problem of generating molecules that bind to a target protein as a machine translation task: translating from the &ldquo;language&rdquo; of amino acid sequences to the SMILES representation of candidate drug molecules. The model takes only a protein&rsquo;s amino acid sequence as input and generates novel molecules with predicted binding affinity, requiring no prior knowledge of active ligands, physicochemical descriptors, or the protein&rsquo;s three-dimensional structure.</p>
<h2 id="limitations-of-existing-generative-drug-design-approaches">Limitations of Existing Generative Drug Design Approaches</h2>
<p>Existing deep learning methods for de novo molecule generation suffer from several limitations. Most RNN-based approaches require a library of known active compounds against the target protein to fine-tune the generator or train a reward predictor for reinforcement learning. Structure-based drug design methods require the three-dimensional structure of the target protein, which can be costly and technically difficult to obtain through protein expression, purification, and crystallization. Autoencoder-based approaches (variational and adversarial) similarly depend on prior knowledge of protein binders or their physicochemical characteristics.</p>
<p>The estimated drug-like molecule space is on the order of $10^{60}$, while only around $10^{8}$ compounds have been synthesized. High-throughput screening is expensive and time-consuming, and virtual screening operates only on known molecules. Computational de novo design methods often generate molecules that are hard to synthesize or restrict accessible chemical space through coded rules. A method that requires only a protein&rsquo;s amino acid sequence would substantially simplify the initial stages of drug discovery, particularly for targets with limited or no information about inhibitors and 3D structure.</p>
<h2 id="sequence-to-sequence-translation-with-self-attention">Sequence-to-Sequence Translation with Self-Attention</h2>
<p>The core insight is to treat protein-targeted drug generation as a translation problem between two &ldquo;languages,&rdquo; applying the Transformer architecture that had demonstrated strong results in neural machine translation. The encoder maps a protein amino acid sequence $(a_1, \ldots, a_n)$ to continuous representations $\mathbf{z} = (z_1, \ldots, z_n)$, and the decoder autoregressively generates a SMILES string conditioned on $\mathbf{z}$.</p>
<p>The self-attention mechanism computes:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
$$</p>
<p>where $d_k$ is a scaling factor. Multihead attention runs $h$ parallel attention heads:</p>
<p>$$
\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$</p>
<p>$$
\text{Multihead}(Q, K, V) = (\text{head}_1, \ldots, \text{head}_h)W^O
$$</p>
<p>Positional encoding uses sinusoidal functions:</p>
<p>$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i / d_{model}}}\right)
$$</p>
<p>$$
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i / d_{model}}}\right)
$$</p>
<p>The self-attention mechanism is particularly well-suited for this task for two reasons. First, protein sequences can be much longer than SMILES strings (dozens of times longer), making the ability to capture long-range dependencies essential. Second, three-dimensional structural features of the binding pocket may be formed by amino acid residues far apart in the linear sequence, and multihead attention can jointly attend to different positional aspects simultaneously.</p>
<h2 id="data-model-architecture-and-docking-evaluation">Data, Model Architecture, and Docking Evaluation</h2>
<h3 id="data">Data</h3>
<p>The training data was retrieved from BindingDB, filtering for interactions between proteins from Homo sapiens, Rattus norvegicus, Mus musculus, and Bos taurus with binding affinity below 100 nM (IC50, Kd, or EC50). After filtering for valid PubChem CIDs, SMILES representations, UniProt IDs, molecular weight under 1000 Da, and amino acid sequence lengths between 80 and 2050, the final dataset contained 238,147 records with 1,613 unique proteins and 154,924 unique ligand SMILES strings.</p>
<p>Five Monte Carlo cross-validation splits were created, with the constraint that test set proteins share less than 20% sequence similarity with training set proteins (measured via <a href="https://en.wikipedia.org/wiki/Needleman%E2%80%93Wunsch_algorithm">Needleman-Wunsch</a> global alignment).</p>
<h3 id="model-configuration">Model Configuration</h3>
<p>The model uses the original Transformer implementation via the tensor2tensor library with:</p>
<ul>
<li>4 encoder/decoder layers of size 128</li>
<li>4 attention heads</li>
<li>Adam optimizer with learning rate decay from the original Transformer paper</li>
<li>Batch size of 4,096 tokens</li>
<li>Training for 600K epochs on a single GPU in Google Colaboratory</li>
<li>Vocabulary of 71 symbols (character-level tokenization)</li>
</ul>
<p>Beam search decoding was used with two modes: beam size 4 keeping only the top-1 result (&ldquo;one per one&rdquo; mode) and beam size 10 keeping all 10 results (&ldquo;ten per one&rdquo; mode).</p>
<h3 id="chemical-validity-and-uniqueness">Chemical Validity and Uniqueness</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>One per One (avg)</th>
          <th>Ten per One (avg)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid SMILES (%)</td>
          <td>90.2</td>
          <td>82.6</td>
      </tr>
      <tr>
          <td>Unique SMILES (%)</td>
          <td>92.3</td>
          <td>81.7</td>
      </tr>
      <tr>
          <td>ZINC15 match (%)</td>
          <td>30.6</td>
          <td>17.1</td>
      </tr>
  </tbody>
</table>
<h3 id="docking-evaluation">Docking Evaluation</h3>
<p>To assess binding affinity, the authors selected two receptor tyrosine kinases from the test set (IGF-1R and VEGFR2) and performed molecular docking with <a href="/notes/chemistry/molecular-design/generation/evaluation/smina-docking-benchmark/">SMINA</a>. Four sets of ligands were compared: known binders, randomly selected compounds, molecules generated for the target protein, and molecules generated for other targets (cross-docking control).</p>
<p>ROC-AUC analysis showed that the docking tool classified generated molecules for the correct target as binders at rates comparable to known binders. For the best-discriminating structures (PDB 3O23 for IGF-1R, PDB 3BE2 for VEGFR2), Mann-Whitney U tests confirmed statistically significant differences between generated-for-target molecules and random compounds, while the difference between generated-for-target and known binders was not significant (p = 0.40 and 0.26 respectively), suggesting the model generates plausible binders.</p>
<h3 id="drug-likeness-properties">Drug-Likeness Properties</h3>
<p>Generated molecules were evaluated against <a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski&rsquo;s Rule of Five</a> and other drug-likeness criteria:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Constraint</th>
          <th>One per One (%)</th>
          <th>Ten per One (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>logP</td>
          <td>&lt; 5</td>
          <td>84.4</td>
          <td>85.6</td>
      </tr>
      <tr>
          <td>Molecular weight</td>
          <td>&lt; 500 Da</td>
          <td>95.8</td>
          <td>88.9</td>
      </tr>
      <tr>
          <td>H-bond donors</td>
          <td>&lt; 5</td>
          <td>95.8</td>
          <td>91.9</td>
      </tr>
      <tr>
          <td>H-bond acceptors</td>
          <td>&lt; 10</td>
          <td>97.9</td>
          <td>93.5</td>
      </tr>
      <tr>
          <td>Rotatable bonds</td>
          <td>&lt; 10</td>
          <td>97.9</td>
          <td>91.2</td>
      </tr>
      <tr>
          <td>TPSA</td>
          <td>&lt; 140</td>
          <td>98.0</td>
          <td>92.7</td>
      </tr>
      <tr>
          <td>SAS</td>
          <td>&lt; 6</td>
          <td>99.9</td>
          <td>100.0</td>
      </tr>
  </tbody>
</table>
<p>Mean QED values were 0.66 +/- 0.19 (one per one) and 0.58 +/- 0.21 (ten per one).</p>
<h3 id="structural-novelty">Structural Novelty</h3>
<p>Tanimoto similarity analysis showed that only 8% of generated structures had similarity above the threshold (&gt; 0.85) to training compounds. The majority (51%) had Tanimoto scores below 0.5. The mean nearest-neighbor Tanimoto similarity of generated molecules to the training set (0.54 +/- 0.17 in one-per-one mode) was substantially lower than the mean within-training-set similarity (0.74 +/- 0.14), indicating the model generates structurally diverse molecules outside the training distribution.</p>
<h2 id="generated-molecules-show-drug-like-properties-and-predicted-binding">Generated Molecules Show Drug-Like Properties and Predicted Binding</h2>
<p>The model generates roughly 90% chemically valid SMILES in one-per-one mode, with 92% uniqueness. Docking simulations on IGF-1R and VEGFR2 suggest that generated molecules for the correct target are statistically indistinguishable from known binders, while molecules generated for other targets behave more like random compounds. Drug-likeness properties fall within acceptable ranges for the vast majority of generated compounds.</p>
<p>The authors acknowledge several limitations:</p>
<ul>
<li>Only two protein targets were analyzed via docking due to computational constraints, and the analysis was limited to proteins with a single well-known druggable binding pocket.</li>
<li>Beam search produces molecules that differ only slightly; diverse beam search or coupling with variational/adversarial autoencoders could improve diversity.</li>
<li>The fraction of molecules matching the ZINC15 database (30.6% in one-per-one mode) could potentially be reduced by pretraining on a larger compound set (e.g., ChEMBL&rsquo;s 1.5 million molecules).</li>
<li>Model interpretability remains limited and is identified as important future work.</li>
<li>The approach is a proof of concept and requires further validation via in vitro assays across diverse protein targets.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data-1">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Test</td>
          <td>BindingDB (filtered)</td>
          <td>238,147 records</td>
          <td>1,613 unique proteins, 154,924 unique SMILES; IC50/Kd/EC50 &lt; 100 nM</td>
      </tr>
      <tr>
          <td>Docking validation</td>
          <td>PDB structures</td>
          <td>11 (IGF-1R), 20 (VEGFR2)</td>
          <td>SMINA docking with default settings</td>
      </tr>
      <tr>
          <td>Database matching</td>
          <td>ZINC15</td>
          <td>N/A</td>
          <td>Used for novelty assessment</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer (encoder-decoder) via tensor2tensor library</li>
<li>Beam search decoding (beam sizes 4 and 10)</li>
<li>Needleman-Wunsch global alignment for protein sequence similarity (EMBOSS)</li>
<li>SMINA for molecular docking</li>
<li>RDKit for validity checking, property calculation, and canonicalization</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>4 layers, 128 hidden size, 4 attention heads</li>
<li>Character-level tokenization with 71-symbol vocabulary</li>
<li>5-fold Monte Carlo cross-validation with &lt; 20% sequence similarity between train/test proteins</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Valid SMILES</td>
          <td>90.2% (1-per-1), 82.6% (10-per-1)</td>
          <td>Averaged across 5 splits</td>
      </tr>
      <tr>
          <td>Unique SMILES</td>
          <td>92.3% (1-per-1), 81.7% (10-per-1)</td>
          <td>Averaged across 5 splits</td>
      </tr>
      <tr>
          <td>ZINC15 match</td>
          <td>30.6% (1-per-1), 17.1% (10-per-1)</td>
          <td>Averaged across 5 splits</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>0.66 +/- 0.19 (1-per-1), 0.58 +/- 0.21 (10-per-1)</td>
          <td>Drug-likeness score</td>
      </tr>
      <tr>
          <td>SAS compliance</td>
          <td>99.9% (1-per-1), 100% (10-per-1)</td>
          <td>SAS &lt; 6</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Google Colaboratory with one GPU</li>
<li>Training for 600K epochs</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/dariagrechishnikova/molecule_structure_generation">molecule_structure_generation</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Jupyter Notebook implementation using tensor2tensor</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Grechishnikova, D. (2021). Transformer neural network for protein-specific de novo drug generation as a machine translation problem. <em>Scientific Reports</em>, 11, 321. <a href="https://doi.org/10.1038/s41598-020-79682-4">https://doi.org/10.1038/s41598-020-79682-4</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{grechishnikova2021transformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Transformer neural network for protein-specific de novo drug generation as a machine translation problem}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Grechishnikova, Daria}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{321}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-020-79682-4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>PASITHEA: Gradient-Based Molecular Design via Dreaming</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/deep-molecular-dreaming-pasithea/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/deep-molecular-dreaming-pasithea/</guid><description>PASITHEA applies inceptionism to molecular design, using gradient-based optimization on SELFIES representations to generate molecules with target properties.</description><content:encoded><![CDATA[<h2 id="inceptionism-applied-to-molecular-inverse-design">Inceptionism Applied to Molecular Inverse Design</h2>
<p>This is a <strong>Method</strong> paper that introduces PASITHEA, a gradient-based approach to de-novo molecular design inspired by inceptionism (deep dreaming) techniques from computer vision. The core contribution is a direct optimization framework that modifies molecular structures by backpropagating through a trained property-prediction network, with the molecular input (rather than weights) serving as the optimizable variable. PASITHEA is enabled by SELFIES, a surjective molecular string representation that guarantees 100% validity of generated molecules.</p>
<h2 id="the-need-for-direct-gradient-based-molecular-optimization">The Need for Direct Gradient-Based Molecular Optimization</h2>
<p>Existing inverse molecular design methods, including variational autoencoders (VAEs), generative adversarial networks (GANs), reinforcement learning (RL), and genetic algorithms (GAs), share a common characteristic: they optimize molecules indirectly. VAEs and GANs learn distributions and scan latent spaces. RL agents learn policies from environmental rewards. GAs iteratively apply mutations and selections. None of these approaches directly maximize an objective function in a gradient-based manner with respect to the molecular representation itself.</p>
<p>This indirection has several consequences. VAE-based methods require learning a latent space, and the optimization happens in that space rather than directly on molecular structures. RL and GA methods require expensive function evaluations for each candidate molecule. The authors identify an opportunity to exploit gradients more directly by reversing the learning process of a neural network trained to predict molecular properties, thereby sidestepping latent spaces, policies, and population-based search entirely.</p>
<p>A second motivation is interpretability. By operating directly on the molecular representation (rather than a learned latent space), PASITHEA can reveal what a regression network has learned about structure-property relationships, a capability the authors frame as analogous to how deep dreaming reveals what image classifiers have learned about visual features.</p>
<h2 id="core-innovation-inverting-regression-networks-on-selfies">Core Innovation: Inverting Regression Networks on SELFIES</h2>
<p>PASITHEA&rsquo;s key insight is a two-phase training procedure that repurposes the standard neural network training loop for molecule generation.</p>
<p><strong>Phase 1: Prediction training.</strong> A fully connected neural network is trained to predict a real-valued chemical property (logP) from one-hot encoded SELFIES strings. The standard feedforward and backpropagation process updates the network weights to minimize mean squared error between predicted and ground-truth property values:</p>
<p>$$
\min_{\theta} \frac{1}{N} \sum_{i=1}^{N} (f_{\theta}(\mathbf{x}_i) - y_i)^2
$$</p>
<p>where $f_{\theta}$ is the neural network with parameters $\theta$, $\mathbf{x}_i$ is the one-hot encoded SELFIES input, and $y_i$ is the target logP value.</p>
<p><strong>Phase 2: Inverse training (deep dreaming).</strong> The network weights $\theta$ are frozen. For a given input molecule $\mathbf{x}$ and a desired target property value $y_{\text{target}}$, the gradients are computed with respect to the input representation rather than the weights:</p>
<p>$$
\mathbf{x} \leftarrow \mathbf{x} - \eta \nabla_{\mathbf{x}} \mathcal{L}(f_{\theta}(\mathbf{x}), y_{\text{target}})
$$</p>
<p>This gradient descent on the input incrementally modifies the one-hot encoding of the molecular string, transforming it toward a structure whose predicted property matches the target value. At each step, the argmax function converts the continuous one-hot encoding back to a discrete SELFIES string, which always maps to a valid molecular graph due to the surjective property of SELFIES.</p>
<p><strong>The role of SELFIES.</strong> The surjective mapping from strings to molecular graphs is essential. With SMILES, intermediate strings during optimization can become syntactically invalid (e.g., an unclosed ring like &ldquo;CCCC1CCCCC&rdquo;), producing no valid molecule. SELFIES enforces constraints that guarantee every string maps to a valid molecular graph, making the continuous gradient-based optimization feasible.</p>
<p><strong>Input noise injection.</strong> Because inverse training transforms a one-hot encoding from binary values to real numbers, the discrete-to-continuous transition can cause convergence problems. The authors address this by initializing the input with noise: every zero in the one-hot encoding is replaced by a random number in $[0, k]$, where $k$ is a hyperparameter between 0.5 and 0.95. This smooths the optimization landscape and enables incremental molecular modifications rather than abrupt changes.</p>
<h2 id="experimental-setup-on-qm9-with-logp-optimization">Experimental Setup on QM9 with LogP Optimization</h2>
<h3 id="dataset-and-property">Dataset and Property</h3>
<p>The experiments use a random subset of 10,000 molecules from the <a href="/notes/chemistry/datasets/qm9/">QM9</a> dataset. The target property is the logarithm of the partition coefficient (logP), computed using RDKit. LogP measures lipophilicity, an important drug-likeness indicator that follows an approximately normal distribution in QM9 and has a nearly continuous range, making it suitable for gradient-based optimization.</p>
<h3 id="network-architecture">Network Architecture</h3>
<p>PASITHEA uses a fully connected neural network with four layers, each containing 500 nodes with ReLU activation. The loss function is mean squared error. Data is split 85%/15% for training/testing. The prediction model trains for approximately 1,500 epochs with an Adam optimizer and a learning rate of $1 \times 10^{-6}$.</p>
<p>For inverse training, the authors select a noise upper-bound of 0.9 and a learning rate of 0.01, chosen from hyperparameter tuning experiments that evaluate the percentage of molecules optimized toward the target property.</p>
<h3 id="optimization-targets">Optimization Targets</h3>
<p>Two extreme logP targets are used: $+6$ (high lipophilicity) and $-6$ (low lipophilicity). These values exceed the range of logP values in the QM9 dataset (minimum: $-2.19$, maximum: $3.08$), testing whether the model can extrapolate beyond the training distribution.</p>
<h2 id="distribution-shifts-and-interpretable-molecular-transformations">Distribution Shifts and Interpretable Molecular Transformations</h2>
<h3 id="distribution-level-results">Distribution-Level Results</h3>
<p>Applying deep dreaming to the full set of 10,000 molecules produces a clear shift in the logP distribution:</p>
<table>
  <thead>
      <tr>
          <th>Statistic</th>
          <th>QM9 Original</th>
          <th>Optimized (target +6)</th>
          <th>Optimized (target -6)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Mean logP</td>
          <td>0.3909</td>
          <td>1.8172</td>
          <td>-0.3360</td>
      </tr>
      <tr>
          <td>Min logP</td>
          <td>-2.1903</td>
          <td>-0.8240</td>
          <td>-2.452</td>
      </tr>
      <tr>
          <td>Max logP</td>
          <td>3.0786</td>
          <td>4.2442</td>
          <td>0.9018</td>
      </tr>
  </tbody>
</table>
<p>The optimized distributions extend beyond the original dataset&rsquo;s property range. The right-shifted distribution (target +6) produces molecules with logP values up to 4.24, exceeding the original maximum of 3.08. The left-shifted distribution (target -6) reaches -2.45, below the original minimum. This indicates that PASITHEA can generate molecules with properties outside the training data bounds.</p>
<p>Additionally, 97.2% of the generated molecules do not exist in the original training set, indicating that the network is not memorizing data but rather using structural features to guide optimization. Some generated molecules contain more heavy atoms than the QM9 maximum of 9, since the SELFIES string length allows for larger structures.</p>
<h3 id="molecule-level-interpretability">Molecule-Level Interpretability</h3>
<p>The stepwise molecular transformations reveal interpretable &ldquo;strategies&rdquo; the network employs:</p>
<ol>
<li>
<p><strong>Nitrogen appendage</strong>: When optimizing for lower logP, the network repeatedly appends nitrogen atoms to the molecule. The authors observe this as a consistent pattern across multiple test molecules, reflecting the known relationship between nitrogen content and reduced lipophilicity.</p>
</li>
<li>
<p><strong>Length modulation</strong>: When optimizing for higher logP, the network tends to increase molecular chain length (e.g., extending a carbon chain). When optimizing for lower logP, it shortens chains. This captures the intuition that larger, more carbon-heavy molecules tend to be more lipophilic.</p>
</li>
<li>
<p><strong>Bond order changes</strong>: The network replaces single bonds with double or triple bonds during optimization, demonstrating an understanding of the relationship between bonding patterns and logP.</p>
</li>
<li>
<p><strong>Consistency across trials</strong>: Because the input initialization includes random noise, repeated trials with the same molecule produce different transformation sequences. Despite this stochasticity, the network applies consistent strategies across trials (e.g., always shortening chains for negative optimization), validating that it has learned genuine structure-property relationships.</p>
</li>
</ol>
<h3 id="thermodynamic-stability">Thermodynamic Stability</h3>
<p>The authors assess synthesizability by computing heats of formation using MOPAC2016 at the PM7 level of theory. Some optimization trajectories move toward thermodynamically stable molecules (negative heats of formation), while others produce less stable structures. The authors acknowledge this limitation and propose multi-objective optimization incorporating stability as a future direction.</p>
<h3 id="comparison-to-vaes">Comparison to VAEs</h3>
<p>The key distinction from VAEs is where gradient computation occurs. In VAEs, a latent space is learned through encoding and decoding, and property optimization happens in that latent space. In PASITHEA, gradients are computed directly with respect to the molecular representation (SELFIES one-hot encoding). The authors argue this makes the approach more interpretable, since we can probe what the network learned about molecular structure without the &ldquo;detour&rdquo; through a latent space.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors are forthright about the preliminary nature of these results:</p>
<ul>
<li>The method is demonstrated only on a small subset of QM9 with a single, computationally inexpensive property (logP).</li>
<li>The simple four-layer architecture may not scale to larger molecular spaces or more complex properties.</li>
<li>Generated molecules are not always thermodynamically stable, requiring additional optimization objectives.</li>
<li>The approach has not been benchmarked against established methods (VAEs, GANs, RL) on standard generative benchmarks.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>QM9 (random subset)</td>
          <td>10,000 molecules</td>
          <td>logP values computed via RDKit</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Prediction training</strong>: 4-layer fully connected NN, 500 nodes/layer, ReLU activation, MSE loss, Adam optimizer, LR $1 \times 10^{-6}$, ~1,500 epochs, 85/15 train/test split</li>
<li><strong>Inverse training</strong>: Frozen weights, Adam optimizer, LR 0.01, noise upper-bound 0.9, logP targets of +6 and -6</li>
<li><strong>Heats of formation</strong>: MOPAC2016, PM7 level, geometry optimization with eigenvector following (EF)</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is a simple 4-layer MLP. No pre-trained weights are distributed, but the full code is available.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Novel molecules</td>
          <td>97.2%</td>
          <td>Generated molecules not in training set</td>
      </tr>
      <tr>
          <td>Max logP (target +6)</td>
          <td>4.2442</td>
          <td>Exceeds QM9 max of 3.0786</td>
      </tr>
      <tr>
          <td>Min logP (target -6)</td>
          <td>-2.452</td>
          <td>Below QM9 min of -2.1903</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/aspuru-guzik-group/Pasithea">Pasithea</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Shen, C., Krenn, M., Eppel, S., &amp; Aspuru-Guzik, A. (2021). Deep molecular dreaming: inverse machine learning for de-novo molecular design and interpretability with surjective representations. <em>Machine Learning: Science and Technology</em>, 2(3), 03LT02. <a href="https://doi.org/10.1088/2632-2153/ac09d6">https://doi.org/10.1088/2632-2153/ac09d6</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shen2021deep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Deep molecular dreaming: inverse machine learning for de-novo molecular design and interpretability with surjective representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shen, Cynthia and Krenn, Mario and Eppel, Sagi and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{03LT02}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1088/2632-2153/ac09d6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Machine Translation of Chemical Nomenclature</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/nmt-chemical-nomenclature-en-zh/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/nmt-chemical-nomenclature-en-zh/</guid><description>Xu et al. apply CNN and LSTM seq2seq models to translate chemical nomenclature between English and Chinese, outperforming rule-based tools.</description><content:encoded><![CDATA[<h2 id="a-method-for-neural-translation-of-chemical-names">A Method for Neural Translation of Chemical Names</h2>
<p>This is a <strong>Method</strong> paper that introduces deep learning approaches for translating chemical nomenclature between English and Chinese. The primary contribution is demonstrating that character-level sequence-to-sequence neural networks (both CNN-based and LSTM-based) can serve as viable alternatives to hand-crafted rule-based translation systems for chemical names. The work compares two neural architectures against an existing rule-based tool on bilingual chemical name datasets.</p>
<h2 id="bridging-the-english-chinese-chemical-nomenclature-gap">Bridging the English-Chinese Chemical Nomenclature Gap</h2>
<p>English and Chinese are the two most widely used languages for chemical nomenclature worldwide. Translation between them is important for chemical data processing, especially for converting Chinese chemical names extracted via named entity recognition into English names that existing name-to-structure tools can parse. Rule-based translation between these languages faces considerable challenges:</p>
<ol>
<li>Chinese chemical names lack word boundaries (no spaces), making segmentation difficult.</li>
<li>Word order is often reversed between English and Chinese chemical names (e.g., &ldquo;ethyl acetate&rdquo; maps to characters meaning &ldquo;acetate-ethyl&rdquo; in Chinese).</li>
<li>The same English morpheme can map to different Chinese characters depending on chemical context (e.g., &ldquo;ethyl&rdquo; translates differently in &ldquo;ethyl acetate&rdquo; vs. &ldquo;ethyl alcohol&rdquo;).</li>
<li>Trivial names, especially for natural products, follow irregular translation patterns or are transliterations.</li>
</ol>
<p>Building comprehensive rule sets requires a formally trained chemist fluent in both languages, making rule-based approaches expensive and fragile.</p>
<h2 id="character-level-sequence-to-sequence-translation">Character-Level Sequence-to-Sequence Translation</h2>
<p>The core idea is to treat chemical name translation as a character-level machine translation task, applying encoder-decoder architectures with attention mechanisms. Two architectures are proposed:</p>
<p><strong>CNN-based architecture</strong>: Three 1D convolutional layers encode the input character sequence. A decoder with three 1D convolutional layers processes the target sequence offset by one timestep, combined with attention mechanism layers that connect encoder and decoder outputs. Two additional 1D convolutional layers produce the final decoded output sequence.</p>
<p><strong>LSTM-based architecture</strong>: An LSTM encoder converts the input sequence into two state vectors. An LSTM decoder is trained with teacher forcing, using the encoder&rsquo;s state vectors as its initial state, and generating the target sequence offset by one timestep.</p>
<p>Both models operate at the character level. Input chemical name strings are transformed into embedding vectors, with the vocabulary size equal to the number of unique characters in the respective language (100 unique characters for English names, 2,056 unique characters for Chinese names).</p>
<h2 id="experimental-setup-and-comparison-with-rule-based-tool">Experimental Setup and Comparison with Rule-Based Tool</h2>
<h3 id="datasets">Datasets</h3>
<p>The authors built two directional datasets from a manually curated corpus of scientific literature maintained at their institution:</p>
<ul>
<li><strong>En2Ch (English to Chinese)</strong>: 30,394 name pairs after deduplication</li>
<li><strong>Ch2En (Chinese to English)</strong>: 37,207 name pairs after deduplication</li>
</ul>
<p>The datasets cover systematic compound names through trivial names. For names with multiple valid translations, the most commonly used translation was selected. Each dataset was split 80/20 for training and validation.</p>
<h3 id="model-configuration">Model Configuration</h3>
<p>Both neural network models used the following hyperparameters:</p>
<ul>
<li>Batch size: 64</li>
<li>Epochs: 100</li>
<li>Latent dimensionality: 256 (encoding and decoding space)</li>
<li>Implementation: Python 3.7 with Keras 2.3 and TensorFlow backend</li>
</ul>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<p>The models were evaluated on five metrics across both translation directions:</p>
<ul>
<li><strong>Success Rate</strong>: Percentage of inputs that produced any output</li>
<li><strong>String Matching Accuracy</strong>: Exact match with the single target name</li>
<li><strong>Data Matching Accuracy</strong>: Exact match allowing any valid translation from the corpus</li>
<li><strong>Manual Spot Check</strong>: Blind evaluation of 100 random samples per approach</li>
<li><strong>Running Time</strong>: Wall-clock time on the same hardware</li>
</ul>
<h3 id="baseline">Baseline</h3>
<p>The rule-based comparison system operates in three steps: disassemble the input name into word fragments, translate each fragment, and reassemble into the target language. This tool had been deployed as an online service with over one million uses at the time of publication.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>CNN</th>
          <th>LSTM</th>
          <th>Rule-based</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success Rate En2Ch</td>
          <td>100%</td>
          <td>100%</td>
          <td>75.97%</td>
      </tr>
      <tr>
          <td>Success Rate Ch2En</td>
          <td>100%</td>
          <td>100%</td>
          <td>59.90%</td>
      </tr>
      <tr>
          <td>String Match En2Ch</td>
          <td>82.92%</td>
          <td>89.64%</td>
          <td>39.81%</td>
      </tr>
      <tr>
          <td>String Match Ch2En</td>
          <td>78.11%</td>
          <td>55.44%</td>
          <td>43.77%</td>
      </tr>
      <tr>
          <td>Data Match En2Ch</td>
          <td>84.44%</td>
          <td>90.82%</td>
          <td>45.15%</td>
      </tr>
      <tr>
          <td>Data Match Ch2En</td>
          <td>80.22%</td>
          <td>57.40%</td>
          <td>44.91%</td>
      </tr>
      <tr>
          <td>Manual Check En2Ch</td>
          <td>90.00%</td>
          <td>89.00%</td>
          <td>80.00%</td>
      </tr>
      <tr>
          <td>Manual Check Ch2En</td>
          <td>82.00%</td>
          <td>61.00%</td>
          <td>78.00%</td>
      </tr>
      <tr>
          <td>Time En2Ch (s)</td>
          <td>1423</td>
          <td>190</td>
          <td>288</td>
      </tr>
      <tr>
          <td>Time Ch2En (s)</td>
          <td>1876</td>
          <td>303</td>
          <td>322</td>
      </tr>
  </tbody>
</table>
<p>Both neural approaches achieved 100% success rate (always producing output), while the rule-based tool failed on 24% and 40% of inputs for En2Ch and Ch2En respectively. The rule-based tool&rsquo;s failures were concentrated on Chinese names lacking word boundaries and on trivial names of natural products.</p>
<p>For English-to-Chinese translation, LSTM performed best at 89.64% string matching accuracy (90.82% data matching), followed by CNN at 82.92%. For Chinese-to-English, CNN substantially outperformed LSTM (78.11% vs. 55.44% string matching), suggesting that LSTM had difficulty with long-term dependencies in Chinese character sequences. The authors observed that many LSTM errors appeared at the ends of chemical names.</p>
<h3 id="analysis-by-name-type">Analysis by Name Type</h3>
<p>The CNN-based approach outperformed LSTM on CAS names (80% vs. 52% in manual checks) and was more robust for longer names. The rule-based tool showed consistent performance regardless of name length, suggesting it was more suited to regular systematic names but struggled with the diversity of real-world chemical nomenclature.</p>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Performance depends heavily on training data quality and quantity.</li>
<li>Neither neural approach was validated on an external test set outside the institution&rsquo;s corpus.</li>
<li>The CNN model was considerably slower (5-6x) than the other two approaches.</li>
<li>No comparison against modern transformer-based NMT architectures (the study predates widespread adoption of transformers for this task).</li>
<li>The dataset is relatively small by modern NMT standards (30-37K pairs).</li>
<li>The authors noted that some neural translations were actually better than the target labels, suggesting the evaluation metrics understate true performance.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest that combining CNN and LSTM architectures could yield further improvements, and that the approach has practical applications in scientific publishing (Chinese journals requiring English abstracts) and chemical database interoperability.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Validation (En2Ch)</td>
          <td>Curated bilingual corpus</td>
          <td>30,394 pairs</td>
          <td>80/20 split, from SIOC chemical data system</td>
      </tr>
      <tr>
          <td>Training/Validation (Ch2En)</td>
          <td>Curated bilingual corpus</td>
          <td>37,207 pairs</td>
          <td>80/20 split, from SIOC chemical data system</td>
      </tr>
      <tr>
          <td>Testing (En2Ch)</td>
          <td>Held-out validation split</td>
          <td>6,079 records</td>
          <td>Same source</td>
      </tr>
      <tr>
          <td>Testing (Ch2En)</td>
          <td>Held-out validation split</td>
          <td>7,441 records</td>
          <td>Same source</td>
      </tr>
  </tbody>
</table>
<p>Training data, Python code for both models, and result data are provided as supplementary files with the paper.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Character-level CNN encoder-decoder with attention (3+3+2 conv layers)</li>
<li>Character-level LSTM encoder-decoder with teacher forcing</li>
<li>Batch size: 64, epochs: 100, latent dim: 256</li>
</ul>
<h3 id="models">Models</h3>
<p>Both models implemented in Python 3.7 with Keras 2.3 / TensorFlow. No pre-trained weights are released separately, but the training code is provided as supplementary material.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Value (En2Ch)</th>
          <th>Best Value (Ch2En)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Success Rate</td>
          <td>100% (both DL)</td>
          <td>100% (both DL)</td>
          <td>Rule-based: 75.97% / 59.90%</td>
      </tr>
      <tr>
          <td>String Matching</td>
          <td>89.64% (LSTM)</td>
          <td>78.11% (CNN)</td>
          <td>Best neural model per direction</td>
      </tr>
      <tr>
          <td>Data Matching</td>
          <td>90.82% (LSTM)</td>
          <td>80.22% (CNN)</td>
          <td>Allows multiple valid translations</td>
      </tr>
      <tr>
          <td>Manual Spot Check</td>
          <td>90.00% (CNN)</td>
          <td>82.00% (CNN)</td>
          <td>Blind evaluation of 100 samples</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Running times reported but hardware details not provided.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.1186/s13321-020-00457-0">Supplementary files</a></td>
          <td>Code + Data</td>
          <td>CC-BY 4.0</td>
          <td>Training data, CNN/LSTM code, results (Additional files 1-6)</td>
      </tr>
      <tr>
          <td><a href="https://www.organchem.csdb.cn/translate">SIOC Translation Tool</a></td>
          <td>Other</td>
          <td>Not specified</td>
          <td>Rule-based baseline tool, online service</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, T., Chen, W., Zhou, J., Dai, J., Li, Y., &amp; Zhao, Y. (2020). Neural machine translation of chemical nomenclature between English and Chinese. <em>Journal of Cheminformatics</em>, 12, 50. <a href="https://doi.org/10.1186/s13321-020-00457-0">https://doi.org/10.1186/s13321-020-00457-0</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{xu2020neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural machine translation of chemical nomenclature between English and Chinese}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xu, Tingjun and Chen, Weiming and Zhou, Junhong and Dai, Jingfang and Li, Yingyong and Zhao, Yingli}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{50}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00457-0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>nach0: A Multimodal Chemical and NLP Foundation Model</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/nach0-multimodal-chemical-language-model/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/multimodal/nach0-multimodal-chemical-language-model/</guid><description>nach0 is a T5-based encoder-decoder model pre-trained on SMILES, scientific text, and patents, then instruction-tuned for chemical and NLP tasks.</description><content:encoded><![CDATA[<h2 id="a-multi-domain-encoder-decoder-for-chemistry-and-nlp">A Multi-Domain Encoder-Decoder for Chemistry and NLP</h2>
<p>nach0 is a <strong>Method</strong> paper that introduces a unified encoder-decoder foundation model capable of handling both natural language processing (NLP) tasks and chemistry tasks within a single architecture. The primary contribution is demonstrating that a T5-based model pre-trained on scientific text, patents, and <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> molecular strings can be instruction-tuned to perform molecular property prediction, reaction prediction, molecular generation, named entity recognition, question answering, and cross-domain translation (text-to-molecule and molecule-to-text) simultaneously. The model is available in base (250M parameters) and large (780M parameters) configurations.</p>
<h2 id="bridging-chemical-and-linguistic-representations">Bridging Chemical and Linguistic Representations</h2>
<p>Most existing biomedical language models (BioBERT, SciFive, BioMegatron) are trained exclusively on natural language text from sources like PubMed, omitting chemical structure information encoded in SMILES strings. Conversely, chemistry-specific models trained on SMILES data often lack the ability to process natural language instructions or perform NLP tasks. Models like <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a> and MolT5 attempted to bridge this gap by training on both natural language and chemical data, but they were not fine-tuned on a diverse set of chemical tasks using instruction tuning in a multi-task fashion.</p>
<p>nach0 addresses this by creating a shared representation space for both modalities and fine-tuning across a comprehensive set of tasks spanning three domains: NLP-only tasks, chemistry-only tasks, and cross-domain tasks that require translating between natural language and molecular representations.</p>
<h2 id="unified-text-to-text-framework-with-smiles-tokenization">Unified Text-to-Text Framework with SMILES Tokenization</h2>
<p>The core innovation in nach0 is formulating all chemical and linguistic tasks as text-to-text problems within a single encoder-decoder transformer, combined with a specialized SMILES tokenization strategy.</p>
<h3 id="smiles-token-integration">SMILES Token Integration</h3>
<p>Rather than treating SMILES as plain text, nach0 extends the T5 vocabulary with dedicated SMILES tokens. Each SMILES token is annotated with special symbols in the format <code>&lt;sm_{token}&gt;</code>, creating a distinct vocabulary space for molecular representations while preserving the natural language vocabulary from FLAN-T5. The embedding matrix is initialized by reusing learned embeddings from the pre-trained model for original tokens, with new chemical tokens initialized from the first embeddings.</p>
<h3 id="architecture">Architecture</h3>
<p>Both model sizes use the standard <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a> encoder-decoder architecture:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Parameters</th>
          <th>Layers</th>
          <th>Hidden Size</th>
          <th>FFN Size</th>
          <th>Attention Heads</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Base</td>
          <td>250M</td>
          <td>12</td>
          <td>768</td>
          <td>3072</td>
          <td>12</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>780M</td>
          <td>24</td>
          <td>1024</td>
          <td>4096</td>
          <td>16</td>
      </tr>
  </tbody>
</table>
<h3 id="pre-training-data">Pre-training Data</h3>
<p>The model is pre-trained with a language modeling objective on three data sources:</p>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>Documents</th>
          <th>Tokens</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PubMed abstracts (chemistry-filtered)</td>
          <td>13M</td>
          <td>355M</td>
      </tr>
      <tr>
          <td>USPTO patent descriptions</td>
          <td>119K</td>
          <td>2.9B</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/datasets/zinc-22/">ZINC</a> molecular database</td>
          <td>~100M</td>
          <td>4.7B</td>
      </tr>
  </tbody>
</table>
<h3 id="instruction-tuning">Instruction Tuning</h3>
<p>Following the approach of Raffel et al. and Chung et al., nach0 uses natural language prompts to formulate each task. For example, a retrosynthesis task might be phrased as &ldquo;What reactants could be used to synthesize [SMILES]?&rdquo; and a property prediction task as &ldquo;Can [SMILES] penetrate the <a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">BBB</a>?&rdquo; This enables multi-task training across all domains with a single loss function and shared hyperparameters.</p>
<p>Training uses a batch size of 1024, learning rate of $1 \times 10^{-4}$, and weight decay of 0.01. Pre-training runs for one epoch, and fine-tuning for 10 epochs. Data mixing follows the examples-proportional mixing strategy from T5.</p>
<h2 id="multi-task-evaluation-across-nlp-and-chemistry-benchmarks">Multi-Task Evaluation Across NLP and Chemistry Benchmarks</h2>
<p>nach0 is evaluated on a comprehensive set of benchmarks spanning three task categories.</p>
<h3 id="task-categories">Task Categories</h3>
<p><strong>NLP tasks</strong>: Named entity recognition (BC5CDR-Chemical, BC5CDR-Disease, NCBI-Disease, BC2GM, JNLPBA), PICO extraction (EBM PICO), textual entailment (MedNLI, SciTail), relation extraction (ChemProt, DDI, GAD), sentence similarity (BIOSSES), document classification (HoC), and question answering (PubMedQA, BioASQ, MedMCQA, MMLU).</p>
<p><strong>Chemistry tasks</strong>: Molecular property prediction (ESOL, FreeSolv, Lipophilicity, BBBP, HIV, BACE from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>; <a href="/notes/chemistry/datasets/qm9/">QM9</a> from Mol-Instructions), molecular generation (<a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>), forward reaction prediction, reagent prediction, and <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthesis</a> (from Mol-Instructions/USPTO).</p>
<p><strong>Cross-domain tasks</strong>: Description-guided molecule design and molecular description generation (from Mol-Instructions).</p>
<h3 id="baselines">Baselines</h3>
<p>nach0 is compared against FLAN-T5 (250M), SciFive (220M), and MolT5 (220M), all trained in multi-task fashion.</p>
<h3 id="key-results">Key Results</h3>
<p>On chemistry and cross-domain tasks, nach0 base consistently outperforms all base-sized baselines. Selected highlights from Table 3:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>MolT5</th>
          <th>SciFive</th>
          <th>FLAN</th>
          <th>nach0 Base</th>
          <th>nach0 Large</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Forward reaction</td>
          <td>Acc@1</td>
          <td>27.0%</td>
          <td>60.0%</td>
          <td>59.0%</td>
          <td>88.0%</td>
          <td>89.9%</td>
      </tr>
      <tr>
          <td>Retrosynthesis</td>
          <td>Acc@1</td>
          <td>15.0%</td>
          <td>31.0%</td>
          <td>31.0%</td>
          <td>53.0%</td>
          <td>56.3%</td>
      </tr>
      <tr>
          <td>Reagent prediction</td>
          <td>Acc@1</td>
          <td>1.1%</td>
          <td>3.8%</td>
          <td>4.0%</td>
          <td>6.3%</td>
          <td>13.1%</td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>BA</td>
          <td>0.58</td>
          <td>0.65</td>
          <td>0.65</td>
          <td>0.74</td>
          <td>0.71</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>BA</td>
          <td>0.55</td>
          <td>0.66</td>
          <td>0.60</td>
          <td>0.67</td>
          <td>0.68</td>
      </tr>
      <tr>
          <td>HFE (FreeSolv)</td>
          <td>R2</td>
          <td>-0.36</td>
          <td>0.51</td>
          <td>0.55</td>
          <td>0.77</td>
          <td>0.78</td>
      </tr>
      <tr>
          <td>MOSES (FCD)</td>
          <td>FCD/Test</td>
          <td>0.521</td>
          <td>0.578</td>
          <td>0.529</td>
          <td>0.311</td>
          <td>0.304</td>
      </tr>
      <tr>
          <td>Description-guided mol. design</td>
          <td>BLEU-2</td>
          <td>30.3%</td>
          <td>44.2%</td>
          <td>43.6%</td>
          <td>49.0%</td>
          <td>48.8%</td>
      </tr>
      <tr>
          <td>Mol. description gen.</td>
          <td>BLEU-2</td>
          <td>35.6%</td>
          <td>39.6%</td>
          <td>38.6%</td>
          <td>43.9%</td>
          <td>41.7%</td>
      </tr>
  </tbody>
</table>
<p>On NLP tasks, nach0 base performs comparably to FLAN base, with the two models trading wins across different tasks. nach0 large improves substantially over nach0 base on most tasks.</p>
<h3 id="ablation-study">Ablation Study</h3>
<p>The ablation study (Table 4) examines the impact of multi-task training across chemical task groups. Key findings:</p>
<ul>
<li>nach0 trained on all chemical tasks jointly outperforms models trained on individual task groups (prediction-only, reaction-only, or generation-only) on the total set of metrics</li>
<li>The joint model shows lower novelty scores on MOSES compared to the generation-only model, but this reflects less overfitting to training data rather than worse performance</li>
<li>nach0 consistently outperforms MolT5 across all chemical task configurations, demonstrating the benefit of pre-training on both natural language and chemical data with specialized SMILES tokens</li>
</ul>
<h3 id="case-studies">Case Studies</h3>
<p>Two applied case studies demonstrate nach0 in drug discovery scenarios:</p>
<ol>
<li>
<p><strong>End-to-end drug discovery for <a href="https://en.wikipedia.org/wiki/Diabetes">diabetes mellitus</a></strong>: Using a sequence of prompts, nach0 identifies biological targets, analyzes mechanisms of action, generates molecular structures, proposes synthesis routes, and predicts molecular properties.</p>
</li>
<li>
<p><strong><a href="https://en.wikipedia.org/wiki/Janus_kinase_3">JAK3</a> inhibitor generation with Chemistry42</strong>: nach0 replaces 42 specialized generative models in Insilico Medicine&rsquo;s Chemistry42 platform. In 45 minutes, nach0 generates 8 molecules satisfying all 2D and 3D requirements (hinge binding, active site binding), compared to a 0.04% discovery rate from a combinatorial generator over 24 hours. Chemistry42&rsquo;s full pipeline (72 hours) still produces better structures since it uses reinforcement learning feedback and explicit structural constraints.</p>
</li>
</ol>
<h3 id="comparison-with-chatgpt">Comparison with ChatGPT</h3>
<p>On a subset evaluation, fine-tuned nach0 base outperforms GPT-3.5-turbo on all tested tasks: EBM PICO (F1: 67.6% vs. 64.4%), MedMCQA-Open (BLEU-2: 6.3% vs. 1.7%), and molecular description generation (BLEU-2: 42.8% vs. 2.2%).</p>
<h2 id="competitive-multi-task-performance-with-clear-limitations">Competitive Multi-Task Performance with Clear Limitations</h2>
<p>nach0 demonstrates that a single encoder-decoder model can achieve competitive results across both chemical and NLP tasks when pre-trained on mixed-modality data and fine-tuned with instruction tuning. The model&rsquo;s strongest advantages appear on chemistry tasks (reaction prediction, property prediction, molecular generation), where specialized SMILES tokenization and chemical pre-training provide clear benefits over general-purpose models of similar scale.</p>
<h3 id="limitations-acknowledged-by-the-authors">Limitations Acknowledged by the Authors</h3>
<ol>
<li>
<p><strong>Not at chemist expert level</strong>: Human evaluations indicate the model does not match domain expert performance. Key gaps include chemical reasoning, knowledge alignment with domain-specific knowledge graphs, and the ability to learn from expert feedback.</p>
</li>
<li>
<p><strong>SMILES-only molecular representation</strong>: The model lacks 3D geometric information. SMILES notation is not one-to-one with molecular structures, and the model does not incorporate molecular graphs or 3D coordinates. The authors suggest <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> as a potential alternative representation.</p>
</li>
<li>
<p><strong>Prompt sensitivity</strong>: Performance depends on prompt quality and specificity. Over-reliance on domain-specific prompts may limit response diversity.</p>
</li>
<li>
<p><strong>Limited chemical diversity</strong>: Cross-domain datasets from Mol-Instructions primarily cover known drugs and chemical probes from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>, representing only a fraction of predicted chemical space.</p>
</li>
</ol>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose extending nach0 with protein sequence modalities (using <a href="/notes/chemistry/molecular-representations/notations/group-selfies-fragment-molecular-representation/">Group SELFIES</a>), expanding zero-shot evaluation capabilities, and integrating knowledge graph information through self-supervised approaches.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training (text)</td>
          <td>PubMed abstracts</td>
          <td>13M docs, 355M tokens</td>
          <td>Filtered for chemistry-related content</td>
      </tr>
      <tr>
          <td>Pre-training (text)</td>
          <td>USPTO patents</td>
          <td>119K docs, 2.9B tokens</td>
          <td>Patent descriptions</td>
      </tr>
      <tr>
          <td>Pre-training (chemical)</td>
          <td>ZINC</td>
          <td>~100M docs, 4.7B tokens</td>
          <td>Molecular SMILES strings</td>
      </tr>
      <tr>
          <td>Fine-tuning (NLP)</td>
          <td>17 NLP datasets</td>
          <td>Varies</td>
          <td>See Table 1 in paper</td>
      </tr>
      <tr>
          <td>Fine-tuning (chemistry)</td>
          <td>MoleculeNet, MOSES, Mol-Instructions</td>
          <td>Varies</td>
          <td>Predefined or random splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: T5 encoder-decoder (base: 250M, large: 780M parameters)</li>
<li>Pre-training objective: Language modeling (masked span prediction)</li>
<li>Fine-tuning: Multi-task instruction tuning with examples-proportional mixing</li>
<li>Hyperparameters: batch size 1024, learning rate $1 \times 10^{-4}$, weight decay 0.01</li>
<li>Pre-training: 1 epoch; fine-tuning: 10 epochs</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/insilicomedicine/nach0_base">nach0 Base (HuggingFace)</a></td>
          <td>Model</td>
          <td>CC-BY-NC-4.0</td>
          <td>250M parameter encoder-decoder</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/insilicomedicine/nach0_large">nach0 Large (HuggingFace)</a></td>
          <td>Model</td>
          <td>CC-BY-NC-4.0</td>
          <td>780M parameter encoder-decoder</td>
      </tr>
      <tr>
          <td><a href="https://github.com/insilicomedicine/nach0">nach0 GitHub Repository</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Training and inference code</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation spans 17+ NLP benchmarks and 10+ chemistry benchmarks. Metrics include F1 (NER, RE, classification), accuracy (QA, entailment, reaction prediction), balanced accuracy (molecular property classification), R2/RMSE (regression), BLEU-2 (generation), and FCD/SNN/validity/novelty (molecular generation via MOSES).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Base models: NVIDIA A4000 and A5000 GPUs</li>
<li>Large models: NVIDIA DGX cloud platform</li>
<li>Training used tensor and pipeline parallelism via NeMo toolkit</li>
<li>Specific GPU counts and training times not reported</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Livne, M., Miftahutdinov, Z., Tutubalina, E., Kuznetsov, M., Polykovskiy, D., Brundyn, A., Jhunjhunwala, A., Costa, A., Aliper, A., Aspuru-Guzik, A., &amp; Zhavoronkov, A. (2024). nach0: Multimodal Natural and Chemical Languages Foundation Model. <em>Chemical Science</em>, 15(22), 8380-8389. <a href="https://doi.org/10.1039/D4SC00966E">https://doi.org/10.1039/D4SC00966E</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{livne2024nach0,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{nach0: multimodal natural and chemical languages foundation model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Livne, Micha and Miftahutdinov, Zulfat and Tutubalina, Elena and Kuznetsov, Maksim and Polykovskiy, Daniil and Brundyn, Annika and Jhunjhunwala, Aastha and Costa, Anthony and Aliper, Alex and Aspuru-Guzik, Al{\&#39;a}n and Zhavoronkov, Alex}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{8380--8389}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D4SC00966E}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolPMoFiT: Inductive Transfer Learning for QSAR</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/</guid><description>MolPMoFiT adapts ULMFiT for QSAR by pre-training an LSTM language model on 1M ChEMBL SMILES and fine-tuning on small molecular property datasets.</description><content:encoded><![CDATA[<h2 id="transfer-learning-meets-molecular-property-prediction">Transfer Learning Meets Molecular Property Prediction</h2>
<p>This is a <strong>Method</strong> paper that introduces MolPMoFiT (Molecular Prediction Model Fine-Tuning), a transfer learning approach for <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSPR/QSAR</a> modeling. The primary contribution is adapting the ULMFiT framework from NLP to molecular property prediction by treating <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a> as a chemical language. A general-purpose molecular structure prediction model (MSPM) is pre-trained on one million <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> molecules via self-supervised next-token prediction, then fine-tuned for specific QSAR endpoints. The approach achieves competitive or superior results to graph neural networks and descriptor-based methods across four benchmark datasets, with particular benefits for small datasets.</p>
<h2 id="the-small-data-problem-in-qsar-modeling">The Small Data Problem in QSAR Modeling</h2>
<p>Deep learning models for molecular property prediction typically require large labeled training sets to learn useful structural representations. While methods like graph convolutional neural networks and SMILES-based models have achieved strong results on well-studied endpoints, they must be trained from scratch for each new task. This presents a challenge for small chemical datasets with limited labeled data, which remain common in drug discovery for specialized endpoints like <a href="https://en.wikipedia.org/wiki/Allosteric_regulation">allosteric inhibition</a>, renal clearance, and inhibitor residence times.</p>
<p>Transfer learning had already shown transformative impact in computer vision (ImageNet pre-training) and NLP (ELMo, BERT, ULMFiT). In chemistry, prior transfer learning efforts included ChemNet (supervised pre-training on computed descriptors), <a href="/notes/chemistry/molecular-representations/encoders/mol2vec-unsupervised-chemical-intuition/">Mol2vec</a> (unsupervised substructure embeddings), and pre-trained graph neural networks. However, a systematic application of the ULMFiT self-supervised pre-training pipeline to SMILES-based molecular models had not been explored. MolPMoFiT fills this gap by treating the vast corpus of unlabeled molecular structures as a self-supervised training signal, analogous to how language models learn from unlabeled text.</p>
<h2 id="core-innovation-ulmfit-adapted-for-smiles">Core Innovation: ULMFiT Adapted for SMILES</h2>
<p>MolPMoFiT adapts ULMFiT&rsquo;s three-stage transfer learning pipeline to molecular property prediction:</p>
<p><strong>Stage 1: General-Domain MSPM Pre-training.</strong> A molecular structure prediction model is trained on one million curated ChEMBL molecules to predict the next token in a SMILES string. This is purely self-supervised: the SMILES string provides its own labels. The model learns general chemical syntax and structural patterns.</p>
<p><strong>Stage 2: Task-Specific MSPM Fine-tuning (Optional).</strong> The general MSPM is further fine-tuned on the unlabeled SMILES of the target task dataset. This adapts the language model to the specific chemical distribution of interest (e.g., HIV inhibitors vs. general bioactive molecules). Discriminative fine-tuning adjusts learning rates per layer:</p>
<p>$$\eta^{layer-1} = \eta^{layer} / 2.6$$</p>
<p>where higher layers (containing more task-specific features) receive higher learning rates.</p>
<p><strong>Stage 3: QSAR/QSPR Model Fine-tuning.</strong> The embedding and encoder weights from the pre-trained MSPM are transferred to a new model with a task-specific classifier head. Fine-tuning uses three key techniques from ULMFiT:</p>
<ul>
<li><strong>Discriminative fine-tuning</strong>: Different learning rates per layer group</li>
<li><strong>Gradual unfreezing</strong>: Layers are unfrozen sequentially (classifier first, then progressively deeper LSTM layers)</li>
<li><strong>One cycle policy</strong>: Learning rate scheduling following Smith&rsquo;s approach</li>
</ul>
<p>The model architecture is AWD-LSTM (ASGD Weight-Dropped LSTM) with an embedding dimension of 400, three LSTM layers with 1152 hidden units, and dropouts applied at every layer (embedding, input, weights, hidden). The QSAR classifier concatenates max pooling, mean pooling, and the last hidden state $h_T$ from the final LSTM layer, feeding this into two feedforward layers.</p>
<p><strong>SMILES Augmentation.</strong> Since multiple valid SMILES can represent the same molecule through different atom orderings, the authors use <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES enumeration</a> as data augmentation. For regression tasks, Gaussian noise ($\sigma_{noise}$) is added to labels of augmented SMILES to simulate experimental error. Test-time augmentation (TTA) averages predictions across the canonical SMILES and four randomized SMILES.</p>
<h2 id="benchmarks-across-four-qsar-datasets">Benchmarks Across Four QSAR Datasets</h2>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size</th>
          <th>Task</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Lipophilicity">Lipophilicity</a></td>
          <td>4,200</td>
          <td>Regression (logD)</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Regression (<a href="https://en.wikipedia.org/wiki/Solvation">solvation energy</a>)</td>
          <td>RMSE</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>41,127</td>
          <td>Classification (replication inhibition)</td>
          <td>AUROC</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>2,039</td>
          <td>Classification (<a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">blood-brain barrier</a>)</td>
          <td>AUROC</td>
      </tr>
  </tbody>
</table>
<p>All datasets use the same 10 random 80:10:10 splits from <a href="/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/">Yang et al. (2019)</a> for fair comparison. Both random and scaffold splits were evaluated, with scaffold splits representing a more realistic test of generalization to novel chemical scaffolds.</p>
<h3 id="baselines">Baselines</h3>
<p>Models were compared against results reported by Yang et al. (2019): directed message passing neural network (D-MPNN), D-MPNN with RDKit features, random forest on Morgan fingerprints, feed-forward networks on Morgan fingerprints, and feed-forward networks on <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> descriptors.</p>
<h3 id="hyperparameters">Hyperparameters</h3>
<p>The same set of fine-tuning hyperparameters was used across all four tasks (tuned on the HIV dataset):</p>
<table>
  <thead>
      <tr>
          <th>Layer Group</th>
          <th>Base Learning Rate</th>
          <th>Epochs</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Linear head only</td>
          <td>3e-2</td>
          <td>4</td>
      </tr>
      <tr>
          <td>+ Final LSTM layer</td>
          <td>5e-3</td>
          <td>4</td>
      </tr>
      <tr>
          <td>+ Final two LSTM layers</td>
          <td>5e-4</td>
          <td>4</td>
      </tr>
      <tr>
          <td>Full model</td>
          <td>5e-5</td>
          <td>6</td>
      </tr>
  </tbody>
</table>
<p>Data augmentation settings were task-specific: lipophilicity training SMILES augmented 25x ($\sigma_{noise} = 0.3$); FreeSolv augmented 50x ($\sigma_{noise} = 0.5$); HIV active class augmented 60x and inactive 2x; BBBP positive class 10x and negative 30x.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="benchmark-results">Benchmark Results</h3>
<p><strong>Lipophilicity (random split):</strong> MolPMoFiT achieved RMSE of $0.565 \pm 0.037$ with TTA and $0.625 \pm 0.032$ without, outperforming D-MPNN and other baselines.</p>
<p><strong>FreeSolv (random split):</strong> RMSE of $1.197 \pm 0.127$ with TTA. The small dataset size (642 compounds) led to high variance across splits.</p>
<p><strong>BBBP (random split):</strong> AUROC of $0.950 \pm 0.020$, outperforming all comparison models. Task-specific MSPM fine-tuning showed no clear benefit over the general MSPM.</p>
<p><strong>HIV (random split):</strong> General MolPMoFiT achieved AUROC of $0.828 \pm 0.029$ with TTA. Task-specific fine-tuning yielded a slightly higher $0.834 \pm 0.025$ with TTA.</p>
<p>Scaffold splits consistently produced lower performance than random splits across all datasets, as expected for out-of-distribution generalization.</p>
<h3 id="transfer-learning-impact">Transfer Learning Impact</h3>
<p>Across all four datasets and varying training set sizes, MolPMoFiT consistently outperformed models trained from scratch with the same architecture. The improvement was most pronounced at smaller training set sizes, confirming the utility of pre-trained representations for low-data regimes.</p>
<h3 id="smiles-augmentation-analysis">SMILES Augmentation Analysis</h3>
<p>Training data augmentation provided significant improvements across all tasks. For classification (HIV, BBBP), augmentation improved performance regardless of whether class re-balancing was applied. For regression (lipophilicity, FreeSolv), both SMILES augmentation and label noise were beneficial, with optimal noise levels varying by dataset.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors note a fundamental limitation: the model learns mappings from individual SMILES strings to properties rather than from molecular structures to properties. SMILES augmentation acts as a regularization technique to mitigate this, making the model more robust to different SMILES representations of the same molecule. The task-specific MSPM fine-tuning stage did not consistently improve results, requiring further investigation. All hyperparameters were tuned on one dataset (HIV) and applied uniformly, which may not be optimal for all endpoints.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL (curated)</td>
          <td>1M molecules</td>
          <td>Filtered: no mixtures, max 50 heavy atoms, standardized with MolVS, canonized with RDKit</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>MoleculeNet benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HIV</td>
          <td>41,127</td>
          <td>MoleculeNet benchmark</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td>MoleculeNet benchmark</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>AWD-LSTM architecture with embedding dim 400, three LSTM layers (1152 hidden units), dropouts at all layers</li>
<li>ULMFiT fine-tuning: discriminative learning rates ($\eta^{layer-1} = \eta^{layer}/2.6$), gradual unfreezing, one cycle policy</li>
<li>SMILES character-level tokenization with special handling for two-character tokens (Cl, Br) and bracket-enclosed tokens</li>
<li>SMILES enumeration for data augmentation with optional Gaussian label noise for regression</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>General-domain MSPM pre-trained on 1M ChEMBL molecules (10 epochs)</li>
<li>Task-specific MSPMs fine-tuned per dataset (optional stage)</li>
<li>QSAR models fine-tuned with transferred embeddings and encoder</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Split</th>
          <th>Metric</th>
          <th>MolPMoFiT (TTA)</th>
          <th>Best Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Lipophilicity</td>
          <td>Random</td>
          <td>RMSE</td>
          <td>$0.565 \pm 0.037$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>Scaffold</td>
          <td>RMSE</td>
          <td>$0.635 \pm 0.031$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Random</td>
          <td>RMSE</td>
          <td>$1.197 \pm 0.127$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>Scaffold</td>
          <td>RMSE</td>
          <td>$2.082 \pm 0.460$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Random</td>
          <td>AUROC</td>
          <td>$0.950 \pm 0.020$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>Scaffold</td>
          <td>AUROC</td>
          <td>$0.931 \pm 0.025$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Random</td>
          <td>AUROC</td>
          <td>$0.828 \pm 0.029$</td>
          <td>D-MPNN</td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>Scaffold</td>
          <td>AUROC</td>
          <td>$0.816 \pm 0.022$</td>
          <td>D-MPNN</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>NVIDIA Quadro P4000 GPU (single GPU)</li>
<li>General-domain MSPM pre-training: approximately 1 day</li>
<li>Pre-training needs to be done only once; fine-tuning is fast per task</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/XinhaoLi74/MolPMoFiT">MolPMoFiT</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>PyTorch + fastai v1 implementation with curated datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, X., &amp; Fourches, D. (2020). Inductive transfer learning for molecular activity prediction: Next-Gen QSAR Models with MolPMoFiT. <em>Journal of Cheminformatics</em>, 12, 27. <a href="https://doi.org/10.1186/s13321-020-00430-x">https://doi.org/10.1186/s13321-020-00430-x</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2020molpmofit,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Inductive transfer learning for molecular activity prediction: Next-Gen QSAR Models with MolPMoFiT}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Xinhao and Fourches, Denis}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{27}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-020-00430-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolBERT: Auxiliary Tasks for Molecular BERT Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/</guid><description>MolBERT applies BERT to SMILES with domain-relevant auxiliary tasks like physicochemical property prediction, improving virtual screening and QSAR.</description><content:encoded><![CDATA[<h2 id="bert-based-molecular-representations-with-auxiliary-pre-training-tasks">BERT-Based Molecular Representations with Auxiliary Pre-Training Tasks</h2>
<p>This is a <strong>Method</strong> paper that introduces MolBERT, a bidirectional Transformer (BERT) architecture applied to SMILES-based molecular representations for drug discovery. The primary contribution is a systematic study of how different domain-relevant self-supervised pre-training tasks affect the quality of learned molecular embeddings, paired with a model that achieves state-of-the-art performance on <a href="https://en.wikipedia.org/wiki/Virtual_screening">virtual screening</a> and <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">quantitative structure-activity relationship (QSAR)</a> benchmarks.</p>
<h2 id="why-domain-relevant-pre-training-matters-for-molecular-language-models">Why Domain-Relevant Pre-Training Matters for Molecular Language Models</h2>
<p>Molecular representations are foundational for predictive, generative, and analytical tasks in drug discovery. Language models applied to text-based molecular representations like SMILES have demonstrated strong performance across property prediction, reaction prediction, and molecular generation. However, several open questions remained at the time of this work:</p>
<ol>
<li><strong>Task selection for pre-training</strong>: Prior work explored masked token prediction, input translation, and property concatenation, but there was no systematic comparison of how different self-supervised tasks affect downstream performance.</li>
<li><strong>SMILES ambiguity</strong>: The same molecule can be encoded as many different SMILES strings depending on how the molecular graph is traversed. Canonicalization algorithms address this but introduce their own artifacts that may distract the model.</li>
<li><strong>Domain knowledge integration</strong>: Standard NLP pre-training objectives (e.g., masked language modeling) do not explicitly encode chemical knowledge. It was unclear whether incorporating chemistry-specific supervision during pre-training could improve representation quality.</li>
</ol>
<p>MolBERT addresses these gaps by evaluating three pre-training tasks, including a novel physicochemical property prediction objective, and measuring their individual and combined effects on downstream drug discovery benchmarks.</p>
<h2 id="three-auxiliary-tasks-for-chemistry-aware-pre-training">Three Auxiliary Tasks for Chemistry-Aware Pre-Training</h2>
<p>MolBERT uses the BERT-Base architecture (12 attention heads, 12 layers, 768-dimensional hidden states, approximately 85M parameters) and explores three self-supervised pre-training tasks:</p>
<p><strong>Masked Language Modeling (MaskedLM)</strong>: The standard BERT objective where 15% of input tokens are masked and the model predicts their identity. The loss is cross-entropy between predicted and true tokens.</p>
<p><strong>SMILES Equivalence (SMILES-Eq)</strong>: A binary classification task where the model receives two SMILES strings and predicts whether they represent the same molecule. The second string is either a random permutation of the first (same molecule, different traversal) or a randomly sampled molecule. This is optimized with cross-entropy loss.</p>
<p><strong>Physicochemical Property Prediction (PhysChemPred)</strong>: Using <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>, a set of 200 real-valued molecular descriptors are computed for each molecule. The model predicts these normalized descriptors from the SMILES input using mean squared error:</p>
<p>$$\mathcal{L}_{\text{PhysChemPred}} = \frac{1}{D} \sum_{d=1}^{D} (y_d - \hat{y}_d)^2$$</p>
<p>where $D = 200$ is the number of descriptors, $y_d$ is the true normalized descriptor value, and $\hat{y}_d$ is the model&rsquo;s prediction.</p>
<p>The final training loss is the arithmetic mean of all active task losses:</p>
<p>$$\mathcal{L}_{\text{total}} = \frac{1}{|\mathcal{T}|} \sum_{t \in \mathcal{T}} \mathcal{L}_t$$</p>
<p>where $\mathcal{T}$ is the set of active pre-training tasks.</p>
<p>Additionally, MolBERT supports SMILES permutation augmentation during training, where each input molecule is represented by a randomly sampled non-canonical SMILES string rather than the canonical form. The model uses a fixed vocabulary of 42 tokens, a sequence length of 128, and relative positional embeddings (from Transformer-XL) to support arbitrary-length SMILES at inference time.</p>
<h2 id="ablation-study-and-benchmark-evaluation">Ablation Study and Benchmark Evaluation</h2>
<h3 id="pre-training-setup">Pre-Training Setup</h3>
<p>All models were pre-trained on the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol benchmark dataset</a>, consisting of approximately 1.6M compounds curated from ChEMBL, using an 80%/5% train/validation split. Training used the Adam optimizer with a learning rate of $3 \times 10^{-5}$ for 20 epochs (ablation) or 100 epochs (final model).</p>
<h3 id="ablation-impact-of-task-combinations-on-virtual-screening">Ablation: Impact of Task Combinations on Virtual Screening</h3>
<p>The ablation study evaluated all seven possible task combinations on the RDKit virtual screening benchmark (69 datasets, 5 query molecules per target). Results measured by AUROC and BEDROC20 (an early enrichment metric with $\alpha = 20$):</p>
<table>
  <thead>
      <tr>
          <th style="text-align: center">MaskedLM</th>
          <th style="text-align: center">PhysChemPred</th>
          <th style="text-align: center">SMILES-Eq</th>
          <th style="text-align: center">AUROC (w/ perm)</th>
          <th style="text-align: center">BEDROC20 (w/ perm)</th>
          <th style="text-align: center">AUROC (w/o perm)</th>
          <th style="text-align: center">BEDROC20 (w/o perm)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.685 +/- 0.069</td>
          <td style="text-align: center">0.246 +/- 0.041</td>
          <td style="text-align: center">0.707 +/- 0.059</td>
          <td style="text-align: center">0.280 +/- 0.042</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.738 +/- 0.060</td>
          <td style="text-align: center">0.323 +/- 0.071</td>
          <td style="text-align: center">0.740 +/- 0.066</td>
          <td style="text-align: center">0.322 +/- 0.065</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.483 +/- 0.092</td>
          <td style="text-align: center">0.092 +/- 0.069</td>
          <td style="text-align: center">0.493 +/- 0.068</td>
          <td style="text-align: center">0.108 +/- 0.070</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.476 +/- 0.077</td>
          <td style="text-align: center">0.064 +/- 0.034</td>
          <td style="text-align: center">0.514 +/- 0.165</td>
          <td style="text-align: center">0.084 +/- 0.014</td>
      </tr>
      <tr>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.696 +/- 0.058</td>
          <td style="text-align: center">0.283 +/- 0.077</td>
          <td style="text-align: center">0.676 +/- 0.060</td>
          <td style="text-align: center">0.250 +/- 0.073</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">0.719 +/- 0.057</td>
          <td style="text-align: center">0.293 +/- 0.071</td>
          <td style="text-align: center">0.716 +/- 0.061</td>
          <td style="text-align: center">0.290 +/- 0.076</td>
      </tr>
      <tr>
          <td style="text-align: center">No</td>
          <td style="text-align: center">No</td>
          <td style="text-align: center">Yes</td>
          <td style="text-align: center">0.129 +/- 0.067</td>
          <td style="text-align: center">0.005 +/- 0.037</td>
          <td style="text-align: center">0.508 +/- 0.068</td>
          <td style="text-align: center">0.048 +/- 0.035</td>
      </tr>
  </tbody>
</table>
<p>Key findings from the ablation:</p>
<ul>
<li>PhysChemPred had the highest individual impact (average BEDROC20 of 0.292 alone vs. 0.266 for MaskedLM alone).</li>
<li>Combining MaskedLM + PhysChemPred achieved the best performance (BEDROC20 of 0.323), though the additive gain from MaskedLM was modest (+0.031).</li>
<li>The SMILES-Eq task consistently decreased performance when added to other task combinations.</li>
</ul>
<p>A further sub-ablation on PhysChemPred descriptor groups showed that surface descriptors alone (49 of 200 descriptors) achieved nearly the same performance as the full set, suggesting molecular surface properties provide particularly informative supervision.</p>
<h3 id="virtual-screening-results">Virtual Screening Results</h3>
<p>Using the best task combination (MaskedLM + PhysChemPred) trained for 100 epochs:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>AUROC</th>
          <th>BEDROC20</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolBERT (100 epochs)</td>
          <td>0.743 +/- 0.062</td>
          <td>0.344 +/- 0.062</td>
      </tr>
      <tr>
          <td>CDDD</td>
          <td>0.725 +/- 0.057</td>
          <td>0.310 +/- 0.080</td>
      </tr>
      <tr>
          <td>RDKit descriptors</td>
          <td>0.633 +/- 0.027</td>
          <td>0.217 +/- 0.000</td>
      </tr>
      <tr>
          <td>ECFC4</td>
          <td>0.603 +/- 0.056</td>
          <td>0.170 +/- 0.079</td>
      </tr>
  </tbody>
</table>
<p>MolBERT outperformed all baselines including <a href="/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/">CDDD</a> (the prior state of the art), RDKit calculated descriptors, and extended-connectivity fingerprints (ECFC4).</p>
<h3 id="qsar-results">QSAR Results</h3>
<p>On <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> regression tasks (RMSE, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">RDKit (norm)</th>
          <th style="text-align: center">ECFC4</th>
          <th style="text-align: center">CDDD</th>
          <th style="text-align: center">MolBERT</th>
          <th style="text-align: center">MolBERT (finetune)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td style="text-align: center">0.687 +/- 0.08</td>
          <td style="text-align: center">0.902 +/- 0.06</td>
          <td style="text-align: center">0.567 +/- 0.06</td>
          <td style="text-align: center">0.552 +/- 0.07</td>
          <td style="text-align: center"><strong>0.531 +/- 0.04</strong></td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td style="text-align: center">1.671 +/- 0.45</td>
          <td style="text-align: center">2.876 +/- 0.38</td>
          <td style="text-align: center">1.456 +/- 0.43</td>
          <td style="text-align: center">1.523 +/- 0.66</td>
          <td style="text-align: center"><strong>0.948 +/- 0.33</strong></td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td style="text-align: center">0.738 +/- 0.04</td>
          <td style="text-align: center">0.770 +/- 0.03</td>
          <td style="text-align: center">0.669 +/- 0.02</td>
          <td style="text-align: center">0.602 +/- 0.01</td>
          <td style="text-align: center"><strong>0.561 +/- 0.03</strong></td>
      </tr>
  </tbody>
</table>
<p>On MoleculeNet classification tasks (AUROC, higher is better):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th style="text-align: center">RDKit (norm)</th>
          <th style="text-align: center">ECFC4</th>
          <th style="text-align: center">CDDD</th>
          <th style="text-align: center">MolBERT</th>
          <th style="text-align: center">MolBERT (finetune)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>BACE</td>
          <td style="text-align: center">0.831</td>
          <td style="text-align: center">0.845</td>
          <td style="text-align: center">0.833</td>
          <td style="text-align: center">0.849</td>
          <td style="text-align: center"><strong>0.866</strong></td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td style="text-align: center">0.696</td>
          <td style="text-align: center">0.678</td>
          <td style="text-align: center">0.761</td>
          <td style="text-align: center">0.750</td>
          <td style="text-align: center"><strong>0.762</strong></td>
      </tr>
      <tr>
          <td>HIV</td>
          <td style="text-align: center">0.708</td>
          <td style="text-align: center">0.714</td>
          <td style="text-align: center">0.753</td>
          <td style="text-align: center">0.747</td>
          <td style="text-align: center"><strong>0.783</strong></td>
      </tr>
  </tbody>
</table>
<p>Fine-tuned MolBERT achieved the best performance on all six QSAR datasets. When used as a fixed feature extractor with an SVM, MolBERT embeddings outperformed other representations on three of six tasks.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Pre-training task selection matters significantly.</strong> The choice of auxiliary tasks during pre-training has a large effect on downstream performance. PhysChemPred provides the strongest individual signal.</li>
<li><strong>Domain-relevant auxiliary tasks improve representation quality.</strong> Predicting physicochemical properties during pre-training encodes chemical knowledge directly into the embeddings, outperforming purely linguistic objectives.</li>
<li><strong>The SMILES equivalence task hurts performance.</strong> Despite being chemically motivated, the SMILES-Eq task consistently degraded results, suggesting it may introduce conflicting learning signals.</li>
<li><strong>PhysChemPred organizes the embedding space.</strong> Analysis of pairwise cosine similarities showed that models trained with PhysChemPred assign high similarity to permutations of the same molecule and low similarity to different molecules, creating a more semantically meaningful representation space.</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li>The paper evaluates only SMILES-based representations, inheriting all limitations of string-based molecular encodings (inability to capture 3D structure, sensitivity to tokenization).</li>
<li>The virtual screening evaluation uses a fixed number of query molecules ($n = 5$), which may not reflect realistic screening scenarios.</li>
<li>Cross-validation splits from ChemBench were used for QSAR evaluation rather than scaffold splits, which may overestimate performance on structurally novel compounds.</li>
<li>The model&rsquo;s 128-token sequence length limit may truncate larger molecules, though relative positional embeddings partially address this at inference time.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors propose extending MolBERT to learn representations for other biological entities such as proteins, and developing more advanced pre-training strategies.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>GuacaMol (ChEMBL)</td>
          <td>~1.6M compounds</td>
          <td>80% train / 5% validation split</td>
      </tr>
      <tr>
          <td>Virtual Screening</td>
          <td>RDKit benchmark v1.2</td>
          <td>69 target datasets</td>
          <td>Filtered subset with active/decoy compounds</td>
      </tr>
      <tr>
          <td>QSAR (Regression)</td>
          <td>ESOL, FreeSolv, Lipophilicity</td>
          <td>Varies</td>
          <td>From MoleculeNet, ChemBench splits</td>
      </tr>
      <tr>
          <td>QSAR (Classification)</td>
          <td>BACE, BBBP, HIV</td>
          <td>Varies</td>
          <td>From MoleculeNet, ChemBench splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: BERT-Base (12 heads, 12 layers, 768-dim hidden, ~85M params)</li>
<li>Optimizer: Adam, learning rate $3 \times 10^{-5}$</li>
<li>Vocabulary: 42 tokens, sequence length 128</li>
<li>Masking: 15% of tokenized input</li>
<li>Positional encoding: relative positional embeddings (Transformer-XL)</li>
<li>Fine-tuning SVM: $C = 5.0$, RBF kernel (from Winter et al.)</li>
<li>Fine-tuning head: single linear layer on pooled output</li>
<li>Embeddings: pooled output (or average sequence output when only MaskedLM is used)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>BERT-Base with ~85M parameters</li>
<li>Pre-trained weights available at <a href="https://github.com/BenevolentAI/MolBERT">BenevolentAI/MolBERT</a></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUROC</td>
          <td>Virtual Screening, Classification QSAR</td>
          <td>Standard area under ROC curve</td>
      </tr>
      <tr>
          <td>BEDROC20</td>
          <td>Virtual Screening</td>
          <td>Early enrichment metric, $\alpha = 20$</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression QSAR</td>
          <td>Root mean squared error</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>2 GPUs, 16 CPUs</li>
<li>Pre-training time: ~40 hours (20 epochs)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BenevolentAI/MolBERT">BenevolentAI/MolBERT</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Official implementation with pre-trained weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fabian, B., Edlich, T., Gaspar, H., Segler, M., Meyers, J., Fiscato, M., &amp; Ahmed, M. (2020). Molecular representation learning with language models and domain-relevant auxiliary tasks. <em>arXiv preprint arXiv:2011.13230</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{fabian2020molecular,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Molecular representation learning with language models and domain-relevant auxiliary tasks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fabian, Benedek and Edlich, Thomas and Gaspar, H{\&#39;e}l{\&#39;e}na and Segler, Marwin and Meyers, Joshua and Fiscato, Marco and Ahmed, Mohamed}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2011.13230}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LLM-Prop: Predicting Crystal Properties from Text</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/llm-prop-crystal-property-prediction/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/llm-prop-crystal-property-prediction/</guid><description>LLM-Prop fine-tunes the T5 encoder on crystal text descriptions to predict band gap, volume, and other properties, outperforming GNN baselines.</description><content:encoded><![CDATA[<h2 id="text-based-crystal-property-prediction-with-llms">Text-Based Crystal Property Prediction with LLMs</h2>
<p>LLM-Prop is a <strong>Method</strong> paper that proposes using the encoder portion of <a href="https://en.wikipedia.org/wiki/T5_(language_model)">T5</a> (a general-purpose language model) fine-tuned on crystal text descriptions to predict physical and electronic properties of crystalline materials. The primary contribution is demonstrating that text-based representations of crystals, generated by Robocrystallographer, can serve as effective inputs for <a href="/notes/chemistry/molecular-design/property-prediction/">property prediction</a>, outperforming graph neural network (GNN) baselines on several tasks despite using a non-domain-specific pre-trained model with fewer parameters.</p>
<h2 id="why-text-instead-of-crystal-graphs">Why Text Instead of Crystal Graphs?</h2>
<p>Graph neural networks have been the dominant approach for crystal property prediction. Models like CGCNN, MEGNet, and ALIGNN represent crystals as graphs where atoms are nodes and bonds are edges. However, GNNs face several fundamental challenges for crystals:</p>
<ol>
<li><strong>Periodicity encoding</strong>: Crystals have repetitive unit cell arrangements that are distinct from standard molecular graphs, and GNNs struggle to encode this periodicity efficiently.</li>
<li><strong>Information incorporation</strong>: Critical structural information like bond angles, <a href="https://en.wikipedia.org/wiki/Space_group">space group</a> symmetry, and <a href="https://en.wikipedia.org/wiki/Wyckoff_positions">Wyckoff sites</a> is difficult to incorporate into graph representations.</li>
<li><strong>Expressiveness</strong>: Graphs may lack the expressiveness needed to convey complex crystal information relevant to property prediction.</li>
</ol>
<p>Meanwhile, textual descriptions of crystals (generated by tools like Robocrystallographer) naturally encode space group information, bond geometries, coordination environments, and symmetry details in human-readable form. Despite this richness, text-based approaches for crystal property prediction had been largely unexplored.</p>
<h2 id="core-innovation-t5-encoder-with-careful-fine-tuning">Core Innovation: T5 Encoder with Careful Fine-Tuning</h2>
<p>The key insight of LLM-Prop is to take a pre-trained encoder-decoder model (<a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-small) and discard the decoder entirely, using only the encoder with a linear prediction head. This design has several advantages:</p>
<ul>
<li>Cutting the network in half (from ~60M to ~37M parameters) allows processing of longer input sequences</li>
<li>Longer sequences mean more crystal information can be included</li>
<li>The encoder-only approach avoids T5&rsquo;s known weakness at regression in text-to-text format</li>
</ul>
<p>The framework applies several preprocessing strategies to the crystal text descriptions:</p>
<ol>
<li><strong>Stopword removal</strong>: Standard English stopwords are removed, except digits and symbols carrying chemical information</li>
<li><strong>Numerical token replacement</strong>: Bond distances are replaced with a <code>[NUM]</code> token and bond angles with <code>[ANG]</code>, reducing sequence length while preserving structural cues</li>
<li><strong>[CLS] token prepending</strong>: A classification token is added at the start, and its learned embedding is used as input to the prediction layer</li>
<li><strong>Label scaling</strong>: For regression tasks, targets are normalized using z-score, min-max, or log normalization</li>
</ol>
<p>The normalization schemes are defined as:</p>
<p>$$
\hat{Y}_{i}(\text{z-score}) = \frac{Y_{i} - \mu}{\sigma}
$$</p>
<p>$$
\hat{Y}_{i}(\text{min-max}) = \frac{Y_{i} - Y_{\min}}{Y_{\max} - Y_{\min}}
$$</p>
<p>$$
\hat{Y}_{i}(\text{log-norm}) = \log(Y_{i} + 1)
$$</p>
<p>The tokenizer is also retrained on the crystal text corpus with a vocabulary size of 32k, and the special tokens <code>[NUM]</code>, <code>[ANG]</code>, and <code>[CLS]</code> are added to the vocabulary.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<h3 id="dataset-textedge">Dataset: TextEdge</h3>
<p>The authors collected data from the <a href="https://en.wikipedia.org/wiki/Materials_Project">Materials Project</a> database (as of November 2022), yielding 144,931 crystal structure-description pairs split into 125,098 training, 9,945 validation, and 9,888 test samples. Crystal text descriptions were generated using Robocrystallographer. The dataset covers six prediction tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Type</th>
          <th>Metric</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Band gap (eV)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Unit cell volume (A^3/cell)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Formation energy per atom (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Energy per atom (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Energy above hull (eV/atom)</td>
          <td>Regression</td>
          <td>MAE (lower is better)</td>
      </tr>
      <tr>
          <td>Is-gap-direct</td>
          <td>Classification</td>
          <td>AUC (higher is better)</td>
      </tr>
  </tbody>
</table>
<h3 id="baselines">Baselines</h3>
<p>Seven baselines were compared:</p>
<ul>
<li><strong>GNN-based</strong>: CGCNN, MEGNet, ALIGNN, DeeperGATGNN</li>
<li><strong>Classic ML</strong>: XGBoost, Random Forest (on Robocrystallographer features)</li>
<li><strong>Text-based</strong>: MatBERT (domain-specific pre-trained BERT, ~110M parameters)</li>
</ul>
<p>All models were trained and evaluated on the same dataset splits for fair comparison. GNN models were retrained on the new data rather than using results from older, smaller Materials Project versions.</p>
<h3 id="main-results-llm-prop-vs-gnn-baselines">Main Results: LLM-Prop vs. GNN Baselines</h3>
<p>When using crystal text descriptions as input, LLM-Prop achieved:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Band gap (eV)</th>
          <th>Volume (A^3/cell)</th>
          <th>FEPA (eV/atom)</th>
          <th>EPA (eV/atom)</th>
          <th>Ehull (eV/atom)</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN</td>
          <td>0.293</td>
          <td>188.834</td>
          <td>0.046</td>
          <td>0.082</td>
          <td>0.040</td>
          <td>0.830</td>
      </tr>
      <tr>
          <td>MEGNet</td>
          <td>0.304</td>
          <td>297.948</td>
          <td>0.077</td>
          <td>0.056</td>
          <td>0.051</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>ALIGNN</td>
          <td>0.250</td>
          <td>129.580</td>
          <td>0.027</td>
          <td>0.059</td>
          <td>0.028</td>
          <td>0.678</td>
      </tr>
      <tr>
          <td>DeeperGATGNN</td>
          <td>0.291</td>
          <td>111.857</td>
          <td>0.081</td>
          <td>0.116</td>
          <td>0.045</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>LLM-Prop (Descr.)</td>
          <td><strong>0.231</strong></td>
          <td><strong>39.252</strong></td>
          <td>0.056</td>
          <td>0.067</td>
          <td>0.047</td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>LLM-Prop outperformed the best GNN baseline (ALIGNN) by approximately 8% on <a href="https://en.wikipedia.org/wiki/Band_gap">band gap</a> prediction, 65% on volume prediction, and 3% on band gap classification (Is-gap-direct). For formation energy per atom, energy per atom, and energy above hull, ALIGNN retained an advantage.</p>
<h3 id="llm-prop-vs-matbert">LLM-Prop vs. MatBERT</h3>
<p>LLM-Prop also outperformed MatBERT (a domain-specific pre-trained BERT) across all tasks despite having roughly 3x fewer parameters. The table below shows the best result for each model across the three input preprocessing strategies (w/ Numbers, w/o Numbers, w/ [NUM]&amp;[ANG]):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Band gap (eV)</th>
          <th>Volume (A^3/cell)</th>
          <th>FEPA (eV/atom)</th>
          <th>EPA (eV/atom)</th>
          <th>Ehull (eV/atom)</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MatBERT (best)</td>
          <td>0.258</td>
          <td>54.969</td>
          <td>0.071</td>
          <td>0.098</td>
          <td>0.050</td>
          <td>0.722</td>
      </tr>
      <tr>
          <td>LLM-Prop (best)</td>
          <td><strong>0.231</strong></td>
          <td><strong>39.138</strong></td>
          <td><strong>0.056</strong></td>
          <td><strong>0.067</strong></td>
          <td><strong>0.047</strong></td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: LLM-Prop&rsquo;s best band gap (0.231) comes from the &ldquo;w/o Numbers&rdquo; configuration, while the best volume (39.138) comes from &ldquo;w/ Numbers&rdquo;. The best Is-gap-direct AUC (0.857) uses the &ldquo;[NUM]&amp;[ANG]&rdquo; configuration.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>The contribution of each preprocessing strategy was evaluated:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Band gap</th>
          <th>Volume</th>
          <th>Is-gap-direct (AUC)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LLM-Prop (baseline)</td>
          <td>0.256</td>
          <td>69.352</td>
          <td>0.796</td>
      </tr>
      <tr>
          <td>+ modified tokenizer</td>
          <td>0.247</td>
          <td>78.632</td>
          <td>0.785</td>
      </tr>
      <tr>
          <td>+ label scaling</td>
          <td>0.242</td>
          <td>44.515</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>+ [CLS] token</td>
          <td>0.231</td>
          <td>39.520</td>
          <td>0.842</td>
      </tr>
      <tr>
          <td>+ [NUM] token</td>
          <td>0.251</td>
          <td>86.090</td>
          <td>0.793</td>
      </tr>
      <tr>
          <td>+ [ANG] token</td>
          <td>0.242</td>
          <td>64.965</td>
          <td>0.810</td>
      </tr>
      <tr>
          <td>- stopwords</td>
          <td>0.252</td>
          <td>56.593</td>
          <td>0.779</td>
      </tr>
      <tr>
          <td>LLM-Prop+all (no space group)</td>
          <td>0.235</td>
          <td>97.457</td>
          <td>0.705</td>
      </tr>
      <tr>
          <td>LLM-Prop+all</td>
          <td><strong>0.229</strong></td>
          <td>42.259</td>
          <td><strong>0.857</strong></td>
      </tr>
  </tbody>
</table>
<p>The [CLS] token provided the single largest improvement across all tasks. Label scaling was critical for volume prediction (reducing MAE from 69.352 to 44.515). Removing space group information from descriptions degraded volume prediction dramatically (from 42.259 to 97.457), confirming that space group symmetry is a key factor.</p>
<h3 id="data-efficiency-and-transfer-learning">Data Efficiency and Transfer Learning</h3>
<p>LLM-Prop achieved SOTA results on band gap and volume prediction with only about 90k training samples (35k fewer than baselines). For volume prediction specifically, LLM-Prop outperformed all GNN baselines with just 30k training samples.</p>
<p>Transfer learning experiments showed that LLM-Prop transferred well between band gap and volume prediction tasks:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Volume-to-Band gap (Test)</th>
          <th>Band gap-to-Volume (Test)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CGCNN-transfer</td>
          <td>0.295</td>
          <td>182.997</td>
      </tr>
      <tr>
          <td>ALIGNN-transfer</td>
          <td>0.322</td>
          <td>136.164</td>
      </tr>
      <tr>
          <td>MatBERT-transfer</td>
          <td>0.266</td>
          <td>54.289</td>
      </tr>
      <tr>
          <td>LLM-Prop-transfer</td>
          <td><strong>0.244</strong></td>
          <td><strong>50.753</strong></td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-limitations-and-future-directions">Key Findings, Limitations, and Future Directions</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Text descriptions of crystals carry rich structural information (space groups, Wyckoff sites, coordination geometries) that is difficult to encode in graphs but naturally expressed in text</li>
<li>A carefully fine-tuned general-purpose LLM encoder can outperform domain-specific pre-trained models, challenging the assumption that in-domain pre-training is always necessary</li>
<li>Removing numerical information (bond distances and angles) from descriptions often improves performance, because current LLMs treat numbers as regular tokens without understanding their quantitative meaning</li>
<li>Longer input sequences correlate with better performance, with 888 tokens as the default maximum on the hardware used</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>The origin of LLM-Prop&rsquo;s performance advantage over GNNs is not fully understood. It remains unclear whether the boost comes from additional structured information in text or from the different data modality itself</li>
<li>LLM-Prop cannot perform zero-shot predictions since T5 was not pre-trained on materials science data</li>
<li>The approach depends on Robocrystallographer to generate text descriptions, adding a preprocessing dependency</li>
<li>Current LLMs&rsquo; inability to reason about numerical values limits the use of quantitative information in descriptions</li>
</ul>
<p><strong>Future directions</strong> suggested by the authors include investigating techniques to use <a href="/notes/chemistry/molecular-design/generation/autoregressive/3d-chemical-language-models-xyz-cif-pdb/">CIF files</a> directly as LLM inputs, developing new GNN architectures that incorporate space group and Wyckoff site information, and further exploring which information in crystal descriptions contributes most to each property prediction task.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>TextEdge</td>
          <td>144,931 crystals</td>
          <td>From Materials Project (Nov 2022), text generated by Robocrystallographer</td>
      </tr>
      <tr>
          <td>Training split</td>
          <td>TextEdge</td>
          <td>125,098</td>
          <td>Random split</td>
      </tr>
      <tr>
          <td>Validation split</td>
          <td>TextEdge</td>
          <td>9,945</td>
          <td>Random split</td>
      </tr>
      <tr>
          <td>Test split</td>
          <td>TextEdge</td>
          <td>9,888</td>
          <td>Random split</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam with one-cycle learning rate scheduler</li>
<li><strong>Learning rate</strong>: 1e-3 for LLM-Prop, 5e-5 for MatBERT</li>
<li><strong>Dropout</strong>: 0.2 for LLM-Prop, 0.5 for MatBERT</li>
<li><strong>Batch size</strong>: 64 (888 tokens) or 16 (2000 tokens) for LLM-Prop</li>
<li><strong>Epochs</strong>: 200-300 depending on task</li>
<li><strong>Loss</strong>: MAE for regression, BCE for classification</li>
<li><strong>Evaluation</strong>: MAE for regression, AUC for classification</li>
<li><strong>Each model run 5 times on test set</strong>, averaged MAE reported</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Base model</strong>: T5-small encoder (~60M parameters total, ~37M after discarding decoder and adding prediction head)</li>
<li><strong>Vocabulary size</strong>: 32k (retrained tokenizer)</li>
<li><strong>Max input tokens</strong>: 888 (default) or 2000</li>
<li><strong>Special tokens</strong>: [CLS], [NUM], [ANG]</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/vertaix/LLM-Prop">LLM-Prop</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://drive.google.com/drive/folders/1YCDBzwjwNRIc1FRkB662G3Y5AOWaokUG">TextEdge + Checkpoints</a></td>
          <td>Dataset + Model</td>
          <td>Not specified</td>
          <td>Benchmark dataset and trained model checkpoints</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: NVIDIA RTX A6000</li>
<li><strong>Training time</strong>: ~40 minutes per epoch for LLM-Prop</li>
<li><strong>Inference</strong>: ~1 minute for 10,000 materials on one GPU</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rubungo, A. N., Arnold, C. B., Rand, B. P., &amp; Dieng, A. B. (2025). LLM-Prop: predicting the properties of crystalline materials using large language models. <em>npj Computational Materials</em>, 11, 186. <a href="https://doi.org/10.1038/s41524-025-01536-2">https://doi.org/10.1038/s41524-025-01536-2</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rubungo2025llmprop,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{LLM-Prop: predicting the properties of crystalline materials using large language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Rubungo, Andre Niyongabo and Arnold, Craig B. and Rand, Barry P. and Dieng, Adji Bousso}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{npj Computational Materials}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{186}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41524-025-01536-2}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Link-INVENT: RL-Driven Molecular Linker Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/link-invent-generative-linker-design/</guid><description>Link-INVENT extends REINVENT for molecular linker design using RNN-based generation and reinforcement learning with flexible multi-parameter scoring.</description><content:encoded><![CDATA[<h2 id="a-method-for-generative-linker-design-with-reinforcement-learning">A Method for Generative Linker Design with Reinforcement Learning</h2>
<p>Link-INVENT is a <strong>Method</strong> paper that introduces a generative model for molecular linker design built on the <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> de novo design platform. The primary contribution is an encoder-decoder recurrent neural network (RNN) architecture that generates SMILES-based linkers connecting two molecular subunits, combined with a flexible multi-parameter optimization (MPO) scoring function and reinforcement learning (RL) to steer generation toward desired properties. Link-INVENT targets three practical drug discovery tasks: fragment linking, scaffold hopping, and <a href="https://en.wikipedia.org/wiki/Proteolysis_targeting_chimera">proteolysis targeting chimera</a> (PROTAC) design.</p>
<h2 id="why-linker-design-needs-flexible-multi-parameter-optimization">Why Linker Design Needs Flexible Multi-Parameter Optimization</h2>
<p>Generating suitable chemical linkers between molecular subunits is a central challenge in <a href="https://en.wikipedia.org/wiki/Fragment-based_lead_discovery">fragment-based drug discovery</a> (FBDD), scaffold hopping, and PROTAC design. Traditional computational approaches rely on database searches, inherently limiting the generalizability of proposed linkers to the pre-defined collection. Recent deep learning methods (DeLinker, SyntaLinker, 3DLinker, DiffLinker) can generate novel linkers but offer limited support for optimizing specific physicochemical properties. Users can typically control only linker length and a few properties like hydrogen-bond donor count.</p>
<p>The key gaps that Link-INVENT addresses are:</p>
<ol>
<li><strong>Conditioning on both subunits</strong>: Prior RNN-based approaches (SAMOA) generate linkers conditioned only on the SMILES sequence seen so far, which may not account for the second molecular subunit. Link-INVENT conditions on both warheads simultaneously.</li>
<li><strong>Flexible scoring</strong>: Existing DL-based linker design tools lack the ability to define tailored MPO objectives. Link-INVENT inherits <a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent4-generative-molecule-design/">REINVENT 4&rsquo;s</a> full scoring infrastructure and adds linker-specific properties.</li>
<li><strong>Generalizability</strong>: A single trained prior handles fragment linking, scaffold hopping, and PROTAC tasks without retraining.</li>
</ol>
<h2 id="core-innovation-conditional-linker-generation-with-augmented-likelihood-rl">Core Innovation: Conditional Linker Generation with Augmented Likelihood RL</h2>
<p>Link-INVENT&rsquo;s architecture is an encoder-decoder RNN adapted from the Lib-INVENT library design model. The encoder processes a pair of warheads (molecular subunits with defined exit vectors), and the decoder generates a linker token by token, yielding a connected molecule in SMILES format. The model uses three hidden layers of 512 LSTM cells with an embedding size of 256.</p>
<h3 id="training">Training</h3>
<p>The prior is trained on ChEMBL v27 data processed through reaction-based slicing to generate (linker, warheads pair, full molecule) tuples. <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">SMILES randomization</a> augments the training data at each epoch, improving chemical space generalizability. The prior is trained by maximizing the likelihood of generating a linker conditioned on the input warhead pair, with teacher forcing for stability.</p>
<h3 id="multi-parameter-optimization-via-rl">Multi-Parameter Optimization via RL</h3>
<p>The scoring function $S(x)$ is a weighted geometric mean of individual component scores:</p>
<p>$$
S(x) = \left(\prod_{i=1}^{n} C_{i}(x)^{w_{i}}\right)^{\frac{1}{\sum_{i=1}^{n} w_{i}}}
$$</p>
<p>where $x$ is a sampled linked molecule, $C_{i}(x)$ is the score for the $i$-th component, and $w_{i}$ is its weight.</p>
<p>The agent (initialized as a copy of the prior) is updated via the Difference of Augmented and Posterior likelihoods (DAP) loss. The <a href="/notes/chemistry/molecular-design/generation/rl-tuned/augmented-hill-climb-rl-molecule-generation/">augmented log likelihood</a> is:</p>
<p>$$
\log \pi_{\text{augmented}} = \log \pi_{\text{prior}} + \sigma \cdot S(x)
$$</p>
<p>where $\pi$ denotes a policy (token sampling probabilities conditioned on the sequence so far) and $\sigma$ is a scalar factor. The loss function is:</p>
<p>$$
J(\theta) = \left(\log \pi_{\text{augmented}} - \log \pi_{\text{agent}}\right)^{2}
$$</p>
<p>Minimizing $J(\theta)$ steers the agent to generate molecules that satisfy the scoring function while remaining anchored to the prior&rsquo;s chemical space.</p>
<h3 id="diversity-filters">Diversity Filters</h3>
<p>Link-INVENT uses Diversity Filters (DFs) to balance exploration and exploitation. Buckets of limited size track unique <a href="/notes/chemistry/molecular-design/generation/rl-tuned/memory-assisted-rl-diverse-molecular-design/">Bemis-Murcko scaffolds</a>. When a bucket is full, further sampling of that scaffold receives a score of zero, encouraging the agent to explore diverse chemical space regions.</p>
<h3 id="linker-specific-scoring-components">Linker-Specific Scoring Components</h3>
<p>New scoring components provide direct control over linker properties:</p>
<ul>
<li><strong>Linker effective length</strong>: number of bonds between attachment atoms</li>
<li><strong>Linker maximum graph length</strong>: bonds in the longest graph traversal path</li>
<li><strong>Linker length ratio</strong>: effective length divided by maximum graph length (controls branching)</li>
<li><strong>Linker ratio of rotatable bonds</strong>: rotatable bonds over total bonds (controls flexibility)</li>
<li><strong>Linker number of rings</strong>: controls linearity vs. cyclicity</li>
<li><strong>Linker number of HBDs</strong>: hydrogen-bond donors in the linker itself</li>
</ul>
<h2 id="experimental-evaluation-across-three-drug-discovery-tasks">Experimental Evaluation Across Three Drug Discovery Tasks</h2>
<p>Link-INVENT was evaluated through four experiments across three drug discovery applications, all using the same pre-trained prior.</p>
<h3 id="illustrative-example-two-benzene-rings">Illustrative Example: Two Benzene Rings</h3>
<p>A simple experiment linked two benzene rings with the objectives of limiting HBDs and requiring exactly one ring in the linker. Over 20 epochs, the agent learned to satisfy both objectives, demonstrating the basic RL-guided generation process.</p>
<h3 id="experiment-1a-fragment-linking-ck2-alpha-inhibitors">Experiment 1a: Fragment Linking (CK2 alpha Inhibitors)</h3>
<p>Based on the <a href="https://en.wikipedia.org/wiki/Casein_kinase_2">casein kinase 2</a> (CK2 alpha) fragment linking campaign by Fusco and Brear et al., Link-INVENT was tasked with linking two fragment hits while retaining the Lys68 hydrogen-bond interaction via a DockStream docking constraint (Glide/LigPrep backend). The scoring function also enforced linker length ratio &gt;= 70 and linker MW &lt;= 200 Da.</p>
<p>Over 100 epochs in triplicate, the agent generated molecules with gradually improving docking scores. Key results:</p>
<ul>
<li>Docking score distributions across triplicates were nearly identical, demonstrating reproducibility</li>
<li>Some generated molecules achieved more favorable docking scores than the reference ligand CAM4066 (-15.20 kcal/mol)</li>
<li>More than 5000 unique Bemis-Murcko scaffolds were generated, with minimal overlap across replicates</li>
<li>Binding pose analysis showed the generated linker closely resembled the ground-truth linker, retaining the Lys68 interaction</li>
</ul>
<h3 id="experiment-1b-comparison-fragment-linking-impdh-inhibitors">Experiment 1b: Comparison Fragment Linking (IMPDH Inhibitors)</h3>
<p>Using the IMPDH inhibitor fragment linking case study from Trapero et al., this experiment applied core constrained docking (fragment pose within 0.3 A of reference) and compared results to DeLinker and SyntaLinker. The scoring function enforced linker effective length in [3, 5], length ratio &gt;= 70, and linker MW &lt;= 150 Da.</p>
<p>Link-INVENT generated 8960 SMILES across 70 epochs (comparable to DeLinker&rsquo;s 9000 molecular graphs). Results:</p>
<ul>
<li>Link-INVENT generated molecules with more favorable docking scores than the reference ligand across triplicate runs</li>
<li>Of 20 DeLinker and 3 SyntaLinker example molecules, none and one (the recovered reference) docked better than or equal to the reference</li>
<li>Approximately 3000 unique Bemis-Murcko scaffolds were generated from 5000 total molecules</li>
<li>Link-INVENT&rsquo;s advantage comes from including docking explicitly as a learning objective rather than applying it post hoc</li>
</ul>
<h3 id="experiment-2-scaffold-hopping-dlk-inhibitor-cns-optimization">Experiment 2: Scaffold Hopping (DLK Inhibitor CNS Optimization)</h3>
<p>Based on Patel et al.&rsquo;s <a href="https://en.wikipedia.org/wiki/MAP3K12">dual leucine zipper kinase</a> (DLK) inhibitor campaign, Link-INVENT generated new scaffold ideas to improve CNS penetration while retaining potency. The scoring function included a Cys193 docking constraint plus CNS-compatible properties (HBDs &lt; 2, tPSA &lt;= 90 A squared, 3 &lt;= SlogP &lt;= 4, MW &lt;= 450 Da, 1-2 aromatic rings in linker).</p>
<p>The solution space was significantly narrower than fragment linking. The agent still generated diverse scaffolds with favorable docking scores, though fewer exceeded the reference ligand&rsquo;s score. Binding pose analysis confirmed retained Cys193 interactions and predicted additional Gln195 hydrogen bonds.</p>
<h3 id="experiment-3-protac-design-bcl-2mcl-1-dual-degradation">Experiment 3: PROTAC Design (Bcl-2/Mcl-1 Dual Degradation)</h3>
<p>Three sub-experiments demonstrated linker-specific scoring components for PROTAC design based on Wang et al.&rsquo;s Bcl-2/Mcl-1 dual degradation strategy:</p>
<table>
  <thead>
      <tr>
          <th>Sub-Experiment</th>
          <th>Objective</th>
          <th>Key Finding</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Sub-Exp 1: Linker length</td>
          <td>Generate linkers within specified length intervals [4,6], [7,9], [10,12], [13,15]</td>
          <td>Clear enrichment within target intervals vs. baseline broad distribution</td>
      </tr>
      <tr>
          <td>Sub-Exp 2: Linearity</td>
          <td>Control linear vs. cyclic linkers at fixed length [7,9]</td>
          <td>Baseline ratio ~1:2 linear:cyclic; enforcing linearity or cyclicity achieved strong enrichment</td>
      </tr>
      <tr>
          <td>Sub-Exp 3: Flexibility</td>
          <td>Generate linkers with Low [0,30], Moderate [40,60], or High [70,100] rotatable bond ratios</td>
          <td>Agent learned that rings and sp2 atoms yield rigidity; linear sp3 chains yield flexibility</td>
      </tr>
  </tbody>
</table>
<h2 id="key-findings-and-practical-implications-for-drug-discovery">Key Findings and Practical Implications for Drug Discovery</h2>
<p>Link-INVENT demonstrates several practical advantages for molecular linker design:</p>
<ol>
<li><strong>Single prior, multiple tasks</strong>: The same pre-trained model handles fragment linking, scaffold hopping, and PROTAC design without retraining.</li>
<li><strong>Docking as a learning signal</strong>: Including molecular docking explicitly in the scoring function (via DockStream) during RL yields molecules with more favorable docking scores than approaches that apply docking post hoc.</li>
<li><strong>Implicit 3D awareness</strong>: The docking constraint guides the agent toward 3D structural awareness without explicit 3D coordinate inputs, as demonstrated by the overlap between generated and reference binding poses.</li>
<li><strong>Diverse and reproducible output</strong>: Diversity filters ensure exploration of multiple chemical space regions, and triplicate experiments show consistent docking score distributions with minimal scaffold overlap.</li>
</ol>
<p>Limitations acknowledged by the authors include:</p>
<ul>
<li>The linker flexibility metric (ratio of rotatable bonds) is agnostic to intra-molecular hydrogen bonds and does not account for all rigidity factors</li>
<li>Molecular docking is an approximation that can be exploited (e.g., excessive HBDs achieving favorable scores at the expense of permeability)</li>
<li>Experiments 1a and 1b require a proprietary Schrodinger license for Glide/LigPrep docking</li>
<li>No direct experimental (wet-lab) validation was performed in this study</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Prior training</td>
          <td>ChEMBL v27 (reaction-sliced)</td>
          <td>Not specified</td>
          <td>Filtered for drug-like compounds, then reaction-based slicing with SMIRKS</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Held-out Bemis-Murcko scaffolds</td>
          <td>287 scaffolds</td>
          <td>Held out from training set</td>
      </tr>
      <tr>
          <td>SMILES augmentation</td>
          <td>Randomized SMILES per epoch</td>
          <td>Same tuples, different representations</td>
          <td>Improves generalizability</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-decoder RNN with 3 hidden layers of 512 LSTM cells, embedding size 256</li>
<li><strong>RL loss</strong>: DAP (Difference of Augmented and Posterior likelihoods)</li>
<li><strong>Batch size</strong>: 128 molecules per epoch</li>
<li><strong>Diversity filter</strong>: Bemis-Murcko scaffold buckets of size 25</li>
<li><strong>Score threshold</strong>: 0 (to store all molecules for analysis)</li>
<li><strong>Scoring function</strong>: Weighted geometric mean of component scores</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Single pre-trained prior used across all experiments</li>
<li>Agent initialized as copy of prior, updated via RL</li>
<li>Pre-trained prior available at GitHub repository</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>Molecular docking via DockStream with Glide/LigPrep backend</li>
<li>Triplicate runs for all experiments</li>
<li>Metrics: docking scores, unique Bemis-Murcko scaffold counts, binding pose overlap</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not reported in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/MolecularAI/Reinvent">REINVENT (Link-INVENT code)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Main codebase for Link-INVENT</td>
      </tr>
      <tr>
          <td><a href="https://github.com/MolecularAI/ReinventCommunity">ReinventCommunity (data + tutorial)</a></td>
          <td>Code + Data</td>
          <td>MIT</td>
          <td>Training/validation data, reaction SMIRKS, pre-trained prior, Jupyter tutorial</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. Code, training data, and pre-trained prior are publicly available. However, reproducing the docking-based experiments (1a, 1b, and 2) requires a proprietary Schrodinger license for Glide and LigPrep. The PROTAC experiments (Experiment 3) that use only physicochemical scoring are fully reproducible with the open-source code.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Guo, J., Knuth, F., Margreitter, C., Janet, J. P., Papadopoulos, K., Engkvist, O., &amp; Patronov, A. (2023). Link-INVENT: generative linker design with reinforcement learning. <em>Digital Discovery</em>, 2, 392-408. <a href="https://doi.org/10.1039/D2DD00115B">https://doi.org/10.1039/D2DD00115B</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{guo2023link,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Link-INVENT: generative linker design with reinforcement learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Guo, Jeff and Knuth, Franziska and Margreitter, Christian and Janet, Jon Paul and Papadopoulos, Kostas and Engkvist, Ola and Patronov, Atanas}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{392--408}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D2DD00115B}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Generative AI Survey for De Novo Molecule and Protein Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/generative-ai-drug-design-survey/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/generative-ai-drug-design-survey/</guid><description>Comprehensive survey of generative AI for de novo drug design covering molecule and protein generation with VAEs, GANs, diffusion, and flow models.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-generative-ai-for-drug-design">A Systematization of Generative AI for Drug Design</h2>
<p>This is a <strong>Systematization</strong> paper that provides a broad survey of generative AI methods applied to de novo drug design. The survey organizes the field into two overarching themes: small molecule generation and protein generation. Within each theme, the authors identify subtasks, catalog datasets and benchmarks, describe model architectures, and compare the performance of leading methods using standardized metrics. The paper covers over 200 references and provides 12 comparative benchmark tables.</p>
<p>The primary contribution is a unified organizational framework that allows both micro-level comparisons within each subtask and macro-level observations across the two application domains. The authors highlight parallel developments in both fields, particularly the shift from sequence-based to structure-based approaches and the growing dominance of diffusion models.</p>
<h2 id="the-challenge-of-navigating-de-novo-drug-design">The Challenge of Navigating De Novo Drug Design</h2>
<p>The drug design process requires creating ligands that interact with specific biological targets. These range from small molecules (tens of atoms) to large proteins (monoclonal antibodies). Traditional discovery methods are computationally expensive, with preclinical trials costing hundreds of millions of dollars and taking 3-6 years. The chemical space of potential drug-like compounds is estimated at $10^{23}$ to $10^{60}$, making brute-force exploration infeasible.</p>
<p>AI-driven generative methods have gained traction in recent years, with over 150 AI-focused biotech companies initiating small-molecule drugs in the discovery phase and 15 in clinical trials. The rate of AI-fueled drug design processes has expanded by almost 40% each year.</p>
<p>The rapid development of the field, combined with its inherent complexity, creates barriers for new researchers. Several prior surveys exist, but they focus on specific aspects: molecule generation, protein generation, antibody generation, or specific model architectures like diffusion models. This survey takes a broader approach, covering both molecule and protein generation under a single organizational framework.</p>
<h2 id="unified-taxonomy-two-themes-seven-subtasks">Unified Taxonomy: Two Themes, Seven Subtasks</h2>
<p>The survey&rsquo;s core organizational insight is structuring de novo drug design into two themes with distinct subtasks, while identifying common architectural patterns across them.</p>
<h3 id="generative-model-architectures">Generative Model Architectures</h3>
<p>The survey covers four main generative model families used across both molecule and protein generation:</p>
<p><strong><a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational Autoencoders (VAEs)</a></strong> encode inputs into a latent distribution and decode from sampled points. The encoder maps input $x$ to a distribution parameterized by mean $\mu_\phi(x)$ and variance $\sigma^2_\phi(x)$. Training minimizes reconstruction loss plus KL divergence:</p>
<p>$$\mathcal{L} = \mathcal{L}_{\text{recon}} + \beta \mathcal{L}_{\text{KL}}$$</p>
<p>where the KL loss is:</p>
<p>$$\mathcal{L}_{\text{KL}} = -\frac{1}{2} \sum_{k} \left(1 + \log(\sigma_k^{(i)2}) - \mu_k^{(i)2} - \sigma_k^{(i)2}\right)$$</p>
<p><strong><a href="/posts/what-is-a-gan/">Generative Adversarial Networks (GANs)</a></strong> use a generator-discriminator game. The generator $G$ creates instances from random noise $z$ sampled from a prior $p_z(z)$, while the discriminator $D$ distinguishes real from synthetic data:</p>
<p>$$\min_{G} \max_{D} \mathbb{E}_x[\log D(x; \theta_d)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z; \theta_g); \theta_d))]$$</p>
<p><strong>Flow-Based Models</strong> generate data by applying an invertible function $f: z_0 \mapsto x$ to transform a simple latent distribution (Gaussian) to the target distribution. The log-likelihood is computed using the change-of-variable formula:</p>
<p>$$\log p(x) = \log p_0(z) + \log \left| \det \frac{\partial f}{\partial z} \right|$$</p>
<p><strong>Diffusion Models</strong> gradually add Gaussian noise over $T$ steps in a forward process and learn to reverse the noising via a denoising neural network. The forward step is:</p>
<p>$$x_{t+1} = \sqrt{1 - \beta_t} x_t + \sqrt{\beta_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$</p>
<p>The training loss minimizes the difference between the true noise and the predicted noise:</p>
<p>$$L_t = \mathbb{E}_{t \sim [1,T], x_0, \epsilon_t} \left[ | \epsilon_t - \epsilon_\theta(x_t, t) |^2 \right]$$</p>
<p>Graph neural networks (GNNs), particularly equivariant GNNs (EGNNs), are commonly paired with these generative methods to handle 2D/3D molecular and protein inputs. Diffusion and flow-based models are often paired with GNNs for processing 2D/3D-based input, while VAEs and GANs are typically used for 1D input.</p>
<h2 id="small-molecule-generation-tasks-datasets-and-models">Small Molecule Generation: Tasks, Datasets, and Models</h2>
<h3 id="target-agnostic-molecule-design">Target-Agnostic Molecule Design</h3>
<p>The goal is to generate a set of novel, valid, and stable molecules without conditioning on any specific biological target. Models are evaluated on atom stability, molecule stability, validity, uniqueness, novelty, and QED (Quantitative Estimate of Drug-Likeness).</p>
<p><strong>Datasets</strong>: <a href="/notes/chemistry/datasets/qm9/">QM9</a> (small stable molecules from <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a>) and <a href="/notes/chemistry/datasets/geom/">GEOM</a>-Drug (more complex, drug-like molecules).</p>
<p>The field has shifted from SMILES-based VAEs (<a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">CVAE</a>, <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">GVAE</a>, SD-VAE) to 2D graph methods (JTVAE) and then to 3D diffusion-based models. Current leading methods on QM9:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>At Stb. (%)</th>
          <th>Mol Stb. (%)</th>
          <th>Valid (%)</th>
          <th>Val/Uniq. (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MiDi</td>
          <td>EGNN, Diffusion</td>
          <td>99.8</td>
          <td>97.5</td>
          <td>97.9</td>
          <td>97.6</td>
      </tr>
      <tr>
          <td>MDM</td>
          <td>EGNN, VAE, Diffusion</td>
          <td>99.2</td>
          <td>89.6</td>
          <td>98.6</td>
          <td>94.6</td>
      </tr>
      <tr>
          <td>JODO</td>
          <td>EGNN, Diffusion</td>
          <td>99.2</td>
          <td>93.4</td>
          <td>99.0</td>
          <td>96.0</td>
      </tr>
      <tr>
          <td>GeoLDM</td>
          <td>VAE, Diffusion</td>
          <td>98.9</td>
          <td>89.4</td>
          <td>93.8</td>
          <td>92.7</td>
      </tr>
      <tr>
          <td>EDM</td>
          <td>EGNN, Diffusion</td>
          <td>98.7</td>
          <td>82.0</td>
          <td>91.9</td>
          <td>90.7</td>
      </tr>
  </tbody>
</table>
<p>EDM provided an initial baseline using diffusion with an equivariant GNN. GCDM introduced attention-based geometric message-passing. MDM separately handles covalent bond edges and Van der Waals forces, and also addresses diversity through an additional distribution-controlling noise variable. GeoLDM maps molecules to a lower-dimensional latent space for more efficient diffusion. MiDi uses a &ldquo;relaxed&rdquo; EGNN and jointly models 2D and 3D information through a graph representation capturing both spatial and connectivity data.</p>
<p>On the larger GEOM-Drugs dataset, performance drops for most models:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>At Stb. (%)</th>
          <th>Mol Stb. (%)</th>
          <th>Valid (%)</th>
          <th>Val/Uniq. (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MiDi</td>
          <td>99.8</td>
          <td>91.6</td>
          <td>77.8</td>
          <td>77.8</td>
      </tr>
      <tr>
          <td>MDM</td>
          <td>&ndash;</td>
          <td>62.2</td>
          <td>99.5</td>
          <td>99.0</td>
      </tr>
      <tr>
          <td>GeoLDM</td>
          <td>84.4</td>
          <td>&ndash;</td>
          <td>99.3</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>EDM</td>
          <td>81.3</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
  </tbody>
</table>
<p>MiDi distinguishes itself for generating more stable complex molecules, though at the expense of validity. Models generally perform well on QM9 but show room for improvement on more complex GEOM-Drugs molecules.</p>
<h3 id="target-aware-molecule-design">Target-Aware Molecule Design</h3>
<p>Target-aware generation produces molecules for specific protein targets, using either ligand-based (LBDD) or structure-based (SBDD) approaches. SBDD methods have become more prevalent as protein structure information becomes increasingly available.</p>
<p><strong>Datasets</strong>: CrossDocked2020 (22.5M ligand-protein pairs), ZINC20, Binding MOAD.</p>
<p><strong>Metrics</strong>: Vina Score (docking energy), High Affinity Percentage, QED, SA Score (synthetic accessibility), Diversity (Tanimoto similarity).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Vina</th>
          <th>Affinity (%)</th>
          <th>QED</th>
          <th>SA</th>
          <th>Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DiffSBDD</td>
          <td>EGNN, Diffusion</td>
          <td>-7.333</td>
          <td>&ndash;</td>
          <td>0.467</td>
          <td>0.554</td>
          <td>0.758</td>
      </tr>
      <tr>
          <td>Luo et al.</td>
          <td>SchNet</td>
          <td>-6.344</td>
          <td>29.09</td>
          <td>0.525</td>
          <td>0.657</td>
          <td>0.720</td>
      </tr>
      <tr>
          <td>TargetDiff</td>
          <td>EGNN, Diffusion</td>
          <td>-6.3</td>
          <td>58.1</td>
          <td>0.48</td>
          <td>0.58</td>
          <td>0.72</td>
      </tr>
      <tr>
          <td>LiGAN</td>
          <td>CNN, VAE</td>
          <td>-6.144</td>
          <td>21.1</td>
          <td>0.39</td>
          <td>0.59</td>
          <td>0.66</td>
      </tr>
      <tr>
          <td>Pocket2Mol</td>
          <td>EGNN, MLP</td>
          <td>-5.14</td>
          <td>48.4</td>
          <td>0.56</td>
          <td>0.74</td>
          <td>0.69</td>
      </tr>
  </tbody>
</table>
<p>DrugGPT is an LBDD autoregressive model using transformers on tokenized protein-ligand pairs. Among the SBDD models, LiGAN introduces a 3D CNN-VAE framework, Pocket2Mol emphasizes binding pocket geometry using an EGNN with geometric vector MLP layers, and Luo et al. model atomic probabilities in the binding site using SchNet. TargetDiff performs diffusion on an EGNN and optimizes binding affinity by reflecting low atom type entropy. DiffSBDD applies an inpainting approach by masking and replacing segments of ligand-protein complexes. DiffSBDD leads in Vina score and diversity, while TargetDiff leads in high affinity. Interestingly, diffusion-based methods are outperformed by Pocket2Mol on drug-likeness metrics (QED and SA).</p>
<h3 id="molecular-conformation-generation">Molecular Conformation Generation</h3>
<p>Conformation generation involves producing 3D structures from 2D connectivity graphs. Models are evaluated on Coverage (COV, percentage of ground-truth conformations &ldquo;covered&rdquo; within an RMSD threshold) and Matching (MAT, average RMSD to closest ground-truth conformation).</p>
<p><strong>Datasets</strong>: GEOM-QM9, GEOM-Drugs, ISO17.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>GEOM-QM9 COV (%)</th>
          <th>GEOM-QM9 MAT</th>
          <th>GEOM-Drugs COV (%)</th>
          <th>GEOM-Drugs MAT</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Torsional Diff.</td>
          <td>Diffusion</td>
          <td>92.8</td>
          <td>0.178</td>
          <td>72.7*</td>
          <td>0.582</td>
      </tr>
      <tr>
          <td>DGSM</td>
          <td>MPNN, Diffusion</td>
          <td>91.49</td>
          <td>0.2139</td>
          <td>78.73</td>
          <td>1.0154</td>
      </tr>
      <tr>
          <td>GeoDiff</td>
          <td>GFN, Diffusion</td>
          <td>90.07</td>
          <td>0.209</td>
          <td>89.13</td>
          <td>0.8629</td>
      </tr>
      <tr>
          <td>ConfGF</td>
          <td>GIN, Diffusion</td>
          <td>88.49</td>
          <td>0.2673</td>
          <td>62.15</td>
          <td>1.1629</td>
      </tr>
      <tr>
          <td>GeoMol</td>
          <td>MPNN</td>
          <td>71.26</td>
          <td>0.3731</td>
          <td>67.16</td>
          <td>1.0875</td>
      </tr>
  </tbody>
</table>
<p>*Torsional Diffusion uses a 0.75 A threshold instead of the standard 1.25 A for GEOM-Drugs coverage, leading to a deflated score. It outperforms GeoDiff and GeoMol when evaluated at the same threshold.</p>
<p>Torsional Diffusion operates in the space of torsion angles rather than Cartesian coordinates, allowing for improved representation and fewer denoising steps. GeoDiff uses Euclidean-space diffusion, treating each atom as a particle and incorporating Markov kernels that preserve E(3) equivariance through a graph field network (GFN) layer.</p>
<h2 id="protein-generation-from-sequence-to-structure">Protein Generation: From Sequence to Structure</h2>
<h3 id="protein-representation-learning">Protein Representation Learning</h3>
<p>Representation learning creates embeddings for protein inputs to support downstream tasks. Models are evaluated on contact prediction, fold classification (at family, superfamily, and fold levels), and stability prediction (Spearman&rsquo;s $\rho$).</p>
<p>Key models include: UniRep (mLSTM RNN), ProtBERT (BERT applied to amino acid sequences), ESM-1B (33-layer, 650M parameter transformer), MSA Transformer (pre-trained on MSA input), and GearNET (Geo-EGNN using 3D structure with directed edges). OntoProtein and KeAP incorporate knowledge graphs for direct knowledge injection.</p>
<h3 id="protein-structure-prediction">Protein Structure Prediction</h3>
<p>Given an amino acid sequence, models predict 3D point coordinates for each residue. Evaluated using RMSD, GDT-TS, TM-score, and LDDT on CASP14 and CAMEO benchmarks.</p>
<p>AlphaFold2 is the landmark model, integrating MSA and pair representations through transformers with invariant point attention (IPA). ESMFold uses ESM-2 language model representations instead of MSAs, achieving faster processing. RoseTTAFold uses a three-track neural network learning from 1D sequence, 2D distance map, and 3D backbone coordinate information simultaneously. EigenFold uses diffusion, representing the protein as a system of harmonic oscillators.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>CAMEO RMSD</th>
          <th>CAMEO TMScore</th>
          <th>CAMEO GDT-TS</th>
          <th>CAMEO lDDT</th>
          <th>CASP14 TMScore</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AlphaFold2</td>
          <td>Transformer</td>
          <td>3.30</td>
          <td>0.87</td>
          <td>0.86</td>
          <td>0.90</td>
          <td>0.38</td>
      </tr>
      <tr>
          <td>ESMFold</td>
          <td>Transformer</td>
          <td>3.99</td>
          <td>0.85</td>
          <td>0.83</td>
          <td>0.87</td>
          <td>0.68</td>
      </tr>
      <tr>
          <td>RoseTTAFold</td>
          <td>Transformer</td>
          <td>5.72</td>
          <td>0.77</td>
          <td>0.71</td>
          <td>0.79</td>
          <td>0.37</td>
      </tr>
      <tr>
          <td>EigenFold</td>
          <td>Diffusion</td>
          <td>7.37</td>
          <td>0.75</td>
          <td>0.71</td>
          <td>0.78</td>
          <td>&ndash;</td>
      </tr>
  </tbody>
</table>
<h3 id="sequence-generation-inverse-folding">Sequence Generation (Inverse Folding)</h3>
<p>Given a fixed protein backbone structure, models generate amino acid sequences that will fold into that structure. The space of valid sequences is between $10^{65}$ and $10^{130}$.</p>
<p>Evaluated using Amino Acid Recovery (AAR), diversity, RMSD, nonpolar loss, and perplexity (PPL):</p>
<p>$$\text{PPL} = \exp\left(\frac{1}{N} \sum_{i=1}^{N} \log P(x_i | x_1, x_2, \ldots x_{i-1})\right)$$</p>
<p>ProteinMPNN is the current top performer, generating the most accurate sequences and leading in AAR, RMSD, and nonpolar loss. It uses a message-passing neural network with a flexible, order-agnostic autoregressive approach.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>AAR (%)</th>
          <th>Div.</th>
          <th>RMSD</th>
          <th>Non.</th>
          <th>Time (s)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ProteinMPNN</td>
          <td>MPNN</td>
          <td>48.7</td>
          <td>0.168</td>
          <td>1.019</td>
          <td>1.061</td>
          <td>112</td>
      </tr>
      <tr>
          <td>ESM-IF1</td>
          <td>Transformer</td>
          <td>47.7</td>
          <td>0.184</td>
          <td>1.265</td>
          <td>1.201</td>
          <td>1980</td>
      </tr>
      <tr>
          <td>GPD</td>
          <td>Transformer</td>
          <td>46.2</td>
          <td>0.219</td>
          <td>1.758</td>
          <td>1.333</td>
          <td>35</td>
      </tr>
      <tr>
          <td>ABACUS-R</td>
          <td>Transformer</td>
          <td>45.7</td>
          <td>0.124</td>
          <td>1.482</td>
          <td>0.968</td>
          <td>233280</td>
      </tr>
      <tr>
          <td>3D CNN</td>
          <td>CNN</td>
          <td>44.5</td>
          <td>0.272</td>
          <td>1.62</td>
          <td>1.027</td>
          <td>536544</td>
      </tr>
      <tr>
          <td>PiFold</td>
          <td>GNN</td>
          <td>42.8</td>
          <td>0.141</td>
          <td>1.592</td>
          <td>1.464</td>
          <td>221</td>
      </tr>
      <tr>
          <td>ProteinSolver</td>
          <td>GNN</td>
          <td>24.6</td>
          <td>0.186</td>
          <td>5.354</td>
          <td>1.389</td>
          <td>180</td>
      </tr>
  </tbody>
</table>
<p>Results are from the independent benchmark by Yu et al. GPD remains the fastest method, generating sequences around three times faster than ProteinMPNN. Current SOTA models recover fewer than half of target amino acid residues, indicating room for improvement.</p>
<h3 id="backbone-design">Backbone Design</h3>
<p>Backbone design creates protein structures from scratch, representing the core of de novo protein design. Models generate coordinates for backbone atoms (nitrogen, alpha-carbon, carbonyl, oxygen) and use external tools like Rosetta for side-chain packing.</p>
<p>Two evaluation paradigms exist: context-free generation (evaluated by self-consistency TM, or scTM) and context-given generation (inpainting, evaluated by AAR, PPL, RMSD).</p>
<p>ProtDiff represents residues as 3D Cartesian coordinates and uses particle-filtering diffusion. FoldingDiff instead uses an angular representation (six angles per residue) with a BERT-based DDPM. LatentDiff embeds proteins into a latent space using an equivariant autoencoder, then applies equivariant diffusion, analogous to GeoLDM for molecules. These early models work well for short proteins (up to 128 residues) but struggle with longer structures.</p>
<p>Frame-based methods address this scaling limitation. Genie uses Frenet-Serret frames with paired residue representations and IPA for noise prediction. FrameDiff parameterizes backbone structures on the $SE(3)^N$ manifold of frames using a score-based generative model. RFDiffusion is the current leading model, combining RoseTTAFold structure prediction with diffusion. It fine-tunes RoseTTAFold weights on a masked input sequence and random noise coordinates, using &ldquo;self-conditioning&rdquo; on predicted structures. Protpardelle co-designs sequence and structure by creating a &ldquo;superposition&rdquo; over possible sidechain states and collapsing them during each iterative diffusion step.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>scTM (%)</th>
          <th>Design. (%)</th>
          <th>PPL</th>
          <th>AAR (%)</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RFDiffusion</td>
          <td>Diffusion</td>
          <td>&ndash;</td>
          <td>95.1</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>Protpardelle</td>
          <td>Diffusion</td>
          <td>85</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>FrameDiff</td>
          <td>Diffusion</td>
          <td>84</td>
          <td>48.3</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>Genie</td>
          <td>Diffusion</td>
          <td>81.5</td>
          <td>79.0</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>LatentDiff</td>
          <td>EGNN, Diffusion</td>
          <td>31.6</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>FoldingDiff</td>
          <td>Diffusion</td>
          <td>14.2</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>ProtDiff</td>
          <td>EGNN, Diffusion</td>
          <td>11.8</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>12.47*</td>
          <td>8.01*</td>
      </tr>
  </tbody>
</table>
<p>*ProtDiff context-given results are tested only on beta-lactamase metalloproteins from PDB.</p>
<h3 id="antibody-design">Antibody Design</h3>
<p>The survey covers antibody structure prediction, representation learning, and CDR-H3 generation. Antibodies are Y-shaped proteins with complementarity-determining regions (CDRs), where CDR-H3 is the most variable and functionally important region.</p>
<p>For CDR-H3 generation, models have progressed from sequence-based (LSTM) to structure-based (RefineGNN) and sequence-structure co-design approaches (MEAN, AntiDesigner, DiffAb). dyMEAN is the current leading model, providing an end-to-end method incorporating structure prediction, docking, and CDR generation into a single framework. MSA alignment cannot be used for antibody input, which makes general models like AlphaFold2 inefficient for antibody prediction. Specialized models like IgFold use sequence embeddings from AntiBERTy with invariant point attention to achieve faster antibody structure prediction.</p>
<h3 id="peptide-design">Peptide Design</h3>
<p>The survey briefly covers peptide generation, including models for therapeutic peptide generation (MMCD), peptide-protein interaction prediction (PepGB), peptide representation learning (PepHarmony), peptide sequencing (AdaNovo), and signal peptide prediction (PEFT-SP).</p>
<h2 id="current-trends-challenges-and-future-directions">Current Trends, Challenges, and Future Directions</h2>
<h3 id="current-trends">Current Trends</h3>
<p>The survey identifies several parallel trends across molecule and protein generation:</p>
<ol>
<li>
<p><strong>Shift from sequence to structure</strong>: In molecule generation, graph-based diffusion models (GeoLDM, MiDi, TargetDiff) now dominate. In protein generation, structure-based representation learning (GearNET) and diffusion-based backbone design (RFDiffusion) have overtaken sequence-only methods.</p>
</li>
<li>
<p><strong>Dominance of E(3) equivariant architectures</strong>: EGNNs appear across nearly all subtasks, reflecting the physical requirement that molecular and protein properties should be invariant to rotation and translation.</p>
</li>
<li>
<p><strong>Structure-based over ligand-based approaches</strong>: In target-aware molecule design, SBDD methods that use 3D protein structures demonstrate clear advantages over LBDD approaches that operate on amino acid sequences alone.</p>
</li>
</ol>
<h3 id="challenges">Challenges</h3>
<p><strong>For small molecule generation:</strong></p>
<ul>
<li><strong>Complexity</strong>: Models perform well on simple QM9 but struggle with complex GEOM-Drugs molecules.</li>
<li><strong>Applicability</strong>: Generating molecules with high binding affinity to targets remains difficult.</li>
<li><strong>Explainability</strong>: Methods are black-box, offering no insight into why generated molecules have desired properties.</li>
</ul>
<p><strong>For protein generation:</strong></p>
<ul>
<li><strong>Benchmarking</strong>: Protein generative tasks lack a standard evaluative procedure, with variance between each model&rsquo;s metrics and testing conditions.</li>
<li><strong>Performance</strong>: SOTA models still struggle with fold classification, gene ontology, and antibody CDR-H3 generation.</li>
</ul>
<p>The authors also note that many generative tasks are evaluated using predictive models (e.g., classifier networks for binding affinity or molecular properties). Improvements to these classification methods would lead to more precise alignment with real-world biological applications.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors identify increasing performance in existing tasks, defining more applicable tasks (especially in molecule-protein binding, antibody generation), and exploring entirely new areas of research as key future directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>As a survey paper, this work does not produce new models, datasets, or experimental results. All benchmark numbers reported are from the original papers cited.</p>
<h3 id="data">Data</h3>
<p>The survey catalogs the following key datasets across subtasks:</p>
<table>
  <thead>
      <tr>
          <th>Subtask</th>
          <th>Datasets</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Target-agnostic molecule</td>
          <td>QM9, <a href="/notes/chemistry/datasets/geom/">GEOM</a>-Drug</td>
          <td>QM9 from <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a>; GEOM-Drug for complex molecules</td>
      </tr>
      <tr>
          <td>Target-aware molecule</td>
          <td>CrossDocked2020, ZINC20, Binding MOAD</td>
          <td>CrossDocked2020 most used (22.5M pairs)</td>
      </tr>
      <tr>
          <td>Conformation generation</td>
          <td><a href="/notes/chemistry/datasets/geom/">GEOM</a>-QM9, GEOM-Drugs, ISO17</td>
          <td>Conformer sets for molecules</td>
      </tr>
      <tr>
          <td>Protein structure prediction</td>
          <td>PDB, CASP14, CAMEO</td>
          <td>CASP biennial blind evaluation</td>
      </tr>
      <tr>
          <td>Protein sequence generation</td>
          <td>PDB, UniRef, UniParc, CATH, TS500</td>
          <td>CATH for domain classification</td>
      </tr>
      <tr>
          <td>Backbone design</td>
          <td>PDB, AlphaFoldDB, SCOP, CATH</td>
          <td>AlphaFoldDB for expanded structural coverage</td>
      </tr>
      <tr>
          <td>Antibody structure</td>
          <td>SAbDab, RAB</td>
          <td>SAbDab: all antibody structures from PDB</td>
      </tr>
      <tr>
          <td>Antibody CDR generation</td>
          <td>SAbDab, RAB, SKEMPI</td>
          <td>SKEMPI for affinity optimization</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gersteinlab/GenAI4Drug">GenAI4Drug</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Organized repository of all covered sources</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Tang, X., Dai, H., Knight, E., Wu, F., Li, Y., Li, T., &amp; Gerstein, M. (2024). A survey of generative AI for de novo drug design: New frontiers in molecule and protein generation. <em>Briefings in Bioinformatics</em>, 25(4), bbae338. <a href="https://doi.org/10.1093/bib/bbae338">https://doi.org/10.1093/bib/bbae338</a></p>
<p><strong>Publication</strong>: Briefings in Bioinformatics, Volume 25, Issue 4, 2024.</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2402.08703">arXiv: 2402.08703</a></li>
<li><a href="https://github.com/gersteinlab/GenAI4Drug">GitHub: GenAI4Drug</a></li>
<li><a href="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11247410/">PMC: PMC11247410</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{tang2024survey,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A survey of generative AI for de novo drug design: new frontiers in molecule and protein generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Tang, Xiangru and Dai, Howard and Knight, Elizabeth and Wu, Fang and Li, Yunyang and Li, Tianxiao and Gerstein, Mark}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{25}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{bbae338}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1093/bib/bbae338}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Foundation Models in Chemistry: A 2025 Perspective</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/foundation-models-chemistry-perspective/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/foundation-models-chemistry-perspective/</guid><description>Perspective reviewing foundation models for chemistry across property prediction, MLIPs, inverse design, and multi-domain applications.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-foundation-models-for-chemistry">A Systematization of Foundation Models for Chemistry</h2>
<p>This is a <strong>Systematization</strong> paper. It organizes the rapidly growing landscape of foundation models in chemistry into a coherent taxonomy. The paper distinguishes between &ldquo;small&rdquo; foundation models (pretrained for a single application domain) and &ldquo;big&rdquo; foundation models (adaptable across multiple domains such as property prediction and inverse design). It covers models based on graph neural networks (GNNs) and language models, reviews pretraining strategies (self-supervised, multimodal, supervised), and maps approximately 40 models across four application domains.</p>
<h2 id="why-a-foundation-model-perspective-for-chemistry">Why a Foundation Model Perspective for Chemistry?</h2>
<p>Foundation models have transformed NLP and computer vision through large-scale pretraining and transfer learning. In chemistry, however, several persistent challenges motivate the adoption of this paradigm:</p>
<ol>
<li><strong>Data scarcity</strong>: Chemical datasets are often small and expensive to generate (requiring experiments or quantum mechanical calculations), unlike the large annotated datasets available in NLP/CV.</li>
<li><strong>Poor generalization</strong>: ML models in chemistry frequently need to extrapolate to out-of-domain compounds (e.g., novel drug candidates, unseen crystal structures), where conventional models struggle.</li>
<li><strong>Limited transferability</strong>: Traditional ML interatomic potentials (MLIPs) are trained on system-specific datasets and cannot be easily transferred across different chemical systems.</li>
</ol>
<p>Foundation models address these by learning general representations from large unlabeled datasets, which can then be adapted to specific downstream tasks via finetuning. The paper argues that summarizing this fast-moving field is timely, given the diversity of approaches emerging across molecular property prediction, MLIPs, inverse design, and multi-domain applications.</p>
<h2 id="small-vs-big-foundation-models-a-two-tier-taxonomy">Small vs. Big Foundation Models: A Two-Tier Taxonomy</h2>
<p>The paper&rsquo;s central organizing framework distinguishes two scopes of foundation model:</p>
<p><strong>Small foundation models</strong> are pretrained models adapted to various tasks within a single application domain. Examples include:</p>
<ul>
<li>A model pretrained on large molecular databases that predicts multiple molecular properties (band gap, formation energy, etc.)</li>
<li>A universal MLIP that can simulate diverse chemical systems</li>
<li>A pretrained generative model adapted for inverse design of different target properties</li>
</ul>
<p><strong>Big foundation models</strong> span multiple application domains, handling both property prediction and inverse design within a single framework. These typically use multimodal learning (combining SMILES/graphs with text) or build on large language models.</p>
<h3 id="architectures">Architectures</h3>
<p>The paper reviews two primary architecture families:</p>
<p><strong>Graph Neural Networks (GNNs)</strong> represent molecules and crystals as graphs $G = (V, E)$ with nodes (atoms) and edges (bonds). Node features are updated through message passing:</p>
<p>$$
m_{i}^{t+1} = \sum_{j \in N(i)} M_{t}(v_{i}^{t}, v_{j}^{t}, e_{ij}^{t})
$$</p>
<p>$$
v_{i}^{t+1} = U_{t}(v_{i}^{t}, m_{i}^{t+1})
$$</p>
<p>After $T$ message-passing steps, a readout function produces a graph-level feature:</p>
<p>$$
g = R({v_{i}^{T} \mid i \in G})
$$</p>
<p>Recent equivariant GNNs (e.g., NequIP, MACE, EquformerV2) use vectorial features that respect geometric symmetries, improving expressivity for tasks sensitive to 3D structure.</p>
<p><strong>Language Models</strong> operate on string representations of molecules (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>) or crystal structures. Autoregressive models like GPT maximize:</p>
<p>$$
\prod_{t=1}^{T} P(y_{t} \mid x_{1}, x_{2}, \ldots, x_{t-1})
$$</p>
<p>Transformers use self-attention:</p>
<p>$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V
$$</p>
<h3 id="pretraining-strategies">Pretraining Strategies</h3>
<p>The paper categorizes pretraining methods into three self-supervised learning (SSL) approaches plus supervised and multimodal strategies:</p>
<table>
  <thead>
      <tr>
          <th>Strategy</th>
          <th>Mechanism</th>
          <th>Example Models</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Contrastive learning</td>
          <td>Maximize similarity between positive pairs, minimize for negatives</td>
          <td>GraphCL, MolCLR, GraphMVP, CrysGNN</td>
      </tr>
      <tr>
          <td>Predictive learning</td>
          <td>Predict self-generated labels (node context, functional groups, space group)</td>
          <td>GROVER, Hu et al., CrysGNN</td>
      </tr>
      <tr>
          <td>Generative learning</td>
          <td>Reconstruct masked nodes/edges or entire molecules/SMILES</td>
          <td><a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a></td>
      </tr>
      <tr>
          <td>Supervised pretraining</td>
          <td>Train on energy, forces, stress from DFT databases</td>
          <td>M3GNet, CHGNet, MACE-MP-0, MatterSim</td>
      </tr>
      <tr>
          <td>Multimodal learning</td>
          <td>Learn joint representations across SMILES/graph + text modalities</td>
          <td>KV-PLM, <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a>, MoleculeSTM, <a href="/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/">SPMM</a></td>
      </tr>
  </tbody>
</table>
<p>A common finding across studies is that combining local and global information (e.g., via contrastive learning between node-level and graph-level views, or supervised learning on both forces and total energy) produces more transferable representations.</p>
<h2 id="survey-of-models-across-four-domains">Survey of Models Across Four Domains</h2>
<h3 id="property-prediction">Property Prediction</h3>
<p>The paper reviews 13 models for molecular and materials property prediction. Key findings:</p>
<ul>
<li><strong>Contrastive learning approaches</strong> (GraphCL, MolCLR, GraphMVP) achieve strong results by defining positive pairs through augmentation, 2D/3D structure views, or crystal system membership.</li>
<li><strong>Language model approaches</strong> (<a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a>) show that transformers trained on SMILES via masked language modeling can compete with GNN-based approaches.</li>
<li><a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a>, pretrained on 1.1 billion SMILES from PubChem and ZINC, outperformed many baselines including GNNs on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> and <a href="/notes/chemistry/datasets/qm9/">QM9</a> benchmarks. Its attention maps captured molecular structural features directly from SMILES strings.</li>
<li>For crystalline materials, CrysGNN combined contrastive, predictive, and generative learning, demonstrating improvements even on small experimental datasets.</li>
</ul>
<h3 id="machine-learning-interatomic-potentials-mlips">Machine Learning Interatomic Potentials (MLIPs)</h3>
<p>The paper surveys 10 universal MLIPs, all using supervised learning on DFT-calculated energies, forces, and stresses:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture</th>
          <th>Training Data Size</th>
          <th>Key Capability</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>M3GNet</td>
          <td>GNN</td>
          <td>187K (MP)</td>
          <td>First universal MLIP</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>GNN</td>
          <td>1.58M (MPtrj)</td>
          <td>Predicts magnetic moments</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>MACE</td>
          <td>1.58M (MPtrj)</td>
          <td>35 diverse applications</td>
      </tr>
      <tr>
          <td>GNoME potential</td>
          <td>NequIP</td>
          <td>89M</td>
          <td>Zero-shot comparable to trained MLIPs</td>
      </tr>
      <tr>
          <td>MatterSim</td>
          <td>M3GNet/Graphormer</td>
          <td>17M</td>
          <td>SOTA on Matbench Discovery</td>
      </tr>
      <tr>
          <td>eqV2</td>
          <td>EquformerV2</td>
          <td>118M (OMat24)</td>
          <td>Structural relaxation</td>
      </tr>
  </tbody>
</table>
<p>The GNoME potential, trained on approximately 89 million data points, achieved zero-shot performance comparable to state-of-the-art MLIPs trained from scratch. MatterSim, trained on over 17 million entries across wide temperature (0-5000K) and pressure (0-1000 GPa) ranges, achieved state-of-the-art on Matbench Discovery and accurately computed thermodynamic and lattice dynamic properties.</p>
<h3 id="inverse-design">Inverse Design</h3>
<p>Few pretrained generative models for inverse design exist. The paper highlights three:</p>
<ul>
<li><strong>MatterGen</strong> (Microsoft): Diffusion model pretrained on Alexandria/MP databases (607K structures), finetuned for conditional generation on band gap, elastic modulus, spacegroup, and composition. Generated S.U.N. (stable, unique, novel) materials at rates more than 2x the previous state of the art.</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/">GP-MoLFormer</a></strong> (IBM): MoLFormer pretrained on 1.1B SMILES, finetuned via pair-tuning for property-guided molecular optimization.</li>
<li><strong>CrystalLLM</strong>: Finetuned LLaMA-2 70B for crystal generation with target spacegroup and composition using string representations and prompting.</li>
</ul>
<h3 id="multi-domain-models">Multi-Domain Models</h3>
<p>The paper covers two multi-domain categories:</p>
<p><strong>Property prediction + MLIP</strong>: Denoising pretraining learns virtual forces that guide noisy configurations back to equilibrium, connecting to force prediction. Joint multi-domain pretraining (JMP) from Meta FAIR achieved state-of-the-art on 34 of 40 tasks spanning molecules, crystals, and MOFs by training simultaneously on diverse energy/force databases.</p>
<p><strong>Property prediction + inverse design</strong>: Multimodal models (KV-PLM, <a href="/notes/chemistry/molecular-representations/multimodal/momu-molecular-multimodal-foundation/">MoMu</a>, MoleculeSTM, <a href="/notes/chemistry/molecular-representations/multimodal/molfm-multimodal-molecular-foundation/">MolFM</a>, <a href="/notes/chemistry/molecular-representations/multimodal/spmm-bidirectional-structure-property/">SPMM</a>) learn joint representations from molecular structures and text, enabling text-based inverse design and property prediction in a single framework. LLM-based models (<a href="/notes/chemistry/llm-applications/chemdfm-x/">ChemDFM</a>, <a href="/notes/chemistry/molecular-representations/multimodal/nach0-multimodal-chemical-language-model/">nach0</a>, <a href="/notes/chemistry/llm-applications/fine-tuning-gpt3-molecular-properties/">finetuned GPT-3</a>) can interact with humans and handle diverse chemistry tasks through instruction tuning.</p>
<h2 id="trends-and-future-directions">Trends and Future Directions</h2>
<h3 id="scope-expansion">Scope Expansion</h3>
<p>The authors identify three axes for expanding foundation model scope:</p>
<ol>
<li><strong>Material types</strong>: Most models target molecules or a single material class. Foundation models that span molecules, crystals, surfaces, and MOFs could exploit shared chemistry across materials.</li>
<li><strong>Modalities</strong>: Beyond SMILES, graphs, and text, additional modalities (images, spectral data like XRD patterns) remain underexplored.</li>
<li><strong>Downstream tasks</strong>: Extending to new chemistry and tasks through emergent capabilities, analogous to the capabilities observed in LLMs at scale.</li>
</ol>
<h3 id="performance-and-scaling">Performance and Scaling</h3>
<p>Key scaling challenges include:</p>
<ul>
<li><strong>Data quality vs. quantity</strong>: Noisy DFT labels (e.g., HOMO-LUMO gaps with high uncertainty from different functionals/basis sets) can limit scalability and out-of-distribution performance.</li>
<li><strong>GNN scalability</strong>: While transformers scale to hundreds of billions of parameters, GNNs have rarely been explored above one million parameters due to oversmoothing and the curse of dimensionality. Recent work by Sypetkowski et al. demonstrated scaling GNNs to 3 billion parameters with consistent improvements.</li>
<li><strong>Database integration</strong>: Combining datasets from different DFT codes requires proper alignment (e.g., total energy alignment methods).</li>
</ul>
<h3 id="efficiency">Efficiency</h3>
<p>For MLIPs, efficiency is critical since MD simulations require millions of inference steps. Approaches include:</p>
<ul>
<li>Knowledge distillation from expensive teacher models to lighter student models</li>
<li>Model compression techniques (quantization, pruning) adapted for GNNs</li>
<li>Investigating whether strict equivariance is always necessary</li>
</ul>
<h3 id="interpretability">Interpretability</h3>
<p>Foundation models can generate hallucinations or mode-collapsed outputs. The authors highlight recent interpretability advances (feature extraction from Claude 3, knowledge localization and editing in transformers) as promising directions for more reliable chemical applications.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<p><strong>Key findings</strong>:</p>
<ul>
<li>Combining local and global information in pretraining consistently improves downstream performance across all domains reviewed.</li>
<li>Self-supervised pretraining enables effective transfer learning even in low-data regimes, a critical advantage for chemistry.</li>
<li>Universal MLIPs have reached the point where zero-shot performance can be comparable to system-specific trained models.</li>
<li>Multimodal learning is the most promising approach for big foundation models capable of spanning property prediction and inverse design.</li>
</ul>
<p><strong>Limitations acknowledged by the authors</strong>:</p>
<ul>
<li>The precise definition of &ldquo;foundation model&rdquo; in chemistry is not established and varies by scope.</li>
<li>Most surveyed models focus on molecules, with crystalline materials less explored.</li>
<li>Benchmarks for low-data regimes and out-of-distribution performance are insufficient.</li>
<li>The paper focuses on three domains (property prediction, MLIPs, inverse design) and does not cover retrosynthesis, reaction prediction, or other chemical tasks in depth.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a perspective/review paper. No new data or models are introduced. The paper surveys existing models and their training datasets, summarized in Table 1 of the paper.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Not applicable (review paper). The paper describes pretraining strategies (contrastive, predictive, generative, supervised, multimodal) at a conceptual level with references to the original works.</p>
<h3 id="models">Models</h3>
<p>Not applicable (review paper). The paper catalogs approximately 40 foundation models across four domains. See Table 1 in the paper for the complete listing.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Not applicable (review paper). The paper references benchmark results from the original studies (MoleculeNet, QM9, Matbench, Matbench Discovery, JARVIS-DFT) but does not perform independent evaluation.</p>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Choi, J., Nam, G., Choi, J., &amp; Jung, Y. (2025). A Perspective on Foundation Models in Chemistry. <em>JACS Au</em>, 5(4), 1499-1518. <a href="https://doi.org/10.1021/jacsau.4c01160">https://doi.org/10.1021/jacsau.4c01160</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{choi2025perspective,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Perspective on Foundation Models in Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Choi, Junyoung and Nam, Gunwook and Choi, Jaesik and Jung, Yousung}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{JACS Au}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1499--1518}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/jacsau.4c01160}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Evolutionary Molecular Design via Deep Learning + GA</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/evolutionary-design-deep-learning-genetic-algorithm/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/evolutionary-design-deep-learning-genetic-algorithm/</guid><description>Kwon et al. combine an RNN decoder for SMILES reconstruction with a genetic algorithm operating on ECFP fingerprints for goal-directed molecular design.</description><content:encoded><![CDATA[<h2 id="fingerprint-based-evolutionary-molecular-design">Fingerprint-Based Evolutionary Molecular Design</h2>
<p>This is a <strong>Method</strong> paper that introduces an evolutionary design methodology (EDM) for goal-directed molecular optimization. The primary contribution is a three-component framework where (1) molecules are encoded as <a href="https://en.wikipedia.org/wiki/Chemical_similarity">extended-connectivity fingerprint</a> (ECFP) vectors, (2) a genetic algorithm evolves these fingerprint vectors through mutation and crossover, (3) a recurrent neural network (RNN) decodes the evolved fingerprints back into valid SMILES strings, and (4) a deep neural network (DNN) evaluates molecular fitness. The key advantage over prior evolutionary approaches is that no hand-crafted chemical rules or fragment libraries are needed, as the RNN learns valid molecular reconstruction from data.</p>
<h2 id="challenges-in-evolutionary-molecular-optimization">Challenges in Evolutionary Molecular Optimization</h2>
<p>Evolutionary algorithms for molecular design face two core challenges. First, maintaining chemical validity of evolved molecules is difficult when operating on graph or string representations directly. Prior methods rely on predefined chemical rules and fragment libraries to constrain structural modifications (atom/bond additions, deletions, substitutions), but these introduce bias and risk convergence to local optima. Each new application domain requires specifying new chemical rules, which may not exist for emerging areas. Second, fitness evaluation must be both efficient and accurate. Simple evaluation methods like structural similarity indices or semi-empirical quantum chemistry calculations reduce computational cost but may not capture complex property relationships.</p>
<p>High-throughput computational screening (HTCS) is a common alternative, but it depends on the quality of predefined virtual chemical libraries and often requires multiple iterative enumerations, limiting its ability to explore novel chemical space.</p>
<h2 id="core-innovation-evolving-fingerprints-with-neural-decoding">Core Innovation: Evolving Fingerprints with Neural Decoding</h2>
<p>The key insight is to perform genetic operations in fingerprint space rather than in molecular graph or SMILES string space. The framework comprises three learned functions:</p>
<p><strong>Encoding function</strong> $e(\cdot)$: Converts a SMILES string $\mathbf{m}$ into a 5000-dimensional ECFP vector $\mathbf{x}$ using Morgan fingerprints with a neighborhood radius of 6. This is a deterministic hash-based encoding (not learned).</p>
<p><strong>Decoding function</strong> $d(\cdot)$: An RNN with three hidden layers of 500 LSTM units that reconstructs a SMILES string from an ECFP vector. The RNN generates SMILES as a sequence of three-character substrings, conditioning each prediction on the current substring and the input ECFP vector:</p>
<p>$$d(\mathbf{x}) = \mathbf{m}, \quad \text{where } p(\mathbf{m}_{t+1} | \mathbf{m}_{t}, \mathbf{x})$$</p>
<p>The three-character substring approach reduces the ratio of invalid SMILES by imposing additional constraints on subsequent characters.</p>
<p><strong>Property prediction function</strong> $f(\cdot)$: A five-layer DNN with 250 hidden units per layer that predicts molecular properties from ECFP vectors:</p>
<p>$$\mathbf{t} = f(e(\mathbf{m}))$$</p>
<p>The RNN is trained by minimizing cross-entropy loss between the softmax output and the target SMILES string $\mathbf{m}_{i}$, learning the relationship $d(e(\mathbf{m}_{i})) = \mathbf{m}_{i}$. The DNN is trained by minimizing mean squared error between predicted and computed property values. Both use the Adam optimizer with mini-batch size 100, 500 training epochs, and dropout rate 0.5.</p>
<h3 id="genetic-algorithm-operations">Genetic Algorithm Operations</h3>
<p>The GA evolves ECFP vectors using the DEAP library with the following parameters:</p>
<ul>
<li><strong>Population size</strong>: 50</li>
<li><strong>Crossover rate</strong>: 0.7 (uniform crossover, mixing ratio 0.2)</li>
<li><strong>Mutation rate</strong>: 0.3 (Gaussian mutation, $N(0, 0.2^{2})$, applied to 1% of elements)</li>
<li><strong>Selection</strong>: Tournament selection with size 3, top 3 individuals as parents</li>
<li><strong>Termination</strong>: 500 generations or 30 consecutive generations without fitness improvement</li>
</ul>
<p>The evolutionary loop proceeds as follows: a seed molecule $\mathbf{m}_{0}$ is encoded to $\mathbf{x}_{0}$, mutated to generate a population $\mathbf{P}^{0} = {\mathbf{z}_{1}, \mathbf{z}_{2}, \ldots, \mathbf{z}_{L}}$, each vector is decoded via the RNN, validity is checked with RDKit, fitness is evaluated via the DNN, and the top parents produce the next generation through crossover and mutation.</p>
<h2 id="experimental-setup-light-absorbing-wavelength-optimization">Experimental Setup: Light-Absorbing Wavelength Optimization</h2>
<h3 id="training-data-and-deep-learning-performance">Training Data and Deep Learning Performance</h3>
<p>The models were trained on 10,000 to 100,000 molecules randomly sampled from PubChem (molecular weight 200-600 g/mol). Each molecule was labeled with DFT-computed excitation energy ($S_{1}$), <a href="https://en.wikipedia.org/wiki/HOMO_and_LUMO">HOMO, and LUMO</a> energies using B3LYP/6-31G.</p>
<table>
  <thead>
      <tr>
          <th>Training Data</th>
          <th>Validity (%)</th>
          <th>Reconstructability (%)</th>
          <th>$S_{1}$ (R, MAE)</th>
          <th>HOMO (R, MAE)</th>
          <th>LUMO (R, MAE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>100,000</td>
          <td>88.8</td>
          <td>62.4</td>
          <td>0.977, 0.185 eV</td>
          <td>0.948, 0.168 eV</td>
          <td>0.960, 0.195 eV</td>
      </tr>
      <tr>
          <td>50,000</td>
          <td>86.7</td>
          <td>60.1</td>
          <td>0.973, 0.198 eV</td>
          <td>0.945, 0.172 eV</td>
          <td>0.955, 0.209 eV</td>
      </tr>
      <tr>
          <td>30,000</td>
          <td>85.3</td>
          <td>59.8</td>
          <td>0.930, 0.228 eV</td>
          <td>0.934, 0.191 eV</td>
          <td>0.945, 0.224 eV</td>
      </tr>
      <tr>
          <td>10,000</td>
          <td>83.2</td>
          <td>55.7</td>
          <td>0.913, 0.278 eV</td>
          <td>0.885, 0.244 eV</td>
          <td>0.917, 0.287 eV</td>
      </tr>
  </tbody>
</table>
<p>Validity refers to the proportion of chemically valid SMILES after RDKit inspection. Reconstructability measures how often the RNN can reproduce the original molecule from its ECFP (62.4% at 100k training samples by matching canonical SMILES among 10,000 generated strings).</p>
<h3 id="design-task-1-unconstrained-s1-modification">Design Task 1: Unconstrained S1 Modification</h3>
<p>Fifty seed molecules with $S_{1}$ values between 3.8 eV and 4.2 eV were evolved in both increasing and decreasing directions. With 50,000 training samples, $S_{1}$ increased by approximately 60% on average in the increasing direction and showed slightly lower rates of change in the decreasing direction. The asymmetry is attributed to the skewed $S_{1}$ distribution of training data (average $S_{1}$ of 4.3-4.4 eV, higher than the seed median of 4.0 eV). Performance saturated at approximately 50,000 training samples.</p>
<h3 id="design-task-2-s1-modification-with-homolumo-constraints">Design Task 2: S1 Modification with HOMO/LUMO Constraints</h3>
<p>The same 50 seeds were evolved with constraints: $-7.0 \text{ eV} &lt; \text{HOMO} &lt; -5.0 \text{ eV}$ and $\text{LUMO} &lt; 0.0 \text{ eV}$. In the increasing $S_{1}$ direction, constraints suppressed the rate of change because both HOMO and LUMO bounds limit the achievable HOMO-LUMO gap. In the decreasing direction, constraints had minimal effect because LUMO could freely decrease while HOMO had sufficient room to rise within the allowed range.</p>
<h3 id="design-task-3-extrapolation-beyond-training-data">Design Task 3: Extrapolation Beyond Training Data</h3>
<p>To generate molecules with $S_{1}$ values below 1.77 eV (outside the training distribution, which had mean $S_{1}$ of 4.91 eV), the authors introduced iterative &ldquo;phases&rdquo;: generate molecules, compute their properties via DFT, retrain the models, and repeat. Starting from the 30 lowest-$S_{1}$ seed molecules with 300 generation runs per phase:</p>
<ul>
<li>Phase 1: Average $S_{1}$ = 2.20 eV, 12 molecules below 1.77 eV</li>
<li>Phase 2: Average $S_{1}$ = 2.22 eV, 37 molecules below 1.77 eV</li>
<li>Phase 3: Average $S_{1}$ = 2.31 eV, 58 molecules below 1.77 eV</li>
</ul>
<p>While the average $S_{1}$ rose slightly across phases, variance decreased (from 1.40 to 1.36), indicating the model concentrated its outputs closer to the target range. This active-learning-like loop demonstrates the framework can extend beyond the training distribution.</p>
<h3 id="design-task-4-guacamol-benchmarks">Design Task 4: GuacaMol Benchmarks</h3>
<p>The method was evaluated on the <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> goal-directed benchmark suite using the ChEMBL25 training dataset. The RNN model was retrained with three-character substrings.</p>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>Best of Dataset</th>
          <th><a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">SMILES LSTM</a></th>
          <th>SMILES GA</th>
          <th><a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph GA</a></th>
          <th><a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph MCTS</a></th>
          <th>cRNN</th>
          <th>EDM (ours)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Celecoxib rediscovery</td>
          <td>0.505</td>
          <td>1.000</td>
          <td>0.607</td>
          <td>1.000</td>
          <td>0.378</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Troglitazone rediscovery</td>
          <td>0.419</td>
          <td>1.000</td>
          <td>0.558</td>
          <td>1.000</td>
          <td>0.312</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>Thiothixene rediscovery</td>
          <td>0.456</td>
          <td>1.000</td>
          <td>0.495</td>
          <td>1.000</td>
          <td>0.308</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>LogP(-1.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.980</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>LogP(8.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>0.979</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>TPSA(150.0)</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>CNS MPO</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
          <td>1.000</td>
      </tr>
      <tr>
          <td>QED</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.948</td>
          <td>0.944</td>
          <td>0.948</td>
          <td>0.948</td>
      </tr>
  </tbody>
</table>
<p>The EDM achieves maximum scores on all eight tasks, matching the cRNN baseline. The 256 highest-scoring molecules from the ChEMBL25 test set were used as seeds, with 500 SMILES strings generated per seed.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="results">Results</h3>
<p>The evolutionary design framework successfully evolved seed molecules toward target properties across all four design tasks. The RNN decoder maintained 88.8% chemical validity at 100k training samples, and the DNN property predictor achieved correlation coefficients above 0.94 for $S_{1}$, HOMO, and LUMO prediction. The iterative retraining procedure enabled exploration outside the training data distribution, generating 58 molecules with $S_{1}$ below 1.77 eV after three phases. On GuacaMol benchmarks, the method achieved maximum scores on all eight tasks, matching <a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">SMILES LSTM</a>, <a href="/notes/chemistry/molecular-design/generation/search-based/graph-based-genetic-algorithm-chemical-space/">Graph GA</a>, and cRNN baselines.</p>
<h3 id="limitations">Limitations</h3>
<p>Several limitations are worth noting:</p>
<ol>
<li><strong>Reconstructability ceiling</strong>: Only 62.4% of molecules could be reconstructed from their ECFP vectors, meaning the RNN decoder fails to recover the original molecule approximately 38% of the time. This information loss in the ECFP encoding is a fundamental bottleneck.</li>
<li><strong>Data dependence</strong>: Performance is sensitive to the training data distribution. The asymmetric evolution rates for increasing vs. decreasing $S_{1}$ directly reflect the skewed training data.</li>
<li><strong>Structural constraints</strong>: Three heuristic constraints (fused ring sizes, number of fused rings, alkyl chain lengths) were still needed to maintain reasonable molecular structures, partially undermining the claim of a fully data-driven approach.</li>
<li><strong>DFT reliance</strong>: The extrapolation experiment requires DFT calculations in the loop, which are computationally expensive and may limit scalability.</li>
<li><strong>Limited benchmark scope</strong>: Only 8 GuacaMol tasks were tested, and all achieved perfect scores, making it difficult to differentiate from competing methods. The paper does not report on harder multi-objective benchmarks.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Evaluation</td>
          <td>PubChem random sample</td>
          <td>10,000-100,000 molecules</td>
          <td>MW 200-600 g/mol, labeled with DFT-computed $S_{1}$, HOMO, LUMO</td>
      </tr>
      <tr>
          <td>GuacaMol Benchmark</td>
          <td>ChEMBL25</td>
          <td>Standard split</td>
          <td>Used for retraining RNN; 256 top-scoring seeds</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Genetic algorithm</strong>: DEAP library; population 50, crossover rate 0.7, mutation rate 0.3, tournament size 3</li>
<li><strong>RNN decoder</strong>: 3 hidden layers, 500 LSTM units each, three-character substring generation</li>
<li><strong>DNN predictor</strong>: 5 layers, 250 hidden units, sigmoid activations, linear output</li>
<li><strong>Training</strong>: Adam optimizer, mini-batch 100, 500 epochs, dropout 0.5</li>
</ul>
<h3 id="models">Models</h3>
<p>All neural networks were implemented using Keras with the Theano backend (GPU-accelerated). No pre-trained model weights are publicly available.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>RNN validity</strong>: Proportion of chemically valid SMILES (RDKit check)</li>
<li><strong>Reconstructability</strong>: Fraction of seed molecules recoverable from ECFP (canonical SMILES match in 10,000 generated strings)</li>
<li><strong>DNN accuracy</strong>: Correlation coefficient (R) and MAE via 10-fold cross-validation</li>
<li><strong>Evolutionary performance</strong>: Average rate of $S_{1}$ change across 50 seeds; molecule count in target range</li>
<li><strong>GuacaMol</strong>: Standard rediscovery and property satisfaction benchmarks</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU models, training times, or computational requirements for the evolutionary runs. DFT calculations used the Gaussian 09 program suite with B3LYP/6-31G.</p>
<h3 id="artifacts">Artifacts</h3>
<p>No public code repository or pre-trained models are available. The paper is published under a CC-BY 4.0 license as open access in Scientific Reports.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.nature.com/articles/s41598-021-96812-8">Paper (Nature)</a></td>
          <td>Paper</td>
          <td>CC-BY 4.0</td>
          <td>Open access</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Partially Reproducible. The method is described in sufficient detail for reimplementation, but no code, trained models, or preprocessed datasets are released. The DFT calculations require Gaussian 09, a commercial software package.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kwon, Y., Kang, S., Choi, Y.-S., &amp; Kim, I. (2021). Evolutionary design of molecules based on deep learning and a genetic algorithm. <em>Scientific Reports</em>, 11, 17304. <a href="https://doi.org/10.1038/s41598-021-96812-8">https://doi.org/10.1038/s41598-021-96812-8</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kwon2021evolutionary,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Evolutionary design of molecules based on deep learning and a genetic algorithm}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kwon, Youngchun and Kang, Seokho and Choi, Youn-Suk and Kim, Inkoo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17304}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41598-021-96812-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DrugEx v3: Scaffold-Constrained Graph Transformer</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/drugex-v3-scaffold-graph-transformer/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/rl-tuned/drugex-v3-scaffold-graph-transformer/</guid><description>DrugEx v3 proposes a Graph Transformer with novel positional encoding for scaffold-constrained molecular generation via multi-objective reinforcement learning.</description><content:encoded><![CDATA[<h2 id="a-graph-transformer-method-for-scaffold-constrained-drug-design">A Graph Transformer Method for Scaffold-Constrained Drug Design</h2>
<p>This is a <strong>Method</strong> paper that introduces DrugEx v3, a Graph Transformer model for scaffold-constrained de novo drug design. The primary contribution is a novel positional encoding scheme for molecular graphs that allows a Transformer architecture to operate on graph-structured molecular data rather than <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. The model takes user-provided scaffold fragments as input and generates complete molecules through growing and connecting operations, trained with multi-objective reinforcement learning to optimize for both target affinity and drug-likeness.</p>
<h2 id="from-fixed-objectives-to-user-guided-scaffold-design">From Fixed Objectives to User-Guided Scaffold Design</h2>
<p>Prior versions of DrugEx (v1 and <a href="/notes/chemistry/molecular-design/generation/rl-tuned/drugex-v2-pareto-multi-objective-rl/">v2</a>) used RNN-based generators trained with reinforcement learning for de novo drug design, but they operated under fixed objectives and could not accept user-provided structural priors. If a medicinal chemist wanted to explore analogs of a specific scaffold, the model needed retraining from scratch. Meanwhile, SMILES-based molecular generators face inherent limitations for scaffold-constrained design: SMILES is a linear notation, so inserting fragments at multiple positions of a scaffold requires complex grammar handling, and small token changes can produce invalid molecules.</p>
<p>Several approaches had been proposed for scaffold-based generation, including graph generative models (Lim et al., 2019), DeepScaffold (Li et al., 2020), SMILES-based scaffold decorators (Arus-Pous et al., 2020), and SyntaLinker for fragment linking (Yang et al., 2020). DrugEx v3 aims to combine the advantages of graph representations (validity guarantees, local invariance, flexible extension) with the Transformer architecture&rsquo;s ability to handle complex dependencies, while maintaining the multi-objective reinforcement learning framework from DrugEx v2.</p>
<h2 id="graph-positional-encoding-for-molecular-transformers">Graph Positional Encoding for Molecular Transformers</h2>
<p>The core innovation is adapting the Transformer architecture to work directly with molecular graph representations. Two key modifications make this possible.</p>
<p><strong>Graph word encoding.</strong> Since atoms and bonds cannot be processed simultaneously in a graph, the authors combine them into a single index:</p>
<p>$$
W = T_{atom} \times 4 + T_{bond}
$$</p>
<p>where $T_{atom}$ is the atom type index and $T_{bond}$ is the bond type index (four bond types: single, double, triple, and none).</p>
<p><strong>Graph positional encoding.</strong> Standard sequential position encoding does not capture molecular topology. The authors propose an adjacency-matrix-based positional encoding:</p>
<p>$$
P = I_{Atom} \times L_{max} + I_{Connected}
$$</p>
<p>where $I_{Atom}$ is the current atom index, $L_{max}$ is the maximum sequence length, and $I_{Connected}$ is the index of the atom connected by the current bond. This encoding is then processed through the standard sinusoidal positional encoding:</p>
<p>$$
PE_{(p, 2i)} = \sin(pos / 10000^{2i / d_{m}})
$$</p>
<p>$$
PE_{(p, 2i+1)} = \cos(pos / 10000^{2i / d_{m}})
$$</p>
<p>with $d_{m} = 512$.</p>
<p><strong>Molecule generation procedure.</strong> Each molecule in the training data is represented as a five-row matrix encoding atom type, bond type, connected atom index, current atom index, and fragment index. The columns are divided into three sections: fragment (the scaffold), growing (new atoms added to fragments), and linking (bonds connecting grown fragments). The decoder uses a GRU-based recurrent layer to sequentially output atom type, bond type, connected atom index, and current atom index at each step, with chemical valence rules enforced at every generation step to guarantee valid molecules.</p>
<p><strong>Multi-objective reinforcement learning.</strong> The generator is trained with a policy gradient objective:</p>
<p>$$
J(\theta) = \mathbb{E}\left[R^{*}(y_{1:T}) | \theta\right] = \sum_{t=1}^{T} \log G(y_{t} | y_{1:t-1}) \cdot R^{\ast}(y_{1:T})
$$</p>
<p>where $R^{*}$ is a Pareto-based reward combining target affinity and QED drug-likeness score:</p>
<p>$$
R^{*} = \begin{cases} 0.5 + \frac{k - N_{undesired}}{2N_{desired}}, &amp; \text{if desired} \\ \frac{k}{2N_{undesired}}, &amp; \text{if undesired} \end{cases}
$$</p>
<p>with $k$ being the solution&rsquo;s index in the Pareto rank. An exploration strategy uses two networks: an exploitation network $G_{\theta}$ (updated by policy gradient) and an exploration network $G_{\phi}$ (fixed, pre-trained on ChEMBL), with an exploration rate $\varepsilon$ controlling how many scaffolds are routed to $G_{\phi}$ during training.</p>
<h2 id="experimental-setup-architecture-comparison-and-rl-optimization">Experimental Setup: Architecture Comparison and RL Optimization</h2>
<h3 id="data">Data</h3>
<p>The ChEMBL set (version 27) contained approximately 1.7 million molecules for pre-training, preprocessed via RDKit (charge neutralization, metal/fragment removal). The LIGAND set comprised 10,828 adenosine receptor ligands for fine-tuning. Each molecule was decomposed into fragments using the BRICS algorithm, creating scaffold-molecule pairs (up to 15 pairs per molecule with four fragments). The ChEMBL set yielded 9.3 million training pairs, and the LIGAND set produced 53,888 training pairs.</p>
<h3 id="architecture-comparison">Architecture comparison</h3>
<p>Four architectures were compared:</p>
<ol>
<li><strong>Graph Transformer</strong>: graph input with novel positional encoding</li>
<li><strong>Sequential Transformer</strong>: SMILES input with standard Transformer</li>
<li><strong>LSTM-BASE</strong>: SMILES encoder-decoder with three recurrent layers</li>
<li><strong>LSTM+ATTN</strong>: LSTM-BASE with an attention mechanism between encoder and decoder</li>
</ol>
<p>All models were pre-trained on ChEMBL and fine-tuned on the LIGAND set. The bioactivity predictor was a random forest regression model using 2048D ECFP6 fingerprints and 19D physicochemical descriptors, with an activity threshold of pX = 6.5 for the A2A adenosine receptor.</p>
<h3 id="evaluation-metrics">Evaluation metrics</h3>
<p>Five metrics were used: validity (parseable molecules), accuracy (scaffold containment), desirability (meeting all objectives), uniqueness, and novelty (not in ChEMBL). Diversity was measured using the Solow-Polasky index with Tanimoto distance on ECFP6 fingerprints:</p>
<p>$$
I(A) = \frac{1}{|A|} \mathbf{e}^{\intercal} F(\mathbf{s})^{-1} \mathbf{e}
$$</p>
<h3 id="hardware">Hardware</h3>
<p>Models were benchmarked on a server with NVIDIA Tesla P100 GPUs.</p>
<h2 id="key-results-graph-representation-advantages-and-rl-trade-offs">Key Results: Graph Representation Advantages and RL Trade-offs</h2>
<h3 id="pre-training-and-fine-tuning-performance">Pre-training and fine-tuning performance</h3>
<p>The Graph Transformer achieved the best overall performance across all metrics:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Validity (PT)</th>
          <th>Accuracy (PT)</th>
          <th>Validity (FT)</th>
          <th>Accuracy (FT)</th>
          <th>Novelty (FT)</th>
          <th>Uniqueness (FT)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph Transformer (512)</td>
          <td>100.0%</td>
          <td>99.3%</td>
          <td>100.0%</td>
          <td>99.2%</td>
          <td>68.9%</td>
          <td>82.9%</td>
      </tr>
      <tr>
          <td>Seq. Transformer (512)</td>
          <td>96.7%</td>
          <td>74.0%</td>
          <td>99.3%</td>
          <td>92.7%</td>
          <td>8.9%</td>
          <td>28.9%</td>
      </tr>
      <tr>
          <td>LSTM+ATTN (512)</td>
          <td>94.3%</td>
          <td>72.8%</td>
          <td>96.9%</td>
          <td>85.9%</td>
          <td>6.3%</td>
          <td>20.7%</td>
      </tr>
      <tr>
          <td>LSTM-BASE (512)</td>
          <td>93.9%</td>
          <td>52.4%</td>
          <td>98.7%</td>
          <td>81.6%</td>
          <td>3.9%</td>
          <td>19.2%</td>
      </tr>
  </tbody>
</table>
<p>PT = pre-trained, FT = fine-tuned. The Graph Transformer achieved 100% validity due to its explicit valence checking at each generation step. It also produced substantially more novel and unique molecules after fine-tuning compared to SMILES-based methods.</p>
<p>The authors identified four advantages of the graph representation over SMILES: (1) local invariance, where fragment ordering does not affect output; (2) global extendibility, where new atoms can be appended without restructuring existing data; (3) freedom from grammar constraints; and (4) direct accessibility of chemical valence rules for validity enforcement.</p>
<h3 id="reinforcement-learning-results">Reinforcement learning results</h3>
<p>With multi-objective RL (affinity + QED), 74.6% of generated molecules were predicted active at $\varepsilon = 0.0$. The exploration rate $\varepsilon$ trades off desirability against uniqueness:</p>
<table>
  <thead>
      <tr>
          <th>$\varepsilon$</th>
          <th>Desirability</th>
          <th>Uniqueness</th>
          <th>Novelty</th>
          <th>Diversity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>0.0</td>
          <td>74.6%</td>
          <td>60.7%</td>
          <td>60.6%</td>
          <td>0.879</td>
      </tr>
      <tr>
          <td>0.1</td>
          <td>66.8%</td>
          <td>75.0%</td>
          <td>74.6%</td>
          <td>0.842</td>
      </tr>
      <tr>
          <td>0.2</td>
          <td>61.6%</td>
          <td>80.2%</td>
          <td>79.4%</td>
          <td>0.879</td>
      </tr>
      <tr>
          <td>0.3</td>
          <td>56.8%</td>
          <td>89.8%</td>
          <td>88.8%</td>
          <td>0.874</td>
      </tr>
  </tbody>
</table>
<p>The authors report that $\varepsilon = 0.3$ produced the best balance between desirability and uniqueness, with 56.8% desired molecules and 89.8% uniqueness. Diversity remained above 0.84 across all settings.</p>
<h3 id="limitations">Limitations</h3>
<p>The Graph Transformer produced molecules with worse synthetic accessibility (SA scores) compared to SMILES-based methods, particularly after fine-tuning on the smaller LIGAND set. The authors attribute this to uncommon ring systems generated when the model handles long-distance dependencies. A kekulization issue also causes a small fraction of molecules to fail scaffold matching: aromatic bond inference during sanitization can alter the scaffold substructure. Without single-objective affinity constraint, the model generates molecules with molecular weight exceeding 500 Da, reducing drug-likeness. All bioactivity predictions rely on a random forest model rather than experimental validation, and the t-SNE analysis suggests some generated molecules fall outside the model&rsquo;s applicability domain.</p>
<h3 id="future-directions">Future directions</h3>
<p>The authors propose extending the Graph Transformer to accept protein information as input via proteochemometric modeling, enabling design of ligands for targets without known ligands. Lead optimization, where a &ldquo;hit&rdquo; serves as input to generate improved analogs, is also identified as a natural extension.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data-1">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ChEMBL v27</td>
          <td>~1.7M molecules (9.3M scaffold-molecule pairs)</td>
          <td>Preprocessed via RDKit</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>LIGAND set (A2A AR ligands from ChEMBL)</td>
          <td>10,828 ligands (53,888 pairs)</td>
          <td>Split 8:1:1 train/val/test</td>
      </tr>
      <tr>
          <td>Bioactivity labels</td>
          <td>ChEMBL A2A AR activity data</td>
          <td>pX threshold = 6.5</td>
          <td>Average pChEMBL values</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Fragment decomposition: BRICS algorithm via RDKit (max 4 fragments per molecule)</li>
<li>Optimizer: Adam with learning rate $10^{-4}$, batch size 256</li>
<li>Pre-training: 20 epochs; fine-tuning: up to 1,000 epochs with early stopping (patience: 100 epochs)</li>
<li>Bioactivity predictor: random forest regression (scikit-learn) with 2048D ECFP6 + 19D physicochemical descriptors</li>
<li>Pareto-based multi-objective ranking with GPU acceleration</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>Graph Transformer: 512 hidden units, 8 attention heads, $d_{k} = d_{v} = 64$</li>
<li>Sequential Transformer: same hidden size, sinusoidal positional encoding</li>
<li>LSTM-BASE / LSTM+ATTN: 128 embedding units, 512 hidden units, 3 recurrent layers</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Graph Transformer</th>
          <th>Best SMILES Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity (fine-tuned)</td>
          <td>100.0%</td>
          <td>99.6% (LSTM-BASE 1024)</td>
          <td>Valence checking guarantees validity</td>
      </tr>
      <tr>
          <td>Accuracy (fine-tuned)</td>
          <td>99.2%</td>
          <td>94.3% (Seq. Transformer 1024)</td>
          <td>Scaffold containment</td>
      </tr>
      <tr>
          <td>Desirability (RL, $\varepsilon$=0.0)</td>
          <td>74.6%</td>
          <td>N/A</td>
          <td>Only Graph Transformer used for RL</td>
      </tr>
      <tr>
          <td>Diversity (RL)</td>
          <td>0.879</td>
          <td>N/A</td>
          <td>Solow-Polasky index</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware-1">Hardware</h3>
<p>NVIDIA Tesla P100 GPUs. Specific training times not reported, but Transformer models trained faster than LSTM models with the same hidden layer size.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/CDDLeiden/DrugEx">CDDLeiden/DrugEx</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation (v1, v2, v3)</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL v27</a></td>
          <td>Dataset</td>
          <td>CC-BY-SA 3.0</td>
          <td>Pre-training data source</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, X., Ye, K., van Vlijmen, H. W. T., IJzerman, A. P., &amp; van Westen, G. J. P. (2023). DrugEx v3: scaffold-constrained drug design with graph transformer-based reinforcement learning. <em>Journal of Cheminformatics</em>, 15, 24. <a href="https://doi.org/10.1186/s13321-023-00694-z">https://doi.org/10.1186/s13321-023-00694-z</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{liu2023drugex,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DrugEx v3: scaffold-constrained drug design with graph transformer-based reinforcement learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Liu, Xuhan and Ye, Kai and van Vlijmen, Herman W. T. and IJzerman, Adriaan P. and van Westen, Gerard J. P.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{24}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-023-00694-z}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DeepSMILES: Adapting SMILES Syntax for Machine Learning</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/</guid><description>DeepSMILES modifies SMILES syntax to eliminate unbalanced parentheses and unpaired ring closures, reducing invalid outputs from generative molecular models.</description><content:encoded><![CDATA[<h2 id="a-new-molecular-string-notation-for-generative-models">A New Molecular String Notation for Generative Models</h2>
<p>This is a <strong>Method</strong> paper that introduces DeepSMILES, a modified SMILES syntax designed to reduce the rate of syntactically invalid strings produced by machine-learning generative models. The primary contribution is a pair of string-level transformations (for ring closures and for branches) that can be applied independently and interconverted with standard SMILES without loss of information, including stereochemistry.</p>
<h2 id="the-problem-of-invalid-smiles-in-molecular-generation">The Problem of Invalid SMILES in Molecular Generation</h2>
<p>Deep neural networks for de novo molecular design commonly operate on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational autoencoders</a> (<a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al., 2018</a>), recurrent neural networks with LSTM (<a href="/notes/chemistry/molecular-design/generation/autoregressive/lstm-drug-like-molecule-generation/">Segler et al., 2018</a>; Olivecrona et al., 2017), and grammar-based approaches (<a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Kusner et al., 2017</a>) all generate molecules by sampling character sequences. A persistent problem is that many generated strings are syntactically invalid SMILES, with reported validity rates ranging from 7% to 80%.</p>
<p>Two structural features of SMILES syntax are responsible for most invalid strings:</p>
<ol>
<li><strong>Balanced parentheses</strong>: Branches require matched open/close parenthesis pairs. A generative model must track nesting state across long sequences to produce valid brackets.</li>
<li><strong>Paired ring closure symbols</strong>: Rings require two identical digit tokens at corresponding positions. The model must remember which digits are &ldquo;open&rdquo; and close them appropriately.</li>
</ol>
<p>Grammar-based approaches (e.g., <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">Grammar VAE</a>) can enforce balanced parentheses through a context-free grammar, but they cannot enforce the ring closure pairing constraint because that constraint is context-sensitive. Syntax-directed approaches (Dai et al., 2018) add explicit ring closure constraints but at the cost of significantly more complex decoder architectures.</p>
<h2 id="core-innovation-postfix-branch-notation-and-single-ring-closure-symbols">Core Innovation: Postfix Branch Notation and Single Ring Closure Symbols</h2>
<p>DeepSMILES addresses both syntax problems through two independent string transformations.</p>
<h3 id="ring-closure-transformation">Ring closure transformation</h3>
<p>Standard SMILES uses a pair of identical digits to mark ring openings and closings (e.g., <code>c1ccccc1</code> for benzene). DeepSMILES eliminates the ring-opening digit and replaces the ring-closing digit with the ring size, counting back along the tree path to the ring-opening atom. Benzene becomes <code>cccccc6</code>, where <code>6</code> means &ldquo;connect to the atom 6 positions back.&rdquo;</p>
<p>This transformation has three key properties:</p>
<ul>
<li>Every ring of a given size always uses the same digit, regardless of context. A phenyl ring is always <code>cccccc6</code> in DeepSMILES, whereas in SMILES it might be <code>c1ccccc1</code>, <code>c2ccccc2</code>, <code>c3ccccc3</code>, etc.</li>
<li>A single symbol cannot be &ldquo;unmatched&rdquo; since there is no corresponding opening symbol.</li>
<li>For double-digit ring sizes, the <code>%N</code> notation is used (and <code>%(N)</code> for sizes above 99).</li>
</ul>
<p>Bond stereochemistry is preserved by moving any explicit or stereo bond from the eliminated ring-opening symbol to the ring-closing symbol, with direction adjusted as needed.</p>
<h3 id="branch-parenthesis-transformation">Branch (parenthesis) transformation</h3>
<p>Standard SMILES uses matched open/close parenthesis pairs for branches (e.g., <code>C(OC)(SC)F</code>). DeepSMILES replaces this with a postfix notation inspired by Reverse Polish Notation (RPN). Only close parentheses are used, and the number of consecutive close parentheses indicates how far back on the current branch the next atom attaches.</p>
<p>For example, <code>C(OC)(SC)F</code> becomes <code>COC))SC))F</code>. The interpretation uses a stack: atoms are pushed onto the stack as they are read, each close parenthesis pops one atom from the stack, and the next atom connects to whatever is on top of the stack.</p>
<h3 id="stereochemistry-preservation">Stereochemistry preservation</h3>
<p>Tetrahedral stereochemistry is fully preserved through the transformations. When ring closure symbol reordering would change the stereo configuration, the <code>@</code>/<code>@@</code> annotation is inverted during encoding to compensate.</p>
<h3 id="independence-of-transformations">Independence of transformations</h3>
<p>The two transformations are independent and can be applied separately or together. Any application of DeepSMILES should specify which transformations were applied.</p>
<h2 id="roundtrip-validation-on-chembl-23">Roundtrip Validation on ChEMBL 23</h2>
<p>The authors validated DeepSMILES by roundtripping all entries in the ChEMBL 23 database through SMILES-to-DeepSMILES-to-SMILES conversion. Canonical SMILES (including stereochemistry) were generated by four independent cheminformatics toolkits: CDK, OEChem, Open Babel, and RDKit. Using multiple toolkits ensures coverage of different traversal orders and ring closure ordering conventions.</p>
<p>All SMILES strings roundtripped without error across all three configurations (branches only, rings only, both). The exact string representation may differ in ring closure digit assignment or digit ordering, sometimes with an associated stereo inversion at tetrahedral centers, but the canonical SMILES of the original and roundtripped molecules are identical.</p>
<h3 id="performance-characteristics">Performance characteristics</h3>
<p>The following table shows the effect of DeepSMILES conversion on string length and throughput, measured on canonical SMILES from Open Babel for ChEMBL 23:</p>
<table>
  <thead>
      <tr>
          <th>Transformation</th>
          <th>Mean % change in length</th>
          <th>Encoding (per sec)</th>
          <th>Decoding (per sec)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Branches only</td>
          <td>+8.2%</td>
          <td>32,000</td>
          <td>16,000</td>
      </tr>
      <tr>
          <td>Rings only</td>
          <td>-6.4%</td>
          <td>26,000</td>
          <td>24,000</td>
      </tr>
      <tr>
          <td>Both</td>
          <td>+1.9%</td>
          <td>26,000</td>
          <td>17,500</td>
      </tr>
  </tbody>
</table>
<p>The ring transformation slightly shortens strings (by removing one digit per ring), while the branch transformation slightly lengthens them (additional close parentheses). Combined, the net effect is a small increase of about 2%. Throughput is in the tens of thousands of conversions per second in pure Python.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>DeepSMILES does not eliminate all invalid strings. Invalid DeepSMILES can still be generated, for example when there are more close parentheses than atoms on the stack, or when a ring size exceeds the number of available atoms. The reference implementation raises a <code>DecodeError</code> in these cases, though the authors note that a more tolerant decoder (ignoring extra parentheses or defaulting to the first atom for oversized rings) could be used during generation.</p>
<p>The paper assumes that input SMILES are generated by a standard cheminformatics toolkit as a depth-first traversal of the molecular graph. Non-standard SMILES (e.g., <code>CC(C1)CCCC1</code>) cannot be directly encoded.</p>
<p>The authors suggest several directions for future work:</p>
<ul>
<li>Investigating whether a preferred traversal order (e.g., shorter branches first) would make DeepSMILES even easier for models to learn.</li>
<li>Exploring notations where atoms in the organic subset explicitly list their hydrogen count, which would allow a fully parenthesis-free representation.</li>
<li>Using SMILES augmentation with random traversal orders (as explored by Bjerrum and Threlfall, 2017) in combination with DeepSMILES.</li>
<li>Designing entirely new line notations optimized for ML, where every string maps to a valid molecule, there are few duplicate representations, small string changes produce small structural changes, and string length correlates with pharmaceutical relevance.</li>
</ul>
<p>The fused ring case presents additional complexity: a bicyclic system has three cycles, and depending on traversal order, the ring size digit may not directly correspond to the ring size of any individual ring. This is an inherent limitation of depth-first traversal-based notations.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validation</td>
          <td>ChEMBL 23</td>
          <td>~1.7M compounds</td>
          <td>Canonical SMILES from CDK, OEChem, Open Babel, RDKit</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The DeepSMILES encoder and decoder are pure string-processing algorithms with no machine-learning components. The transformations operate on SMILES syntax tokens (atoms, bonds, parentheses, ring closure digits) without chemical interpretation.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Roundtrip accuracy</td>
          <td>100%</td>
          <td>All ChEMBL 23 entries across 4 toolkits</td>
      </tr>
      <tr>
          <td>Encoding throughput</td>
          <td>26,000-32,000/s</td>
          <td>Pure Python, varies by transformation</td>
      </tr>
      <tr>
          <td>Decoding throughput</td>
          <td>16,000-24,000/s</td>
          <td>Pure Python, varies by transformation</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>No specific hardware requirements. The implementation is a pure Python module with no GPU dependencies.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/nextmovesoftware/deepsmiles">deepsmiles</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Pure Python encoder/decoder</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: O&rsquo;Boyle, N. M., &amp; Dalke, A. (2018). DeepSMILES: An Adaptation of SMILES for Use in Machine-Learning of Chemical Structures. <em>ChemRxiv</em>. <a href="https://doi.org/10.26434/chemrxiv.7097960.v1">https://doi.org/10.26434/chemrxiv.7097960.v1</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{oboyle2018deepsmiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DeepSMILES: An Adaptation of SMILES for Use in Machine-Learning of Chemical Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{O&#39;Boyle, Noel M. and Dalke, Andrew}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{ChemRxiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.26434/chemrxiv.7097960.v1}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CogMol: Controlled Molecule Generation for COVID-19</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/cogmol-target-specific-drug-design/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/cogmol-target-specific-drug-design/</guid><description>CogMol combines a SMILES VAE with controlled latent space sampling to generate drug-like molecules with target specificity for novel viral proteins.</description><content:encoded><![CDATA[<h2 id="a-controlled-generation-framework-for-target-specific-drug-design">A Controlled Generation Framework for Target-Specific Drug Design</h2>
<p>This is a <strong>Method</strong> paper that introduces CogMol (Controlled Generation of Molecules), an end-to-end framework for de novo drug design. The primary contribution is a pipeline that combines a SMILES-based Variational Autoencoder (VAE) with multi-attribute controlled latent space sampling (CLaSS) to generate novel drug-like molecules with high binding affinity to specified protein targets, off-target selectivity, and favorable drug-likeness properties. The framework operates on protein sequence embeddings, allowing it to generalize to unseen target proteins without model retraining.</p>
<h2 id="multi-constraint-drug-design-for-novel-viral-targets">Multi-Constraint Drug Design for Novel Viral Targets</h2>
<p>Traditional drug discovery costs 2-3 billion USD and takes over a decade with less than 10% success rate. Generating drug molecules requires satisfying multiple competing objectives simultaneously: target binding affinity, off-target selectivity, synthetic accessibility, drug-likeness, and low toxicity. Prior generative approaches using reinforcement learning or Bayesian optimization are computationally expensive and typically require fine-tuning on target-specific ligand libraries, making them unable to generalize to unseen protein targets.</p>
<p>The emergence of SARS-CoV-2 in 2020 created an urgent need for antiviral drug candidates targeting novel viral proteins. Because no binding affinity data existed for these new targets, and the viral proteins were not closely related to proteins in existing databases like BindingDB, existing target-specific generative frameworks could not be directly applied. CogMol addresses this by using pre-trained protein sequence embeddings from UniRep (trained on 24 million UniRef50 sequences) rather than learning protein representations from the limited BindingDB training set.</p>
<h2 id="controlled-latent-space-sampling-with-pre-trained-protein-embeddings">Controlled Latent Space Sampling with Pre-trained Protein Embeddings</h2>
<p>CogMol&rsquo;s core innovation is a three-component architecture that enables multi-constraint molecule generation for unseen targets:</p>
<p><strong>1. SMILES VAE with adaptive pre-training.</strong> A Variational Autoencoder is first trained unsupervised on the MOSES/ZINC dataset (1.6M molecules), then jointly fine-tuned with QED and SA property predictors on BindingDB molecules. The standard VAE objective is:</p>
<p>$$\mathcal{L}_{\text{VAE}}(\theta, \phi) = \mathbb{E}_{p(x)} \left\{ \mathbb{E}_{q_\phi(z|x)} [\log p_\theta(x|z)] - D_{\text{KL}}(q_\phi(z|x) | p(z)) \right\}$$</p>
<p>where $q_\phi(z|x) = \mathcal{N}(z; \mu(x), \Sigma(x))$ specifies a diagonal Gaussian encoder distribution.</p>
<p><strong>2. Protein-molecule binding affinity predictor.</strong> A regression model takes pre-trained UniRep protein sequence embeddings and molecule latent embeddings $z$ as input and predicts pIC50 binding affinity ($= -\log(\text{IC50})$). Because UniRep embeddings capture sequence, structural, and functional relationships from a large unsupervised corpus, the predictor can estimate binding affinity for novel target sequences not present in the training data.</p>
<p><strong>3. CLaSS controlled sampling.</strong> The Conditional Latent attribute Space Sampling scheme generates molecules satisfying multiple constraints (affinity, QED, selectivity) through rejection sampling in the VAE latent space:</p>
<p>$$p(\mathbf{x} | \mathbf{a}) = \mathbb{E}_{\mathbf{z}} [p(\mathbf{z} | \mathbf{a}) , p(\mathbf{x} | \mathbf{z})] \approx \mathbb{E}_{\mathbf{z}} [\hat{p}_\xi(\mathbf{z} | \mathbf{a}) , p_\theta(\mathbf{x} | \mathbf{z})]$$</p>
<p>where $\mathbf{a} = [a_1, a_2, \ldots, a_n]$ is a set of independent attribute constraints. The conditional density $\hat{p}_\xi(\mathbf{z} | \mathbf{a})$ is approximated using a Gaussian mixture model $Q_\xi(\mathbf{z})$ and per-attribute classifiers $q_\xi(a_i | \mathbf{z})$, with Bayes&rsquo; rule and conditional independence assumptions. The acceptance probability equals the product of all attribute predictor scores, enabling efficient multi-constraint sampling without surrogate model or policy learning.</p>
<p><strong>Selectivity modeling.</strong> Off-target selectivity for a molecule $m$ against target $T$ is defined as:</p>
<p>$$\text{Sel}_{T,m} = \text{BA}(T, m) - \frac{1}{k} \sum_{i=1}^{k} \text{BA}(T_i, m)$$</p>
<p>where $\text{BA}(T, m)$ is binding affinity to the target and $T_i$ are $k$ randomly selected off-targets. This selectivity score is incorporated as a control attribute during CLaSS sampling.</p>
<h2 id="experimental-setup-covid-19-targets-and-in-silico-screening">Experimental Setup: COVID-19 Targets and In Silico Screening</h2>
<p><strong>Target proteins.</strong> CogMol was applied to three SARS-CoV-2 targets not present in BindingDB: NSP9 Replicase dimer, Main Protease (Mpro), and the Receptor-Binding Domain (RBD) of the spike protein. A cancer target (human HDAC1) with low ligand coverage in the training data was also evaluated.</p>
<p><strong>Training data.</strong> The SMILES VAE was trained on the MOSES benchmark (1.6M molecules from ZINC). The binding affinity predictor used curated IC50 data from BindingDB as reported in DeepAffinity, with all protein classes included in training.</p>
<p><strong>CLaSS controlled generation.</strong> Molecules were generated with simultaneous constraints on binding affinity (&gt; 0.5 normalized), QED (&gt; 0.8 normalized), and selectivity (&gt; 0.5 normalized). Approximately 1000 molecules per target were selected for downstream evaluation.</p>
<p><strong>In silico screening pipeline.</strong> Generated molecules underwent:</p>
<ul>
<li>Toxicity prediction via a multi-task deep neural network (MT-DNN) on 12 Tox21 in vitro endpoints and ClinTox clinical trial failure</li>
<li>Binding affinity rescoring with a higher-accuracy SMILES-level predictor</li>
<li>Blind docking (5 independent runs per molecule) using AutoDock Vina against target protein structures</li>
<li>Synthetic feasibility assessment using a retrosynthetic algorithm based on the Molecular Transformer trained on patent reaction data</li>
</ul>
<p><strong>Baselines.</strong> VAE performance was benchmarked against models from the MOSES platform. CLaSS-accepted molecules were compared against randomly sampled molecules from the latent space. Generated molecules were compared against FDA-approved drugs for toxicity and synthesizability.</p>
<h3 id="key-results">Key Results</h3>
<p><strong>CLaSS enrichment (Table 1).</strong> CLaSS consistently produced higher fractions of molecules meeting all criteria compared to random sampling. For the triple constraint (affinity &gt; 0.5, QED &gt; 0.8, selectivity &gt; 0.5), the enrichment was substantial: 6.9% vs. 0.7% for NSP9, 9.0% vs. 0.9% for RBD, and 10.4% vs. 1.1% for Mpro.</p>
<table>
  <thead>
      <tr>
          <th>Target</th>
          <th>CLaSS (Aff+QED+Sel)</th>
          <th>Random (Aff+QED+Sel)</th>
          <th>Enrichment</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NSP9</td>
          <td>6.9%</td>
          <td>0.7%</td>
          <td>~10x</td>
      </tr>
      <tr>
          <td>RBD</td>
          <td>9.0%</td>
          <td>0.9%</td>
          <td>~10x</td>
      </tr>
      <tr>
          <td>Mpro</td>
          <td>10.4%</td>
          <td>1.1%</td>
          <td>~9.5x</td>
      </tr>
  </tbody>
</table>
<p><strong>Docking results (Table 3).</strong> 87-95% of high-affinity generated molecules showed docking binding free energy (BFE) below -6 kcal/mol, with minimum BFEs reaching -8.6 to -9.5 kcal/mol depending on the target.</p>
<p><strong>Novelty.</strong> The likelihood of generating an exact duplicate of a training molecule was 2% or less. Against the full PubChem database (~103M molecules), exact matches ranged from 3.7% to 9.5%. Generated molecules also showed novel chemical scaffolds as confirmed by high Frechet ChemNet Distance.</p>
<p><strong>Synthesizability.</strong> Generated molecules for COVID-19 targets showed 85-90% synthetic feasibility using retrosynthetic analysis, exceeding the ~78% rate of FDA-approved drugs.</p>
<p><strong>Toxicity.</strong> Approximately 70% of generated parent molecules and ~80% of predicted metabolites were toxic in 0-1 endpoints out of 13, comparable to FDA-approved drugs.</p>
<h2 id="generated-molecules-show-favorable-binding-and-drug-like-properties">Generated Molecules Show Favorable Binding and Drug-Like Properties</h2>
<p>CogMol demonstrates that controlled latent space sampling with pre-trained protein embeddings can generate novel, drug-like molecules for unseen viral targets. The key findings are:</p>
<ol>
<li>CLaSS provides roughly 10x enrichment over random latent space sampling for molecules satisfying all three constraints (affinity, QED, selectivity).</li>
<li>Generated molecules bind favorably to druggable pockets in target protein 3D structures, even though the generation model uses only 1D sequence information.</li>
<li>Some generated SMILES matched existing PubChem molecules with known biological activity, suggesting the model identifies chemically relevant regions of molecular space.</li>
<li>The framework generalizes across targets of varying novelty, with Mpro (more similar to training proteins) yielding easier generation than NSP9 or RBD.</li>
</ol>
<p><strong>Limitations.</strong> The authors note that no wet-lab validation was performed on generated candidates. There may be divergence between ML-predicted properties and experimental measurements. The binding affinity predictor&rsquo;s accuracy is bounded by the quality and coverage of BindingDB training data. Selectivity modeling uses a random sample of off-targets rather than a pharmacologically curated panel.</p>
<p><strong>Future directions.</strong> The authors propose incorporating additional contexts beyond target protein (e.g., metabolic pathways), adding more pharmacologically relevant controls, and weighting objectives by relative importance.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>VAE pre-training</td>
          <td>MOSES/ZINC</td>
          <td>1.6M train, 176K test</td>
          <td>Publicly available benchmark</td>
      </tr>
      <tr>
          <td>VAE adaptive training</td>
          <td>BindingDB (DeepAffinity split)</td>
          <td>~27K protein-ligand pairs</td>
          <td>Curated IC50 data</td>
      </tr>
      <tr>
          <td>Protein embeddings</td>
          <td>UniRef50 via UniRep</td>
          <td>24M sequences</td>
          <td>Pre-trained, publicly available</td>
      </tr>
      <tr>
          <td>Toxicity prediction</td>
          <td>Tox21 + ClinTox</td>
          <td>12 in vitro + clinical endpoints</td>
          <td>Public benchmark datasets</td>
      </tr>
      <tr>
          <td>Docking validation</td>
          <td>PDB structures</td>
          <td>3 SARS-CoV-2 targets</td>
          <td>Public crystal structures</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>VAE architecture: SMILES encoder-decoder with diagonal Gaussian latent space, jointly trained with QED and SA regressors</li>
<li>CLaSS: rejection sampling from Gaussian mixture model of latent space with per-attribute classifiers</li>
<li>Binding affinity: regression on concatenated UniRep protein embeddings and VAE molecule embeddings</li>
<li>Selectivity: excess binding affinity over average of $k$ random off-targets</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>SMILES VAE with adaptive pre-training (ZINC then BindingDB)</li>
<li>Multi-task toxicity classifier (MT-DNN) for Tox21 and ClinTox endpoints</li>
<li>Binding affinity predictor (latent-level for generation, SMILES-level for screening)</li>
<li>Retrosynthetic predictor based on Molecular Transformer</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validity</td>
          <td>90%</td>
          <td>-</td>
          <td>Generated SMILES</td>
      </tr>
      <tr>
          <td>Uniqueness</td>
          <td>99%</td>
          <td>-</td>
          <td>Among valid molecules</td>
      </tr>
      <tr>
          <td>Filter pass</td>
          <td>95%</td>
          <td>-</td>
          <td>Relevant chemical filters</td>
      </tr>
      <tr>
          <td>Docking BFE &lt; -6 kcal/mol</td>
          <td>87-95%</td>
          <td>-</td>
          <td>Varies by target</td>
      </tr>
      <tr>
          <td>Synthetic feasibility</td>
          <td>85-90%</td>
          <td>78% (FDA drugs)</td>
          <td>COVID-19 targets</td>
      </tr>
      <tr>
          <td>Low toxicity (0-1 endpoints)</td>
          <td>~70% parent, ~80% metabolite</td>
          <td>Comparable to FDA drugs</td>
          <td>MT-DNN prediction</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU types or training times. The work was funded internally by IBM Research.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/CogMol">CogMol (GitHub)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://github.com/IBM/CogMol">~3500 generated molecules</a></td>
          <td>Dataset</td>
          <td>Open license</td>
          <td>For three SARS-CoV-2 targets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chenthamarakshan, V., Das, P., Hoffman, S. C., Strobelt, H., Padhi, I., Lim, K. W., Hoover, B., Manica, M., Born, J., Laino, T., &amp; Mojsilovic, A. (2020). CogMol: Target-Specific and Selective Drug Design for COVID-19 Using Deep Generative Models. <em>Advances in Neural Information Processing Systems</em>, 33, 4320-4332.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chenthamarakshan2020cogmol,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{CogMol: Target-Specific and Selective Drug Design for COVID-19 Using Deep Generative Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chenthamarakshan, Vijil and Das, Payel and Hoffman, Samuel C. and Strobelt, Hendrik and Padhi, Inkit and Lim, Kar Wai and Hoover, Benjamin and Manica, Matteo and Born, Jannis and Laino, Teodoro and Mojsilovi{\&#39;c}, Aleksandra}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{33}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{4320--4332}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemical Language Models for De Novo Drug Design Review</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/</guid><description>Review of chemical language models for de novo drug design covering string representations, architectures, training strategies, and experimental validation.</description><content:encoded><![CDATA[<h2 id="a-systematization-of-chemical-language-models-for-drug-design">A Systematization of Chemical Language Models for Drug Design</h2>
<p>This paper is a <strong>Systematization</strong> (minireview) that surveys the landscape of chemical language models (CLMs) for de novo drug design. It organizes the field along three axes: molecular string representations, deep learning architectures, and generation strategies (distribution learning, goal-directed, and conditional). The review also highlights experimental validations, current gaps, and future opportunities.</p>
<h2 id="why-chemical-language-models-matter-for-drug-design">Why Chemical Language Models Matter for Drug Design</h2>
<p>De novo drug design faces an enormous combinatorial challenge: the &ldquo;chemical universe&rdquo; is estimated to contain up to $10^{60}$ drug-like small molecules. Exhaustive enumeration is infeasible, and traditional design algorithms rely on hand-crafted assembly rules. Chemical language models address this by borrowing natural language processing techniques to learn the &ldquo;chemical language,&rdquo; generating molecules as string representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, DeepSMILES) that satisfy both syntactic validity (chemically valid structures) and semantic correctness (desired pharmacological properties).</p>
<p>CLMs have gained traction because string representations are readily available for most molecular databases, generation is computationally cheap (one molecule per forward pass through a sequence model), and the same architecture can be applied to diverse tasks (property prediction, de novo generation, reaction prediction). At the time of this review, CLMs had produced experimentally validated bioactive molecules in several prospective studies, establishing them as practical tools for drug discovery.</p>
<h2 id="molecular-string-representations-smiles-deepsmiles-and-selfies">Molecular String Representations: SMILES, DeepSMILES, and SELFIES</h2>
<p>The review covers three main string representations used as input/output for CLMs:</p>
<p><strong>SMILES</strong> (Simplified Molecular Input Line Entry Systems) converts hydrogen-depleted molecular graphs into strings where atoms are denoted by atomic symbols, bonds and branching by punctuation, and ring openings/closures by numbers. SMILES are non-univocal (multiple valid strings per molecule), and canonicalization algorithms are needed for unique representations. Multiple studies show that using randomized (non-canonical) SMILES for data augmentation improves CLM performance, with diminishing returns beyond 10- to 20-fold augmentation.</p>
<p><strong><a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a></strong> modifies SMILES to improve machine-readability by replacing the paired ring-opening/closure digits with a count-based system and using closing parentheses only (no opening ones). This reduces the frequency of syntactically invalid strings but does not eliminate them entirely.</p>
<p><strong><a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a></strong> (Self-Referencing Embedded Strings) use a formal grammar that guarantees 100% syntactic validity of decoded molecules. Every SELFIES string maps to a valid molecular graph. However, SELFIES can produce chemically unrealistic molecules (e.g., highly strained ring systems), and the mapping between string edits and molecular changes is less intuitive than for SMILES.</p>
<p>The review notes a key tradeoff: SMILES offer a richer, more interpretable language with well-studied augmentation strategies, while SELFIES guarantee validity at the cost of chemical realism and edit interpretability.</p>
<h2 id="clm-architectures-and-training-strategies">CLM Architectures and Training Strategies</h2>
<h3 id="architectures">Architectures</h3>
<p>The review describes the main architectures used in CLMs:</p>
<p><strong>Recurrent Neural Networks (RNNs)</strong>, particularly LSTMs and GRUs, dominated early CLM work. These models process SMILES character-by-character and generate new strings autoregressively via next-token prediction. RNNs are computationally efficient and well-suited to the sequential nature of molecular strings.</p>
<p><strong><a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Variational Autoencoders (VAEs)</a></strong> encode molecules into a continuous latent space and decode them back into strings. This enables smooth interpolation between molecules and latent-space optimization, but generated strings may be syntactically invalid.</p>
<p><strong><a href="/posts/what-is-a-gan/">Generative Adversarial Networks (GANs)</a></strong> have been adapted for molecular string generation (e.g., <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a>), though they face training instability and mode collapse challenges that limit their adoption.</p>
<p><strong>Transformers</strong> have emerged as an increasingly popular alternative, offering parallelized training and the ability to capture long-range dependencies in molecular strings. The review notes the growing relevance of Transformer-based CLMs, particularly for large-scale pretraining.</p>
<h3 id="generation-strategies">Generation Strategies</h3>
<p>The review organizes CLM generation into three categories:</p>
<ol>
<li>
<p><strong>Distribution learning</strong>: The model learns to reproduce the statistical distribution of a training set of molecules. No explicit scoring function is used during generation. The generated molecules are evaluated post-hoc by comparing their property distributions to the training set. This approach is end-to-end but provides no direct indication of individual molecule quality.</p>
</li>
<li>
<p><strong>Goal-directed generation</strong>: A pretrained CLM is steered toward molecules optimizing a specified scoring function (e.g., predicted bioactivity, physicochemical properties). Common approaches include reinforcement learning (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/reinvent-deep-rl-molecular-design/">REINVENT</a> and variants), hill-climbing, and Bayesian optimization. Scoring functions provide direct quality signals but can introduce biases, shortcuts, and limited structural diversity.</p>
</li>
<li>
<p><strong>Conditional generation</strong>: An intermediate approach that learns a joint semantic space between molecular structures and desired properties. The desired property profile serves as an input &ldquo;prompt&rdquo; for generation (e.g., a protein target, gene expression signature, or 3D shape). This bypasses the need for external scoring functions but has seen limited experimental application.</p>
</li>
</ol>
<h3 id="transfer-learning-and-chemical-space-exploration">Transfer Learning and Chemical Space Exploration</h3>
<p>Transfer learning is the dominant paradigm for CLM-driven chemical space exploration. A large-scale pretraining step (on $10^5$ to $10^6$ molecules via next-character prediction) is followed by fine-tuning on a smaller set of molecules with desired properties (often 10 to $10^2$ molecules). Key findings from the literature:</p>
<ul>
<li>The minimum training set size depends on target molecule complexity and heterogeneity.</li>
<li>SMILES augmentation is most beneficial with small training sets (fewer than 10,000 molecules) and plateaus for large, structurally complex datasets.</li>
<li>Fine-tuning with as few as 10 to 100 molecules has produced experimentally validated bioactive designs.</li>
<li>Hyperparameter tuning has relatively little effect on overall CLM performance.</li>
</ul>
<h2 id="evaluating-clm-designs-and-experimental-validation">Evaluating CLM Designs and Experimental Validation</h2>
<p>The review identifies evaluation as a critical gap. CLMs are often benchmarked on &ldquo;toy&rdquo; properties such as calculated logP, molecular weight, or QED (quantitative estimate of drug-likeness). These metrics capture the ability to satisfy predefined criteria but fail to reflect real-world drug discovery complexity and may lead to trivial solutions.</p>
<p>Existing benchmarks (<a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>) enable comparability across independently developed approaches but do not fully address the quality of generated compounds. The review emphasizes that experimental validation is the ultimate test. At the time of writing, only a few prospective applications had been published:</p>
<ul>
<li>Dual modulator of <a href="https://en.wikipedia.org/wiki/Retinoid_X_receptor">retinoid X</a> and <a href="https://en.wikipedia.org/wiki/Peroxisome_proliferator-activated_receptor">PPAR</a> receptors (EC50 ranging from 0.06 to 2.3 uM)</li>
<li>Inhibitor of <a href="https://en.wikipedia.org/wiki/Pim_kinase">Pim1 kinase</a> and <a href="https://en.wikipedia.org/wiki/Cyclin-dependent_kinase_4">CDK4</a> (manually modified from generated design)</li>
<li>Natural-product-inspired <a href="https://en.wikipedia.org/wiki/RAR-related_orphan_receptor_gamma">RORgamma</a> agonist (EC50 = 0.68 uM)</li>
<li>Molecules designed via combined generative AI and on-chip synthesis</li>
</ul>
<p>The scarcity of experimental validations reflects the interdisciplinary expertise required and the time/cost of chemical synthesis.</p>
<h2 id="gaps-limitations-and-future-directions">Gaps, Limitations, and Future Directions</h2>
<p>The review identifies several key gaps and opportunities:</p>
<p><strong>Scoring function limitations</strong>: Current scoring functions struggle with activity cliffs and non-additive structure-activity relationships. Conditional generation methods may help overcome these limitations by learning direct structure-property mappings.</p>
<p><strong>Structure-based design</strong>: Generating molecules that match electrostatic and shape features of protein binding pockets holds promise for addressing unexplored targets. However, prospective applications have been limited, potentially due to bias in existing protein-ligand affinity datasets.</p>
<p><strong>Synthesizability</strong>: Improving the ability of CLMs to propose synthesizable molecules is expected to increase practical relevance. Automated synthesis platforms may help but could also limit accessible chemical space.</p>
<p><strong>Few-shot learning</strong>: Large-scale pretrained CLMs combined with few-shot learning approaches are expected to boost prospective applications.</p>
<p><strong>Extensions beyond small molecules</strong>: Extending chemical languages to more complex molecular entities (proteins with non-natural amino acids, crystals, supramolecular chemistry) is an open frontier.</p>
<p><strong>Failure modes</strong>: Several studies have documented failure modes in goal-directed generation, including model shortcuts (exploiting scoring function artifacts), limited structural diversity, and generation of chemically unrealistic molecules.</p>
<p><strong>Interdisciplinary collaboration</strong>: The review emphasizes that bridging deep learning, cheminformatics, and medicinal chemistry expertise is essential for translating CLM designs into real-world drug candidates.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>This is a review paper and does not present novel experimental data. The paper surveys results from the literature.</p>
<h3 id="algorithms">Algorithms</h3>
<p>No novel algorithms are introduced. The review categorizes existing approaches (RNNs, VAEs, GANs, Transformers) and generation strategies (distribution learning, goal-directed, conditional).</p>
<h3 id="models">Models</h3>
<p>No new models are presented. The paper references existing implementations including REINVENT, ORGAN, and various RNN-based and Transformer-based CLMs.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The review discusses existing benchmarks:</p>
<ul>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a></strong>: Benchmarking suite for de novo molecular design</li>
<li><strong><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></strong>: Benchmarking platform for molecular generation models</li>
<li><strong>QED</strong>: Quantitative estimate of drug-likeness</li>
<li>Various physicochemical property metrics (logP, molecular weight)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not applicable (review paper).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Grisoni, F. (2023). Chemical language models for de novo drug design: Challenges and opportunities. <em>Current Opinion in Structural Biology</em>, 79, 102527. <a href="https://doi.org/10.1016/j.sbi.2023.102527">https://doi.org/10.1016/j.sbi.2023.102527</a></p>
<p><strong>Publication</strong>: Current Opinion in Structural Biology, Volume 79, April 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{grisoni2023chemical,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Chemical language models for de novo drug design: Challenges and opportunities}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Grisoni, Francesca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Current Opinion in Structural Biology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{79}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{102527}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Elsevier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.sbi.2023.102527}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CDDD: Learning Descriptors by Translating SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/cddd-translation-molecular-descriptors/</guid><description>CDDD learns continuous molecular descriptors by translating between SMILES and InChI representations, outperforming fingerprints in virtual screening.</description><content:encoded><![CDATA[<h2 id="a-translation-based-method-for-learned-molecular-descriptors">A Translation-Based Method for Learned Molecular Descriptors</h2>
<p>This is a <strong>Method</strong> paper that introduces Continuous and Data-Driven Descriptors (CDDD), a neural machine translation approach for learning fixed-size, continuous molecular representations. Rather than training an autoencoder to reconstruct <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, Winter et al. train an encoder-decoder model to translate between semantically equivalent but syntactically different molecular representations (e.g., randomized SMILES to canonical SMILES, or <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> to canonical SMILES). The bottleneck latent vector serves as a general-purpose molecular descriptor. Pretrained on approximately 72 million compounds from <a href="/notes/chemistry/datasets/zinc-22/">ZINC15</a> and PubChem, CDDD produces 512-dimensional descriptors that achieve competitive QSAR performance and significantly outperform all tested molecular fingerprints in ligand-based virtual screening.</p>
<h2 id="why-translation-instead-of-reconstruction">Why Translation Instead of Reconstruction?</h2>
<p>Molecular descriptors are central to cheminformatics. Traditional approaches rely on human-engineered fingerprints like ECFPs, which encode structural features as fixed-length bit vectors. While effective, these representations are constrained by predefined feature extraction rules.</p>
<p>Recent work applied deep neural networks directly to molecular graphs or SMILES strings to learn task-specific representations. However, these end-to-end approaches must learn features from scratch for each new dataset, making them prone to overfitting on the small bioactivity datasets typical in drug discovery.</p>
<p>Unsupervised approaches based on autoencoders (notably <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">Gomez-Bombarelli et al.&rsquo;s VAE</a> and <a href="/notes/chemistry/molecular-representations/encoders/seq2seq-fingerprint-molecular-embedding/">Xu et al.&rsquo;s seq2seq model</a>) offered a path toward general-purpose learned descriptors. These models reconstruct SMILES strings through an information bottleneck, forcing the latent space to capture molecular information. The concern with reconstruction, however, is that the model may focus on syntactic patterns of the string representation rather than the underlying chemical semantics. A model that memorizes SMILES syntax shortcuts can achieve low reconstruction error without truly encoding chemical meaning.</p>
<p>Winter et al. address this by drawing on the analogy to neural machine translation: a translator must understand the meaning of a sentence to produce a correct translation in another language. By training the model to translate between different molecular representations (which share chemical semantics but differ in syntax), the latent space is forced to capture the chemical information common to both representations, rather than representation-specific syntactic artifacts.</p>
<h2 id="translation-as-semantic-compression">Translation as Semantic Compression</h2>
<p>The core insight is that translating between two syntactically different but semantically equivalent representations forces the encoder to capture only the chemical meaning shared by both. The model architecture follows the standard encoder-decoder framework from neural machine translation.</p>
<p>The encoder reads a source molecular string (e.g., a randomized SMILES or InChI) and compresses it into a fixed-size latent vector. The decoder takes this latent vector and generates the target molecular string (canonical SMILES). The model is trained to minimize character-level cross-entropy between the decoder output and the target sequence.</p>
<p>Four translation tasks were evaluated:</p>
<ol>
<li><strong>Randomized SMILES to canonical SMILES</strong> (best performing)</li>
<li><strong>InChI to canonical SMILES</strong></li>
<li><strong>Canonical SMILES to canonical SMILES</strong> (autoencoding baseline)</li>
<li><strong>Canonical SMILES to InChI</strong> (failed to learn)</li>
</ol>
<p>The final model uses an RNN encoder with 3 stacked GRU layers (512, 1024, and 2048 units). The concatenated cell states pass through a fully connected layer with tanh activation to produce a 512-dimensional latent vector. The decoder mirrors this architecture, initializing its GRU states from the latent vector via separate fully connected layers. Teacher forcing is used during training, and left-to-right beam search is used at inference.</p>
<p>An auxiliary property prediction network takes the latent vector as input and predicts nine molecular properties (logP, partial charges, valence electrons, H-bond donors/acceptors, Balaban&rsquo;s J, <a href="https://en.wikipedia.org/wiki/Molar_refractivity">molar refractivity</a>, TPSA). This multi-task signal encourages the latent space to encode physically meaningful information. The full training objective combines the translation cross-entropy loss with the property prediction mean squared error:</p>
<p>$$\mathcal{L} = \mathcal{L}_{\text{translation}} + \mathcal{L}_{\text{properties}}$$</p>
<p>To ensure invariance to input SMILES representation at inference time, the model uses randomized SMILES as input half the time and canonical SMILES the other half during training. Input dropout (15% at the character level) and Gaussian noise (standard deviation 0.05) are applied for regularization.</p>
<h2 id="qsar-benchmarks-virtual-screening-and-latent-space-exploration">QSAR Benchmarks, Virtual Screening, and Latent Space Exploration</h2>
<h3 id="pretraining">Pretraining</h3>
<p>The model was pretrained on approximately 72 million compounds from ZINC15 and PubChem (merged, deduplicated, filtered for organic molecules with MW 12-600, &gt;3 heavy atoms, logP between -7 and 5). All evaluation compounds were removed from the pretraining set.</p>
<h3 id="qsar-experiments">QSAR Experiments</h3>
<p>Ten QSAR datasets were used, spanning classification (<a href="https://en.wikipedia.org/wiki/Ames_test">Ames mutagenicity</a>, <a href="https://en.wikipedia.org/wiki/KCNH2">hERG inhibition</a>, <a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">BBB penetration</a>, BACE inhibition, bee toxicity) and regression (EGFR inhibition, <a href="https://en.wikipedia.org/wiki/Plasmodium_falciparum">Plasmodium falciparum</a> inhibition, lipophilicity, aqueous solubility, melting point). Two datasets (Ames and lipophilicity) served as validation for architecture selection; the remaining eight were held out for final evaluation.</p>
<p>CDDD descriptors with an SVM were benchmarked against:</p>
<ul>
<li>Nine circular fingerprint variants (Morgan fingerprints, radius 1-3, folded to 512/1024/2048 bits) with RF, SVM, and GB</li>
<li>Graph convolution models (<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">DeepChem</a>)</li>
</ul>
<p>Both random-split and cluster-split (K-means on MACCS fingerprints, K=5) cross-validation were performed.</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Split</th>
          <th>CDDD + SVM</th>
          <th>Best Fingerprint</th>
          <th>Graph Conv</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Ames (ROC-AUC)</td>
          <td>Random</td>
          <td>0.89</td>
          <td>0.89 (ecfc2, RF)</td>
          <td>0.88</td>
      </tr>
      <tr>
          <td>hERG (ROC-AUC)</td>
          <td>Random</td>
          <td>0.86</td>
          <td>0.85 (ecfc4, RF)</td>
          <td>0.86</td>
      </tr>
      <tr>
          <td>BBBP (ROC-AUC)</td>
          <td>Random</td>
          <td>0.93</td>
          <td>0.93 (ecfc2, RF)</td>
          <td>0.92</td>
      </tr>
      <tr>
          <td>BACE (ROC-AUC)</td>
          <td>Random</td>
          <td>0.90</td>
          <td>0.91 (ecfc2, RF)</td>
          <td>0.91</td>
      </tr>
      <tr>
          <td>Bee toxicity (ROC-AUC)</td>
          <td>Random</td>
          <td>0.92</td>
          <td>0.91 (ecfc6, RF)</td>
          <td>0.89</td>
      </tr>
      <tr>
          <td>Lipophilicity ($r^2$)</td>
          <td>Random</td>
          <td>0.72</td>
          <td>0.69 (ecfc2, SVM)</td>
          <td>0.73</td>
      </tr>
      <tr>
          <td>ESOL ($r^2$)</td>
          <td>Random</td>
          <td>0.92</td>
          <td>0.58 (ecfc6, SVM)</td>
          <td>0.86</td>
      </tr>
      <tr>
          <td>Melting point ($r^2$)</td>
          <td>Random</td>
          <td>0.42</td>
          <td>0.38 (ecfc2, SVM)</td>
          <td>0.39</td>
      </tr>
  </tbody>
</table>
<p>CDDD descriptors showed competitive or better performance across all tasks. Notably, CDDD achieved substantially higher $r^2$ on aqueous solubility (0.92 vs. 0.58 for the best fingerprint). The authors emphasize that CDDD&rsquo;s feature extraction was fixed based on two validation tasks, while baseline methods selected the best fingerprint/model combination per task, making the comparison conservative for CDDD.</p>
<h3 id="virtual-screening">Virtual Screening</h3>
<p>Ligand-based virtual screening experiments followed the Riniker et al. benchmarking protocol on 40 DUD targets and 17 MUV targets. Five active compounds were randomly selected per target, and remaining compounds were ranked by similarity (cosine similarity for CDDD, <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto</a> for fingerprints). This process was repeated 50 times per target.</p>
<table>
  <thead>
      <tr>
          <th>Database</th>
          <th>CDDD (ROC-AUC)</th>
          <th>Second Best</th>
          <th>p-value (Wilcoxon)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DUD</td>
          <td>0.949</td>
          <td>0.899 (laval)</td>
          <td>$5 \times 10^{-38}$</td>
      </tr>
      <tr>
          <td>MUV</td>
          <td>0.679</td>
          <td>0.677 (ap)</td>
          <td>0.04</td>
      </tr>
  </tbody>
</table>
<p>CDDD significantly outperformed all 14 baseline fingerprints on both databases. The DUD improvement was particularly large (+5.0 ROC-AUC points over the next best). On MUV, which is designed to be harder, the advantage was smaller but still statistically significant. Importantly, while the best baseline fingerprint varied between DUD and MUV (laval vs. ap), CDDD ranked first on both, demonstrating consistent performance.</p>
<h3 id="latent-space-exploration">Latent Space Exploration</h3>
<p>The continuous, reversible nature of CDDD enables chemical space navigation. Shifting a molecule&rsquo;s embedding along the first principal component of the pretraining data correlates with molecular size (Spearman $r = 0.947$, $p = 0.00048$), while the second principal component correlates with polarity/logP ($r = -0.916$, $p = 0.00015$).</p>
<p>When shifting 1000 compounds along 100 random directions, the model maintained high valid SMILES generation rates (&gt;97% for the top beam search output, &gt;99% when considering the top 3 outputs). Euclidean distance in the descriptor space correlated smoothly with Tanimoto distance in fingerprint space, confirming that the latent space supports meaningful interpolation.</p>
<h2 id="consistent-learned-descriptors-for-chemistry">Consistent Learned Descriptors for Chemistry</h2>
<p>CDDD demonstrated that translation between molecular representations produces more informative latent spaces than autoencoder reconstruction. The key findings are:</p>
<ol>
<li><strong>Translation outperforms reconstruction</strong>: Models trained on translating between different representations consistently produced better downstream descriptors than autoencoding models, despite autoencoding being an easier task.</li>
<li><strong>Auxiliary property prediction helps</strong>: The additional classification task for molecular properties improved descriptor quality, particularly for physicochemical endpoints correlated with the predicted properties.</li>
<li><strong>Consistent performance</strong>: Unlike baseline methods where the best fingerprint varies by task, CDDD showed consistent performance across all QSAR and VS experiments.</li>
<li><strong>Smooth latent space</strong>: The continuous descriptor space supports meaningful interpolation and chemical space exploration with high valid SMILES rates.</li>
</ol>
<p>The authors acknowledge several limitations. The InChI-to-SMILES translation worked but produced inferior descriptors compared to SMILES-to-SMILES, and SMILES-to-InChI translation failed entirely, likely due to InChI&rsquo;s complex syntax (counting, arithmetic). The approach was only tested with string-based representations; translation between conceptually different representations (e.g., 3D structures) remains future work. The QSAR evaluation, while extensive, used relatively standard datasets, and the method&rsquo;s advantage over graph convolution models was modest on tasks where end-to-end learning had sufficient data.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ZINC15 + PubChem (merged)</td>
          <td>~72M compounds</td>
          <td>Filtered: organic, MW 12-600, &gt;3 heavy atoms, logP -7 to 5</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Ames mutagenicity</td>
          <td>6,130</td>
          <td>Classification</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Lipophilicity</td>
          <td>3,817</td>
          <td>Regression</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>hERG, BBBP, BACE, bee toxicity</td>
          <td>188-3,440</td>
          <td>Classification</td>
      </tr>
      <tr>
          <td>Test</td>
          <td>EGFR, Plasmodium, ESOL, melting point</td>
          <td>184-4,451</td>
          <td>Regression</td>
      </tr>
      <tr>
          <td>VS</td>
          <td>DUD</td>
          <td>40 targets</td>
          <td>Ligand-based virtual screening</td>
      </tr>
      <tr>
          <td>VS</td>
          <td>MUV</td>
          <td>17 targets</td>
          <td>Maximum unbiased validation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Encoder: 3 stacked GRU layers (512, 1024, 2048 units) with tanh bottleneck to 512-dim latent space</li>
<li>Decoder: Matching 3 stacked GRU layers, initialized from latent space</li>
<li>Auxiliary classifier: 3 FC layers (512, 128, 9) predicting molecular properties</li>
<li>Optimizer: Adam, initial LR $5 \times 10^{-4}$, decayed by 0.9 every 50,000 steps</li>
<li>Batch size: 64 with bucketing by sequence length</li>
<li>Input regularization: 15% character dropout + Gaussian noise (std 0.05)</li>
<li>Beam search for decoding at inference</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jrwnter/cddd">CDDD (GitHub)</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Pretrained model and extraction code</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>QSAR: 5-fold random CV and 5-fold cluster CV (K-means on MACCS, K=5)</li>
<li>Classification metric: ROC-AUC</li>
<li>Regression metric: $r^2$</li>
<li>VS: ROC-AUC averaged over 50 random active set selections per target</li>
<li>Statistical test: <a href="https://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test">Wilcoxon signed-rank test</a> for VS comparisons</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Framework: TensorFlow 1.4.1</li>
<li>Fingerprint extraction on GPU is comparable in speed to RDKit on CPU</li>
<li>SVM training on 512-dim CDDD descriptors takes seconds (vs. minutes for 2048-dim fingerprints)</li>
<li>Graph convolution training: ~30 minutes per task on GPU</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Winter, R., Montanari, F., Noe, F., &amp; Clevert, D.-A. (2019). Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations. <em>Chemical Science</em>, 10(6), 1692-1701. <a href="https://doi.org/10.1039/C8SC04175J">https://doi.org/10.1039/C8SC04175J</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{winter2019learning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Learning continuous and data-driven molecular descriptors by translating equivalent chemical representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Winter, Robin and Montanari, Floriane and No{\&#39;e}, Frank and Clevert, Djork-Arn{\&#39;e}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1692--1701}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/C8SC04175J}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Atom-in-SMILES: Better Tokens for Chemical Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/</link><pubDate>Thu, 26 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/atom-in-smiles-tokenization/</guid><description>Atom-in-SMILES replaces generic SMILES tokens with environment-aware atomic tokens, reducing token degeneration and improving chemical translation accuracy.</description><content:encoded><![CDATA[<h2 id="a-new-tokenization-method-for-chemical-language-models">A New Tokenization Method for Chemical Language Models</h2>
<p>This is a <strong>Method</strong> paper that introduces Atom-in-SMILES (AIS), a tokenization scheme for SMILES strings that replaces generic atomic tokens with environment-aware tokens encoding each atom&rsquo;s local chemical neighborhood. The primary contribution is demonstrating that tokenization quality has a significant impact on chemical language model outcomes across multiple tasks: SMILES canonicalization, <a href="/notes/chemistry/molecular-design/reaction-prediction/">single-step retrosynthesis</a>, and <a href="/notes/chemistry/molecular-design/property-prediction/">molecular property prediction</a>.</p>
<h2 id="why-standard-smiles-tokenization-falls-short">Why Standard SMILES Tokenization Falls Short</h2>
<p>Standard atom-wise SMILES tokenization treats all atoms of the same element identically. Every carbon is tokenized as &ldquo;C&rdquo; regardless of whether it is part of an aromatic ring, a carbonyl group, or a methyl chain. This creates a highly degenerate token space where chemically distinct atoms share the same representation.</p>
<p>The authors draw an analogy between natural language and chemical language. A typical SMILES sequence is about three times longer than a natural language sentence, yet the token vocabulary is roughly 1000 times smaller. This mismatch leads to extreme token repetition: the same tokens (C, c, N, O) appear many times within a single sequence. In natural language processing, token degeneration (where models repeatedly predict the same token) is a known failure mode of autoregressive decoders. The repetitive nature of SMILES tokens exacerbates this problem in chemical language models.</p>
<p>SMILES also lacks a one-to-one correspondence between tokens and chemical meaning. Two molecules that differ in only one atom substitution (e.g., swapping a carbon for a nitrogen in a ring) produce identical token sets under atom-wise tokenization, making it harder for models to distinguish structurally similar molecules.</p>
<h2 id="core-innovation-encoding-atom-environments-into-tokens">Core Innovation: Encoding Atom Environments into Tokens</h2>
<p>The key insight is to replace each atomic token with a richer token that encodes the atom&rsquo;s local chemical environment, inspired by the <a href="https://en.wikipedia.org/wiki/Atoms_in_molecules">atoms-in-molecules (AIM)</a> concept from quantum chemistry. For a given SMILES string, the AIS mapping function $f$ operates on the token space:</p>
<p>$$
f(X) = \begin{cases} AE|_{X_{\text{central}}} &amp; \text{if } X \text{ is an atom} \\ X &amp; \text{otherwise} \end{cases}
$$</p>
<p>where $AE|_{X_{\text{central}}}$ denotes the atomic environment centered on atom $X$. Non-atomic tokens (brackets, bond symbols, ring closures) pass through unchanged.</p>
<p>Each AIS token is formatted as <code>[Sym;Ring;Neighbors]</code> where:</p>
<ul>
<li><strong>Sym</strong> is the atomic symbol with chirality, aromaticity (lowercase for aromatic), hydrogen count, and formal charge</li>
<li><strong>Ring</strong> indicates whether the atom is in a ring (<code>R</code>) or not (<code>!R</code>)</li>
<li><strong>Neighbors</strong> lists the neighboring atoms interacting with the central atom</li>
</ul>
<p>This mapping is bijective: SMILES strings can be fully recovered from AIS strings via an inverse projection. The algorithm iterates over atoms in a molecule, computes their local environments using RDKit, and produces environment-aware token variants.</p>
<p>As a concrete example, in glycine the two carbons and two oxygens are indistinguishable under atom-wise tokenization. Under AIS, each receives a unique token reflecting its bonding environment (e.g., the carboxyl carbon is distinguished from the alpha carbon).</p>
<p>The AIS tokenization also exhibits a fingerprint-like property. Because each token encodes local structural information, the set of AIS tokens for a molecule functions similarly to circular fingerprints like ECFP2. The authors show that pairwise <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarities</a> computed from AIS token sets have resolution comparable to ECFP2 and HashAP fingerprints, and better resolution than MACCS, Avalon, and RDKit fingerprints.</p>
<p>Token repetition can be quantified as:</p>
<p>$$
\text{rep-}l = \sum_{t=1}^{|s|} \mathbb{1}[s_t \in s_{t-w-1:t-1}]
$$</p>
<p>where $s$ is the predicted sequence, $|s|$ is the token count, and $w$ is the window size. AIS tokens exhibit consistently lower normalized repetition rates compared to SMILES, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, and <a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a> across diverse molecular datasets (drugs, natural products, steroids, lipids, metal complexes, octane isomers).</p>
<h2 id="experimental-evaluation-across-three-chemical-tasks">Experimental Evaluation Across Three Chemical Tasks</h2>
<h3 id="input-output-equivalent-mapping-smiles-canonicalization">Input-Output Equivalent Mapping (SMILES Canonicalization)</h3>
<p>The first task tests whether a model can translate non-canonical SMILES enumerations into canonical form. The authors constructed deliberately challenging datasets from <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> subsets with cumulative structural constraints (no cyclic heteroatom-heteroatom bonds, stable functional groups only, fragment-like, scaffold-like, etc.), generating training sets of 1M molecules augmented with 150K molecules from the most restrictive subset at 10x, 30x, and 50x augmentation levels.</p>
<table>
  <thead>
      <tr>
          <th>GDB-13 Subset</th>
          <th>Atom-wise (x10)</th>
          <th>Atom-wise (x50)</th>
          <th>AIS (x10)</th>
          <th>AIS (x50)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ab</td>
          <td>34.2%</td>
          <td>33.2%</td>
          <td>37.3%</td>
          <td>34.1%</td>
      </tr>
      <tr>
          <td>abc</td>
          <td>31.0%</td>
          <td>29.6%</td>
          <td>33.7%</td>
          <td>30.4%</td>
      </tr>
      <tr>
          <td>abcde</td>
          <td>48.7%</td>
          <td>45.5%</td>
          <td>53.6%</td>
          <td>47.0%</td>
      </tr>
      <tr>
          <td>abcdef</td>
          <td>41.8%</td>
          <td>39.1%</td>
          <td>52.5%</td>
          <td>46.9%</td>
      </tr>
      <tr>
          <td>abcdefg</td>
          <td>50.9%</td>
          <td>50.0%</td>
          <td>59.9%</td>
          <td>56.8%</td>
      </tr>
  </tbody>
</table>
<p>AIS outperformed atom-wise tokenization on all subsets and augmentation levels. The performance gap grew larger for more restrictive (more similar) subsets, reaching up to 10.7% on the abcdef subset. This demonstrates that AIS is particularly effective when molecules are structurally similar and harder to distinguish.</p>
<h3 id="single-step-retrosynthesis">Single-Step Retrosynthesis</h3>
<p>The second task uses the USPTO-50K benchmark for single-step <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthetic prediction</a> via a template-free transformer encoder-decoder model. The model was trained for 200,000 steps with Adam optimizer, negative log-likelihood loss, and cyclic learning rate scheduling.</p>
<table>
  <thead>
      <tr>
          <th>Tokenization</th>
          <th>rep-|P - rep-|GT &gt;= 2</th>
          <th>String Exact (%)</th>
          <th>Tc Exact (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Atom-wise baseline</td>
          <td>&ndash;</td>
          <td>42.00</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>Atom-wise (reproduced)</td>
          <td>801</td>
          <td>42.05</td>
          <td>44.72</td>
      </tr>
      <tr>
          <td>SmilesPE</td>
          <td>821</td>
          <td>19.82</td>
          <td>22.74</td>
      </tr>
      <tr>
          <td>SELFIES</td>
          <td>886</td>
          <td>28.82</td>
          <td>30.76</td>
      </tr>
      <tr>
          <td>DeepSMILES</td>
          <td>902</td>
          <td>38.63</td>
          <td>41.20</td>
      </tr>
      <tr>
          <td><strong>Atom-in-SMILES</strong></td>
          <td><strong>727</strong></td>
          <td><strong>46.32</strong></td>
          <td><strong>47.62</strong></td>
      </tr>
  </tbody>
</table>
<p>AIS achieved 46.32% string exact accuracy (4.3% above the atom-wise baseline) and 47.62% Tanimoto exact accuracy (2.9% above baseline). AIS also had the fewest degenerate token repetitions (727 vs. 801 for atom-wise), representing approximately a 10% reduction. DeepSMILES had the highest repetition count (902) despite reasonable overall accuracy. SELFIES and <a href="/notes/chemistry/molecular-representations/notations/smiles-pair-encoding/">SmilesPE</a> both performed substantially worse than the atom-wise baseline on this task.</p>
<p>The authors identified six common token repetition patterns in retrosynthetic predictions: long head repetitions, long tail repetitions, repetitive rings, repetitive chains, and halogen repetitions on both aliphatic and aromatic carbons.</p>
<h3 id="molecular-property-prediction">Molecular Property Prediction</h3>
<p>The third task evaluates tokenization schemes on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks using Random Forest models with 5-fold cross-validation. AIS tokens were converted to fingerprint-like feature vectors.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>SMILES</th>
          <th>DeepSMILES</th>
          <th>SELFIES</th>
          <th>SmilesPE</th>
          <th>AIS</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Regression (RMSE, lower is better)</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>0.628</td>
          <td>0.631</td>
          <td>0.675</td>
          <td>0.689</td>
          <td><strong>0.553</strong></td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>0.545</td>
          <td>0.544</td>
          <td>0.564</td>
          <td>0.761</td>
          <td><strong>0.441</strong></td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>0.924</td>
          <td>0.895</td>
          <td>0.938</td>
          <td>0.800</td>
          <td><strong>0.683</strong></td>
      </tr>
      <tr>
          <td><strong>Classification (ROC-AUC, higher is better)</strong></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
          <td></td>
      </tr>
      <tr>
          <td>BBBP</td>
          <td>0.758</td>
          <td>0.777</td>
          <td>0.799</td>
          <td>0.847</td>
          <td><strong>0.885</strong></td>
      </tr>
      <tr>
          <td>BACE</td>
          <td>0.740</td>
          <td>0.774</td>
          <td>0.746</td>
          <td>0.837</td>
          <td><strong>0.835</strong></td>
      </tr>
      <tr>
          <td>HIV</td>
          <td>0.649</td>
          <td>0.648</td>
          <td>0.653</td>
          <td>0.739</td>
          <td><strong>0.729</strong></td>
      </tr>
  </tbody>
</table>
<p>AIS achieved the best performance on all three regression datasets and two of three classification datasets. On ESOL, the RMSE improvement over standard SMILES was 12%. On lipophilicity, the improvement was 26%.</p>
<h2 id="key-findings-better-tokens-yield-better-chemical-models">Key Findings: Better Tokens Yield Better Chemical Models</h2>
<p>The main findings of this work are:</p>
<ol>
<li>
<p><strong>Tokenization significantly impacts chemical language model quality.</strong> The choice of tokenization scheme can change prediction accuracy by over 10 percentage points on equivalent mapping tasks.</p>
</li>
<li>
<p><strong>AIS reduces token degeneration by approximately 10%</strong> compared to atom-wise SMILES tokenization, with consistently lower normalized repetition rates across diverse molecular datasets.</p>
</li>
<li>
<p><strong>AIS outperforms all compared tokenization schemes</strong> (atom-wise SMILES, SmilesPE, SELFIES, DeepSMILES) on canonicalization, retrosynthesis, and property prediction.</p>
</li>
<li>
<p><strong>The fingerprint-like nature of AIS tokens</strong> enables direct use as molecular features for property prediction and provides resolution comparable to established circular fingerprints.</p>
</li>
<li>
<p><strong>The mapping is invertible</strong>, so AIS strings can always be converted back to valid SMILES. This is a practical advantage over approaches that may lose structural information.</p>
</li>
</ol>
<p><strong>Limitations</strong>: AIS cannot distinguish environmentally identical substructures or atoms related by a molecular symmetry plane, since it only considers nearest-neighbor environments. Performance on long-chain molecules (e.g., lipids) is similar across all tokenization schemes, suggesting that local environment encoding is less informative for repetitive linear structures.</p>
<p><strong>Future directions</strong>: The authors suggest AIS has potential for broader adoption in molecular generative models, chemical translation, and property prediction tasks across the cheminformatics community.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Canonicalization training</td>
          <td>GDB-13 subsets</td>
          <td>1M + 150K augmented</td>
          <td>Cumulative structural constraints a-h</td>
      </tr>
      <tr>
          <td>Canonicalization testing</td>
          <td>GDB-13 disjoint test sets</td>
          <td>20K per subset</td>
          <td>Various restriction levels</td>
      </tr>
      <tr>
          <td>Retrosynthesis</td>
          <td>USPTO-50K</td>
          <td>~50K reactions</td>
          <td>Sequences &gt; 150 tokens removed</td>
      </tr>
      <tr>
          <td>Property prediction</td>
          <td>MoleculeNet (ESOL, FreeSolv, Lipophilicity, BBBP, BACE, HIV)</td>
          <td>Varies</td>
          <td>Standard benchmark splits</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Transformer encoder-decoder architecture for canonicalization and retrosynthesis tasks</li>
<li>200,000 training steps with Adam optimizer, negative log-likelihood loss, cyclic learning rate scheduler</li>
<li>Random Forest with 5-fold cross-validation for property prediction</li>
<li>AIS tokenization implemented via RDKit for atom environment extraction</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>String exact match (%)</td>
          <td>Canonicalization, Retrosynthesis</td>
          <td>Exact SMILES match</td>
      </tr>
      <tr>
          <td>Tanimoto exactness (Tc)</td>
          <td>Retrosynthesis</td>
          <td>Morgan FP radius 3, 2048 bits</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression property prediction</td>
          <td>ESOL, FreeSolv, Lipophilicity</td>
      </tr>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification property prediction</td>
          <td>BBBP, BACE, HIV</td>
      </tr>
      <tr>
          <td>rep-l</td>
          <td>Token degeneration</td>
          <td>Single-token repetition count</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Not explicitly specified in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/snu-lcbc/atom-in-SMILES">atom-in-SMILES</a></td>
          <td>Code</td>
          <td>CC-BY-NC-SA-4.0</td>
          <td>AIS tokenization implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ucak, U. V., Ashyrmamatov, I., &amp; Lee, J. (2023). Improving the quality of chemical language model outcomes with atom-in-SMILES tokenization. <em>Journal of Cheminformatics</em>, 15, 55. <a href="https://doi.org/10.1186/s13321-023-00725-9">https://doi.org/10.1186/s13321-023-00725-9</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ucak2023improving,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Improving the quality of chemical language model outcomes with atom-in-SMILES tokenization}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ucak, Umit V. and Ashyrmamatov, Islambek and Lee, Juyong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{55}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-023-00725-9}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>TamGen: GPT-Based Target-Aware Drug Design and Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/tamgen-target-aware-molecule-generation/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/target-aware/tamgen-target-aware-molecule-generation/</guid><description>TamGen combines a GPT-like chemical language model with protein pocket encoding and VAE refinement to generate drug candidates with experimental validation.</description><content:encoded><![CDATA[<h2 id="a-method-for-target-conditioned-molecular-generation">A Method for Target-Conditioned Molecular Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces TamGen (Target-aware molecular generation), a three-module architecture for generating drug-like compounds conditioned on protein binding pocket structures. The primary contribution is a GPT-like chemical language model pre-trained on 10 million SMILES from PubChem, combined with a Transformer-based protein encoder and a VAE-based contextual encoder for compound refinement. The authors validate TamGen on the CrossDocked2020 benchmark and apply it through a Design-Refine-Test pipeline to discover 14 novel inhibitors of the Mycobacterium tuberculosis ClpP protease, with $\text{IC}_{50}$ values ranging from 1.88 to 35.2 $\mu$M.</p>
<h2 id="bridging-generative-ai-and-practical-drug-discovery">Bridging Generative AI and Practical Drug Discovery</h2>
<p>Target-based generative drug design aims to create novel compounds with desired pharmacological properties from scratch, exploring the estimated $10^{60}$ feasible compounds in chemical space rather than screening existing libraries of $10^{4}$ to $10^{8}$ molecules. Prior approaches using diffusion models, GANs, VAEs, and autoregressive models have demonstrated the feasibility of generating compounds conditioned on target proteins. However, most generated compounds lack satisfactory physicochemical properties for drug-likeness, and validations with biophysical or biochemical assays are largely missing.</p>
<p>The key limitations of existing 3D generation methods (TargetDiff, Pocket2Mol, ResGen, 3D-AR) include:</p>
<ul>
<li>Generated compounds frequently contain multiple fused rings, leading to poor synthetic accessibility</li>
<li>High cellular toxicity and decreased developability associated with excessive fused ring counts</li>
<li>Slow generation speeds (tens of minutes to hours per 100 compounds)</li>
<li>Limited real-world experimental validation of generated candidates</li>
</ul>
<p>TamGen addresses these issues by operating in 1D SMILES space rather than 3D coordinate space, leveraging pre-training on natural compound distributions to produce more drug-like molecules.</p>
<h2 id="three-module-architecture-with-pre-training-and-refinement">Three-Module Architecture with Pre-Training and Refinement</h2>
<p>TamGen consists of three components: a compound decoder, a protein encoder, and a contextual encoder.</p>
<h3 id="compound-decoder-chemical-language-model">Compound Decoder (Chemical Language Model)</h3>
<p>The compound decoder is a GPT-style autoregressive model pre-trained on 10 million SMILES randomly sampled from PubChem. The pre-training objective follows standard next-token prediction:</p>
<p>$$
\min -\sum_{y \in \mathcal{D}_0} \frac{1}{M_y} \sum_{i=1}^{M_y} \log P(y_i \mid y_{i-1}, y_{i-2}, \ldots, y_1)
$$</p>
<p>where $M_y$ is the SMILES sequence length. This enables both unconditional and conditional generation. The decoder uses 12 Transformer layers with hidden dimension 768.</p>
<h3 id="protein-encoder-with-distance-aware-attention">Protein Encoder with Distance-Aware Attention</h3>
<p>The protein encoder processes binding pocket residues using both sequential and geometric information. Given amino acids $\mathbf{a} = (a_1, \ldots, a_N)$ with 3D coordinates $\mathbf{r} = (r_1, \ldots, r_N)$, the input representation combines amino acid embeddings with coordinate embeddings:</p>
<p>$$
h_i^{(0)} = E_a a_i + E_r \rho\left(r_i - \frac{1}{N}\sum_{j=1}^{N} r_j\right)
$$</p>
<p>where $\rho$ denotes a random roto-translation operation applied as data augmentation, and coordinates are centered to the origin.</p>
<p>The encoder uses a distance-aware self-attention mechanism that weights attention scores by spatial proximity:</p>
<p>$$
\begin{aligned}
\hat{\alpha}_j &amp;= \exp\left(-\frac{|r_i - r_j|^2}{\tau}\right)(h_i^{(l)\top} W h_j^{(l)}) \\
\alpha_j &amp;= \frac{\exp \hat{\alpha}_j}{\sum_{k=1}^{N} \exp \hat{\alpha}_k} \\
\hat{\boldsymbol{h}}_i^{(l+1)} &amp;= \sum_{j=1}^{N} \alpha_j (W_v h_j^{(l)})
\end{aligned}
$$</p>
<p>where $\tau$ is a temperature hyperparameter and $W$, $W_v$ are learnable parameters. The encoder uses 4 layers with hidden dimension 256. Outputs are passed to the compound decoder via cross-attention.</p>
<h3 id="vae-based-contextual-encoder">VAE-Based Contextual Encoder</h3>
<p>A VAE-based contextual encoder determines the mean $\mu$ and standard deviation $\sigma$ for any (compound, protein) pair. During training, the model recovers the input compound. During application, a seed compound enables compound refinement. The full training objective combines reconstruction loss with KL regularization:</p>
<p>$$
\min_{\Theta, q} \frac{1}{|\mathcal{D}|} \sum_{(\mathbf{x}, \mathbf{y}) \in \mathcal{D}} -\log P(\mathbf{y} \mid \mathbf{x}, z; \Theta) + \beta \mathcal{D}_{\text{KL}}(q(z \mid \mathbf{x}, \mathbf{y}) | p(z))
$$</p>
<p>where $\beta$ is a hyperparameter controlling the KL divergence weight, and $p(z)$ is a standard Gaussian prior.</p>
<h2 id="benchmark-evaluation-and-tuberculosis-drug-discovery">Benchmark Evaluation and Tuberculosis Drug Discovery</h2>
<h3 id="crossdocked2020-benchmark">CrossDocked2020 Benchmark</h3>
<p>TamGen was evaluated against five baselines (liGAN, 3D-AR, Pocket2Mol, ResGen, TargetDiff) on the CrossDocked2020 dataset (~100k drug-target pairs for training, 100 test binding pockets). For each target, 100 compounds were generated per method. Evaluation metrics included:</p>
<ul>
<li><strong>Docking score</strong> (AutoDock-Vina): binding affinity estimate</li>
<li><strong>QED</strong>: quantitative estimate of drug-likeness</li>
<li><strong><a href="https://en.wikipedia.org/wiki/Lipinski%27s_rule_of_five">Lipinski&rsquo;s Rule of Five</a></strong>: physicochemical property compliance</li>
<li><strong>SAS</strong>: synthetic accessibility score</li>
<li><strong>LogP</strong>: lipophilicity (optimal range 0-5 for oral administration)</li>
<li><strong>Molecular diversity</strong>: Tanimoto similarity between Morgan fingerprints</li>
</ul>
<p>TamGen ranked first or second on 5 of 6 metrics and achieved the best overall score using mean reciprocal rank (MRR) across all metrics. On synthetic accessibility for high-affinity compounds, TamGen performed best. The generated compounds averaged 1.78 fused rings, closely matching FDA-approved drugs, while competing 3D methods produced compounds with significantly more fused rings.</p>
<p>TamGen was also 85x to 394x faster than competing methods: generating 100 compounds per target in an average of 9 seconds on a single A6000 GPU, compared to tens of minutes or hours for the baselines.</p>
<h3 id="design-refine-test-pipeline-for-clpp-inhibitors">Design-Refine-Test Pipeline for ClpP Inhibitors</h3>
<p>The practical application targeted ClpP protease of Mycobacterium tuberculosis, an emerging antibiotic target with no documented advanced inhibitors beyond <a href="https://en.wikipedia.org/wiki/Bortezomib">Bortezomib</a>.</p>
<p><strong>Design stage</strong>: Using the ClpP binding pocket from PDB structure 5DZK, TamGen generated 2,612 unique compounds. Compounds were filtered by molecular docking (retaining those with better scores than Bortezomib) and Ligandformer phenotypic activity prediction. Peptidomimetic compounds were excluded for poor ADME properties. Four seed compounds were selected.</p>
<p><strong>Refine stage</strong>: Using the 4 seed compounds plus 3 weakly active compounds ($\text{IC}_{50}$ 100-200 $\mu$M) from prior experiments, TamGen generated 8,635 unique compounds conditioned on both the target and seeds. After filtering, 296 compounds were selected for testing.</p>
<p><strong>Test stage</strong>: From a 446k commercial compound library, 159 analogs (MCS similarity &gt; 0.55) were identified. Five analogs showed significant inhibitory effects. Dose-response experiments revealed $\text{IC}_{50}$ values below 20 $\mu$M for all five, with Analog-005 achieving $\text{IC}_{50}$ of 1.9 $\mu$M. Three additional novel compounds were synthesized for SAR analysis:</p>
<table>
  <thead>
      <tr>
          <th>Compound</th>
          <th>Series</th>
          <th>Source</th>
          <th>$\text{IC}_{50}$ ($\mu$M)</th>
          <th>Key Feature</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Analog-005</td>
          <td>II</td>
          <td>Commercial library</td>
          <td>1.9</td>
          <td>Most potent analog</td>
      </tr>
      <tr>
          <td>Analog-003</td>
          <td>I</td>
          <td>Commercial library</td>
          <td>&lt; 20</td>
          <td>Strongest single-dose inhibition</td>
      </tr>
      <tr>
          <td>Syn-A003-01</td>
          <td>I</td>
          <td>TamGen (synthesized)</td>
          <td>&lt; 20</td>
          <td>Diphenylurea scaffold</td>
      </tr>
  </tbody>
</table>
<p>Both compound series (diphenylurea and benzenesulfonamide scaffolds) represent novel ClpP inhibitor chemotypes distinct from Bortezomib. Additionally, 6 out of 8 directly synthesized TamGen compounds demonstrated $\text{IC}_{50}$ below 40 $\mu$M, confirming TamGen&rsquo;s ability to produce viable hits without the library search step.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>Four ablation experiments clarified the contributions of TamGen&rsquo;s components:</p>
<ol>
<li><strong>Without pre-training</strong>: Significantly worse docking scores and simpler structures. The optimal decoder depth dropped from 12 to 4 layers without pre-training due to overfitting.</li>
<li><strong>Shuffled pocket-ligand pairs (TamGen-r)</strong>: Substantially worse docking scores, confirming TamGen learns meaningful pocket-ligand interactions rather than generic compound distributions.</li>
<li><strong>Without distance-aware attention</strong>: Significant decline in docking scores when removing the geometric attention term from Eq. 2.</li>
<li><strong>Without coordinate augmentation</strong>: Performance degradation when removing the roto-translation augmentation $\rho$, highlighting the importance of geometric invariance.</li>
</ol>
<h2 id="validated-drug-like-generation-with-practical-limitations">Validated Drug-Like Generation with Practical Limitations</h2>
<p>TamGen demonstrates that 1D SMILES-based generation with pre-training on natural compounds produces molecules with better drug-likeness properties than 3D generation methods. The experimental validation against ClpP is a notable strength, as most generative drug design methods lack biochemical assay confirmation.</p>
<p>Key limitations acknowledged by the authors include:</p>
<ul>
<li><strong>Insufficient sensitivity to minor target differences</strong>: TamGen cannot reliably distinguish targets with point mutations or protein isoforms, limiting applicability for cancer-related proteins</li>
<li><strong>Requires known structure and pocket</strong>: As a structure-based method, TamGen needs the 3D structure of the target protein and binding pocket information</li>
<li><strong>Limited cellular validation</strong>: The study focuses on hit identification; cellular activities and toxicities of proposed compounds were not extensively tested</li>
<li><strong>1D generation trade-off</strong>: SMILES-based generation does not fully exploit 3D protein-ligand geometric interactions available in coordinate space</li>
</ul>
<p>Future directions include integrating insights from 3D autoregressive methods, using Monte Carlo Tree Search or reinforcement learning to guide generation for better docking scores and ADME/T properties, and property-guided generation as explored in <a href="/notes/chemistry/molecular-design/generation/target-aware/prefixmol-target-chemistry-aware-generation/">PrefixMol</a>.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>PubChem (random sample)</td>
          <td>10M SMILES</td>
          <td>Compound decoder pre-training</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>CrossDocked2020</td>
          <td>~100k pairs</td>
          <td>Filtered pocket-ligand pairs</td>
      </tr>
      <tr>
          <td>Extended fine-tuning</td>
          <td>CrossDocked + PDB</td>
          <td>~300k pairs</td>
          <td>Used for TB compound generation</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>CrossDocked2020 test</td>
          <td>100 pockets</td>
          <td>Same split as TargetDiff/Pocket2Mol</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Compound decoder</strong>: 12-layer GPT with hidden dimension 768, pre-trained for 200k steps</li>
<li><strong>Protein encoder</strong>: 4-layer Transformer with hidden dimension 256, distance-aware attention</li>
<li><strong>VAE encoder</strong>: 4-layer standard Transformer encoder with hidden dimension 256</li>
<li><strong>Optimizer</strong>: Adam with initial learning rate $3 \times 10^{-5}$</li>
<li><strong>VAE $\beta$</strong>: 0.1 or 1.0 depending on generation stage</li>
<li><strong>Beam search</strong>: beam sizes of 4, 10, or 20 depending on stage</li>
<li><strong>Pocket definition</strong>: residues within 10 or 15 Angstrom distance cutoff from ligand center</li>
</ul>
<h3 id="models">Models</h3>
<p>Pre-trained model weights are available via Zenodo at <a href="https://doi.org/10.5281/zenodo.13751391">https://doi.org/10.5281/zenodo.13751391</a>.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>TamGen</th>
          <th>Best Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Overall MRR</td>
          <td>Best</td>
          <td>TargetDiff (2nd)</td>
          <td>Ranked across 6 metrics</td>
      </tr>
      <tr>
          <td>Fused rings (avg)</td>
          <td>1.78</td>
          <td>~3-5 (others)</td>
          <td>Matches FDA-approved drug average</td>
      </tr>
      <tr>
          <td>Generation speed</td>
          <td>9 sec/100 compounds</td>
          <td>~13 min (ResGen)</td>
          <td>Single A6000 GPU</td>
      </tr>
      <tr>
          <td>ClpP hit rate</td>
          <td>6/8 synthesized</td>
          <td>N/A</td>
          <td>$\text{IC}_{50}$ &lt; 40 $\mu$M</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 8x V100 GPUs for 200k steps</li>
<li>Inference benchmarking: 1x A6000 GPU</li>
<li>Generation time: ~9 seconds per 100 compounds per target</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/SigmaGenX/TamGen">TamGen code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.13751391">Model weights and data</a></td>
          <td>Model + Data</td>
          <td>CC-BY-4.0</td>
          <td>Pre-trained weights, source data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wu, K., Xia, Y., Deng, P., Liu, R., Zhang, Y., Guo, H., Cui, Y., Pei, Q., Wu, L., Xie, S., Chen, S., Lu, X., Hu, S., Wu, J., Chan, C.-K., Chen, S., Zhou, L., Yu, N., Chen, E., Liu, H., Guo, J., Qin, T., &amp; Liu, T.-Y. (2024). TamGen: drug design with target-aware molecule generation through a chemical language model. <em>Nature Communications</em>, 15, 9360. <a href="https://doi.org/10.1038/s41467-024-53632-4">https://doi.org/10.1038/s41467-024-53632-4</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wu2024tamgen,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{TamGen: drug design with target-aware molecule generation through a chemical language model}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wu, Kehan and Xia, Yingce and Deng, Pan and Liu, Renhe and Zhang, Yuan and Guo, Han and Cui, Yumeng and Pei, Qizhi and Wu, Lijun and Xie, Shufang and Chen, Si and Lu, Xi and Hu, Song and Wu, Jinzhi and Chan, Chi-Kin and Chen, Shawn and Zhou, Liangliang and Yu, Nenghai and Chen, Enhong and Liu, Haiguang and Guo, Jinjiang and Qin, Tao and Liu, Tie-Yan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{9360}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-024-53632-4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SPECTRA: Evaluating Generalizability of Molecular AI</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/spectra-evaluating-generalizability-molecular-ai/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/spectra-evaluating-generalizability-molecular-ai/</guid><description>SPECTRA evaluates ML model generalizability on molecular datasets by plotting performance across a spectrum of train-test overlap levels.</description><content:encoded><![CDATA[<h2 id="a-spectral-framework-for-evaluating-molecular-ml-generalizability">A Spectral Framework for Evaluating Molecular ML Generalizability</h2>
<p>This is a <strong>Method</strong> paper that introduces SPECTRA (SPECtral framework for model evaluaTion on moleculaR dAtasets), a systematic approach for evaluating how well machine learning models generalize on molecular sequencing data. The primary contribution is a framework that generates train-test splits with controlled, decreasing levels of overlap, producing a spectral performance curve (SPC) and a single summary metric, the area under the spectral performance curve (AUSPC), for comparing model generalizability across tasks and architectures.</p>
<h2 id="why-existing-molecular-benchmarks-overestimate-generalizability">Why Existing Molecular Benchmarks Overestimate Generalizability</h2>
<p>Deep learning has achieved high performance on molecular sequencing benchmarks, but a persistent gap exists between benchmark performance and real-world deployment. The authors identify the root cause: existing evaluation approaches use either metadata-based (MB) splits or similarity-based (SB) splits, both of which provide an incomplete picture of generalizability.</p>
<p>MB splits partition data by metadata properties (e.g., temporal splits, random splits) without controlling sequence similarity between train and test sets. This means high train-test similarity can inflate performance metrics. SB splits control similarity at a single threshold, but the model&rsquo;s behavior at other similarity levels remains unknown.</p>
<p>For example, the TAPE benchmark&rsquo;s remote homology family split has 97% cross-split overlap, while the superfamily split has 71%. Model accuracy drops by 50% between these two points, yet the full curve of performance degradation is never characterized. This gap between evaluated and real-world overlap levels leads to overoptimistic deployment expectations, as demonstrated by the case of <a href="https://en.wikipedia.org/wiki/Rifampicin">rifampicin</a> resistance prediction in <em>M. tuberculosis</em>, where commercial genotypic assays later proved unreliable in specific geographic regions.</p>
<h2 id="the-spectra-framework-spectral-properties-graphs-and-performance-curves">The SPECTRA Framework: Spectral Properties, Graphs, and Performance Curves</h2>
<p>SPECTRA takes three inputs: a molecular sequencing dataset, a machine learning model, and a spectral property definition. A spectral property (SP) is a molecular sequence property expected to influence model generalizability for a specific task. For sequence-to-sequence datasets, the spectral property is typically sequence identity (proportion of aligned positions &gt; 0.3). For mutational scan datasets, it is defined by sample barcodes (string representations of mutations present in each sample).</p>
<h3 id="spectral-property-graph-construction">Spectral Property Graph Construction</h3>
<p>SPECTRA constructs a spectral property graph (SPG) where nodes represent samples and edges connect samples that share the spectral property. The goal is to generate train-test splits with controlled levels of cross-split overlap by finding approximate <a href="https://en.wikipedia.org/wiki/Maximal_independent_set">maximal independent sets</a> of this graph.</p>
<p>Finding the exact maximal independent set is NP-Hard, so SPECTRA uses a greedy randomized algorithm parameterized by a spectral parameter $\mathbf{SP} \in [0, 1]$:</p>
<ol>
<li>Randomly order SPG vertices</li>
<li>Select the first vertex and delete each neighbor with probability equal to $\mathbf{SP}$</li>
<li>Continue until no vertices remain</li>
</ol>
<p>When $\mathbf{SP} = 0$, this produces a random split (maximum cross-split overlap). When $\mathbf{SP} = 1$, it approximates the maximal independent set (minimum cross-split overlap). For each spectral parameter value (incremented by 0.05 from 0 to 1), three splits with different random seeds are generated.</p>
<h3 id="the-spectral-performance-curve-and-auspc">The Spectral Performance Curve and AUSPC</h3>
<p>The model is trained and evaluated on each split. Plotting test performance against the spectral parameter produces the spectral performance curve (SPC). The area under this curve, the AUSPC, serves as a single summary metric for model generalizability that captures behavior across the full spectrum of train-test overlap.</p>
<h3 id="handling-mutational-scan-datasets">Handling Mutational Scan Datasets</h3>
<p>For mutational scan datasets where sample barcodes map to multiple samples, SPECTRA introduces two modifications: (1) weighting nodes in the SPG by the number of samples they represent, and (2) running a subset sum algorithm to ensure 80/20 train-test splits by sample count.</p>
<h2 id="evaluation-across-18-datasets-and-19-models">Evaluation Across 18 Datasets and 19 Models</h2>
<p>The authors apply SPECTRA to 18 molecular sequencing datasets spanning three benchmarks (TAPE, PEER, ProteinGym) plus PDBBind, evaluating 19 models including CNNs, LSTMs, GNNs (GearNet), LLMs (ESM2), diffusion models (DiffDock), variational autoencoders (EVE), and logistic regression.</p>
<h3 id="benchmark-datasets">Benchmark Datasets</h3>
<p>The core evaluation covers five primary tasks:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Dataset</th>
          <th>Type</th>
          <th>Metric</th>
          <th>Samples</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Rifampicin resistance (RIF)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>17,474</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Isoniazid">Isoniazid</a> resistance (INH)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>26,574</td>
      </tr>
      <tr>
          <td><a href="https://en.wikipedia.org/wiki/Pyrazinamide">Pyrazinamide</a> resistance (PZA)</td>
          <td>TB clinical isolates</td>
          <td>MSD</td>
          <td>AUROC</td>
          <td>12,146</td>
      </tr>
      <tr>
          <td>Fluorescence prediction</td>
          <td><a href="https://en.wikipedia.org/wiki/Green_fluorescent_protein">GFP</a> variants</td>
          <td>MSD</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>54,024</td>
      </tr>
      <tr>
          <td>Vaccine escape</td>
          <td>SARS-CoV-2 RBD</td>
          <td>MSD</td>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>438,046</td>
      </tr>
  </tbody>
</table>
<p>Additional benchmarks include remote homology detection, secondary structure prediction, subcellular localization, and protein-ligand binding (PDBBind, Astex diverse set, Posebusters).</p>
<h3 id="models-evaluated">Models Evaluated</h3>
<p>Eight models were evaluated in depth across the five primary tasks: logistic regression, CNN, ESM2 (pretrained), ESM2-Finetuned, GearNet, GearNet-Finetuned, EVE, and SeqDesign. Additional models (LSTM, ResNet, DeepSF, Transformer, HHblits, Equibind, DiffDock, TankBind, Transception, MSA Transformer, ESM1v, Progen2) were evaluated on specific benchmark tasks.</p>
<h3 id="existing-splits-as-points-on-the-spc">Existing Splits as Points on the SPC</h3>
<p>SPECTRA reveals that existing benchmark splits correspond to specific points on the spectral performance curve. For instance:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Benchmark Split</th>
          <th>Cross-Split Overlap</th>
          <th>Spectral Parameter</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Remote homology</td>
          <td>TAPE family</td>
          <td>97%</td>
          <td>0.025</td>
      </tr>
      <tr>
          <td>Remote homology</td>
          <td>TAPE superfamily</td>
          <td>71%</td>
          <td>0.475</td>
      </tr>
      <tr>
          <td>Secondary structure</td>
          <td>CASP12</td>
          <td>48%</td>
          <td>0.5</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>Equibind temporal</td>
          <td>76%</td>
          <td>0.55</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>LPPDBind similarity</td>
          <td>91%</td>
          <td>0.275</td>
      </tr>
      <tr>
          <td>Protein-ligand binding</td>
          <td>Posebusters</td>
          <td>70%</td>
          <td>0.575</td>
      </tr>
  </tbody>
</table>
<h2 id="performance-degradation-and-foundation-model-insights">Performance Degradation and Foundation Model Insights</h2>
<h3 id="universal-performance-decline">Universal Performance Decline</h3>
<p>All evaluated models demonstrate decreased performance as cross-split overlap decreases. Logistic regression drops from AUROC &gt; 0.9 to 0.5 for rifampicin resistance. ESM2-Finetuned decreases from Spearman&rsquo;s $\rho &gt; 0.9$ to less than 0.4 for GFP fluorescence prediction.</p>
<p>No single model achieves the highest AUSPC across all tasks. CNN maintains AUSPC &gt; 0.6 across all tasks but is surpassed by ESM2-Finetuned and ESM2 on rifampicin resistance. Some models retain reasonable performance even at $\mathbf{SP} = 1$ (minimal overlap): ESM2, ESM2-Finetuned, and CNN maintain AUROC &gt; 0.7 for RIF and PZA at this extreme.</p>
<h3 id="uncovering-hidden-spectral-properties">Uncovering Hidden Spectral Properties</h3>
<p>SPECTRA can detect unconsidered spectral properties through high variance in model performance at fixed spectral parameters. For rifampicin resistance, the CNN shows high variance at $\mathbf{SP} = 0.9$, $0.95$, and $1.0$ (standard deviations of 0.09, 0.10, and 0.08 respectively).</p>
<p>The authors trace this to the rifampicin resistance determining region (RRDR), a 26-amino-acid region of the rpoB gene. They define diff-RRDR as:</p>
<p>$$
\text{diff-RRDR} = \left(\max\left(\text{position}_{\text{train}}\right) - \max\left(\text{position}_{\text{test}}\right)\right) + \left(\min\left(\text{position}_{\text{train}}\right) - \min\left(\text{position}_{\text{test}}\right)\right)
$$</p>
<p>diff-RRDR correlates with CNN performance variance (Spearman&rsquo;s $\rho = -0.51$, p-value $= 1.79 \times 10^{-5}$) but not with ESM2 performance. The authors attribute this to ESM2&rsquo;s larger context window (512 positions vs. CNN&rsquo;s 12), making it more invariant to positional shifts in resistance-determining mutations.</p>
<h3 id="foundation-model-generalizability">Foundation Model Generalizability</h3>
<p>For protein foundation models, SPECTRA reveals that AUSPC correlates with the similarity between task-specific datasets and the pretraining dataset. ESM2&rsquo;s AUSPC varies from 0.91 (RIF) to 0.26 (SARS-CoV-2). The correlation between UniRef50 overlap and AUSPC is strong (Spearman&rsquo;s $\rho = 0.9$, p-value $= 1.4 \times 10^{-27}$).</p>
<p>This finding holds across multiple foundation models (Transception, MSA Transformer, ESM1v, Progen2) evaluated on five ProteinGym datasets (Spearman&rsquo;s $\rho = 0.9$, p-value $= 0.04$). Fine-tuning improves AUSPC for tasks with low pretraining overlap (PZA, SARS-CoV-2, GFP).</p>
<h3 id="computational-cost">Computational Cost</h3>
<p>Generating SPECTRA splits ranges from 5 minutes (amyloid beta aggregation) to 9 hours (PDBBind). Generating spectral performance curves ranges from 1 hour (logistic regression) to 5 days (ESM2-Finetuned). The authors recommend releasing SPECTRA splits alongside new benchmarks to amortize this cost.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>The authors acknowledge several limitations:</p>
<ul>
<li><strong>Spectral property selection is pivotal</strong>: The choice of spectral property must be biologically informed and task-specific. Standardized definitions across the community are needed.</li>
<li><strong>Computational cost</strong>: Running SPECTRA is expensive, especially for large models. The authors mitigate this with multi-core CPU parallelization and multi-GPU training.</li>
<li><strong>Not a model ranking tool</strong>: SPECTRA is designed for understanding generalizability patterns, not for ranking models. Proper ranking requires averaging AUSPCs across many tasks in a standardized benchmark.</li>
<li><strong>Spectral parameter vs. cross-split overlap</strong>: The minimal achievable cross-split overlap varies across tasks, so SPECTRA plots performance against the spectral parameter rather than overlap directly. This means the AUSPC reflects relative impact on performance per unit decrease in overlap.</li>
</ul>
<p>The authors envision SPECTRA as a foundation for next-generation molecular benchmarks that explicitly characterize generalizability across the full spectrum of distribution shift, applicable beyond molecular data to small molecule therapeutics, inverse protein folding, and patient-level clinical datasets.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>All data used in this study is publicly available.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>TB RIF resistance</td>
          <td>17,474 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TB INH resistance</td>
          <td>26,574 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TB PZA resistance</td>
          <td>12,146 isolates</td>
          <td>From Green et al. (2022)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>GFP fluorescence</td>
          <td>54,024 samples</td>
          <td>From Sarkisyan et al. (2016)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>SARS-CoV-2 escape</td>
          <td>438,046 samples</td>
          <td>From Greaney et al. (2021)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>TAPE (remote homology, secondary structure)</td>
          <td>Various</td>
          <td>From Rao et al. (2019)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>PEER (subcellular localization)</td>
          <td>13,949 samples</td>
          <td>From Xu et al. (2022)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>ProteinGym (amyloid, RRM)</td>
          <td>Various</td>
          <td>From Notin et al. (2022)</td>
      </tr>
      <tr>
          <td>Benchmark</td>
          <td>PDBBind (protein-ligand binding)</td>
          <td>14,993-16,742 complexes</td>
          <td>From Wang et al. (2005)</td>
      </tr>
  </tbody>
</table>
<p>Data is also available on <a href="https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/W5UUNN">Harvard Dataverse</a>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Spectral property comparison uses Biopython pairwise alignment (match=1, mismatch=-2, gap=-2.5) with a 0.3 similarity threshold for sequence-to-sequence datasets</li>
<li>Greedy randomized maximal independent set approximation for split generation</li>
<li>Spectral parameter incremented in 0.05 steps from 0 to 1</li>
<li>Three random seeds per spectral parameter value</li>
<li>80/20 train-test split ratio enforced via subset sum for mutational scan datasets</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>ESM2: 650M parameter version from Lin et al. (2023)</li>
<li>ESM2-Finetuned: First 30 layers frozen, masked language head replaced with linear prediction layer</li>
<li>GearNet and GearNet-Finetuned: Protein structures generated via ESMFold</li>
<li>CNN: Architecture from Green et al. (2022), one-hot encoded sequences</li>
<li>Logistic regression: One-hot encoded mutational barcodes</li>
<li>EVE and SeqDesign: MSAs constructed via Jackhmmer against UniRep100</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>AUROC</td>
          <td>TB resistance (RIF, INH, PZA)</td>
          <td>Binary classification</td>
      </tr>
      <tr>
          <td>Spearman&rsquo;s $\rho$</td>
          <td>GFP fluorescence, SARS-CoV-2 escape</td>
          <td>Regression tasks</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>Remote homology, secondary structure, subcellular localization</td>
          <td>Per-label/class accuracy</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Protein-ligand binding</td>
          <td>Predicted vs. actual complex</td>
      </tr>
      <tr>
          <td>AUSPC</td>
          <td>All tasks</td>
          <td>Area under spectral performance curve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Most models: 1x Tesla A10 GPU</li>
<li>ESM2-Finetuned: 4x Tesla A100 GPUs on Azure cluster</li>
<li>Hyperparameter optimization: Weights &amp; Biases random search over learning rate</li>
<li>All code in PyTorch</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mims-harvard/SPECTRA">SPECTRA Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Framework implementation and reproduction scripts</td>
      </tr>
      <tr>
          <td><a href="https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/W5UUNN">Harvard Dataverse</a></td>
          <td>Dataset</td>
          <td>CC0 1.0</td>
          <td>All datasets and generated splits</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ektefaie, Y., Shen, A., Bykova, D., Marin, M. G., Zitnik, M., &amp; Farhat, M. (2024). Evaluating generalizability of artificial intelligence models for molecular datasets. <em>Nature Machine Intelligence</em>, 6(12), 1512-1524. <a href="https://doi.org/10.1038/s42256-024-00931-6">https://doi.org/10.1038/s42256-024-00931-6</a></p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ektefaie2024evaluating,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Evaluating generalizability of artificial intelligence models for molecular datasets}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ektefaie, Yasha and Shen, Andrew and Bykova, Daria and Marin, Maximillian G. and Zitnik, Marinka and Farhat, Maha}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1512--1524}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-024-00931-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Perplexity for Molecule Ranking and CLM Bias Detection</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/perplexity-molecule-ranking-bias-clms/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/perplexity-molecule-ranking-bias-clms/</guid><description>Perplexity scoring enables intrinsic molecule ranking and pretraining bias detection in chemical language models for de novo drug design.</description><content:encoded><![CDATA[<h2 id="a-method-for-intrinsic-scoring-and-bias-detection-in-chemical-language-models">A Method for Intrinsic Scoring and Bias Detection in Chemical Language Models</h2>
<p>This is a <strong>Method</strong> paper that introduces two contributions to the chemical language model (CLM) pipeline for <a href="/notes/chemistry/molecular-design/generation/evaluation/clms-de-novo-drug-design-review/">de novo molecular design</a>. First, the authors propose using perplexity as a model-intrinsic score to rank generated <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings by how well they match the design objectives encoded in the fine-tuning data. Second, they introduce a &ldquo;delta score&rdquo; that compares molecule rankings from pretrained and fine-tuned CLMs to detect pretraining bias, where molecules are generated primarily based on generic pretraining knowledge rather than task-specific fine-tuning objectives.</p>
<h2 id="the-ranking-and-bias-problem-in-clm-based-molecule-generation">The Ranking and Bias Problem in CLM-Based Molecule Generation</h2>
<p>Chemical language models generate new molecules as SMILES strings by iteratively predicting the next character based on learned probability distributions. After training, CLMs can produce large virtual libraries of candidate molecules via multinomial sampling. However, two key challenges remain: (1) the generated molecules lack a natural ranking, requiring external scoring methods such as similarity assessment or activity prediction for prioritization, and (2) <a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">transfer learning</a> (pretraining on a large corpus followed by fine-tuning on a small target set) can introduce &ldquo;pretraining bias,&rdquo; where some generated molecules reflect generic chemical knowledge from pretraining rather than the specific design objectives of the fine-tuning data.</p>
<p>Beam search offers an alternative sampling approach that produces inherently ranked molecules by greedily selecting the most probable SMILES strings. However, beam search explores only a narrow portion of chemical space. The authors sought to combine the ranking advantage of beam search with the chemical space exploration of multinomial sampling by applying perplexity scoring as a post-hoc ranking criterion.</p>
<h2 id="perplexity-scoring-and-the-delta-score-for-bias-estimation">Perplexity Scoring and the Delta Score for Bias Estimation</h2>
<p>The core innovation is the application of <a href="https://en.wikipedia.org/wiki/Perplexity">perplexity</a>, a standard evaluation metric from natural language processing, to score SMILES strings generated by CLMs. For a SMILES string of length $N$ with character probabilities $p_i$ assigned by the CLM, perplexity is computed as:</p>
<p>$$
\text{perplexity} = 2^{-\frac{1}{N} \sum_{i=1}^{N} \log(p_{i})}
$$</p>
<p>Low perplexity indicates that the CLM assigns high probability to each character in the SMILES string, suggesting the molecule closely matches the learned distribution of the fine-tuning data. The metric is normalized by string length, making it comparable across molecules of different sizes.</p>
<p>To address pretraining bias, the authors introduce a delta score. For each generated molecule, the perplexity-based rank from the fine-tuned model ($\text{rank}_{ft}$) is compared against the rank from the pretrained model ($\text{rank}_{pt}$):</p>
<p>$$
\text{delta} = \text{rank}_{ft} - \text{rank}_{pt}
$$</p>
<p>A positive delta score indicates that the fine-tuned model ranks the molecule higher than the pretrained model, suggesting the molecule was generated based on task-specific fine-tuning knowledge. A negative delta score flags molecules that may have been generated primarily from pretraining information, which do not necessarily match the design objectives.</p>
<p>The multinomial sampling probability for each character is computed via the softmax function:</p>
<p>$$
p_{i} = \frac{e^{z_{i}/T}}{\sum_{j} e^{z_{j}/T}}
$$</p>
<p>where $z_{i}$ is the CLM output logit for the $i$th character, $j$ runs over all dictionary characters, and $T$ is the temperature parameter (set to $T = 1$ in this study).</p>
<h2 id="experimental-setup-10-protein-targets-across-four-data-regimes">Experimental Setup: 10 Protein Targets Across Four Data Regimes</h2>
<p>The authors systematically evaluated perplexity scoring across 10 macromolecular targets and four low-data fine-tuning regimes (5, 10, 20, and 40 molecules per target).</p>
<p><strong>Model architecture</strong>: A four-layer LSTM-based RNN (5,820,515 parameters) with batch normalization layers, LSTM layers of 1024 and 256 units, trained using the Adam optimizer with a learning rate of $10^{-4}$.</p>
<p><strong>Pretraining</strong>: The model was pretrained on 1,683,181 molecules from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> (version 28), encoded as canonical SMILES (20-90 characters), for 90 epochs.</p>
<p><strong>Fine-tuning</strong>: For each of 10 randomly selected protein targets (Table 1), bioactive ligands with pChEMBL &gt; 6 were selected. Fine-tuning sets of 5, 10, 20, and 40 molecules were compiled for each target. Fine-tuning ran for 100 epochs, with 1,000 SMILES strings sampled every second epoch via multinomial sampling ($T = 1$).</p>
<table>
  <thead>
      <tr>
          <th>CHEMBL ID</th>
          <th>Target</th>
          <th>Protein Classification</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CHEMBL1836</td>
          <td>Prostanoid EP4 receptor</td>
          <td><a href="https://en.wikipedia.org/wiki/G_protein-coupled_receptor">G protein-coupled receptor</a></td>
      </tr>
      <tr>
          <td>CHEMBL1945</td>
          <td>Melatonin receptor 1A</td>
          <td>G protein-coupled receptor</td>
      </tr>
      <tr>
          <td>CHEMBL1983</td>
          <td>Serotonin 1D (5-HT1D) receptor</td>
          <td>Family A GPCR</td>
      </tr>
      <tr>
          <td>CHEMBL202</td>
          <td><a href="https://en.wikipedia.org/wiki/Dihydrofolate_reductase">Dihydrofolate reductase</a></td>
          <td>Oxidoreductase</td>
      </tr>
      <tr>
          <td>CHEMBL3522</td>
          <td><a href="https://en.wikipedia.org/wiki/Cytochrome_P450">Cytochrome P450</a> 17A1</td>
          <td>Cytochrome P450</td>
      </tr>
      <tr>
          <td>CHEMBL4029</td>
          <td>Interleukin-8 receptor A</td>
          <td>Family A GPCR</td>
      </tr>
      <tr>
          <td>CHEMBL5073</td>
          <td>CaM kinase I delta</td>
          <td>Kinase</td>
      </tr>
      <tr>
          <td>CHEMBL5137</td>
          <td>Metabotropic glutamate receptor 2</td>
          <td>G protein-coupled receptor</td>
      </tr>
      <tr>
          <td>CHEMBL5408</td>
          <td>Serine/threonine-protein kinase TBK1</td>
          <td>Kinase</td>
      </tr>
      <tr>
          <td>CHEMBL5608</td>
          <td>NT-3 growth factor receptor</td>
          <td>Kinase</td>
      </tr>
  </tbody>
</table>
<p><strong>Sampling comparison</strong>: Beam search sampling was performed with beam widths $k = 10$ and $k = 50$ for comparison against multinomial sampling.</p>
<p><strong>Molecular similarity</strong>: <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> was computed using Morgan fingerprints (radius 2, length 1024) and 2D <a href="https://en.wikipedia.org/wiki/Pharmacophore">pharmacophore</a> fingerprints via RDKit (2019.03.2).</p>
<h2 id="key-findings-multinomial-sampling-outperforms-beam-search">Key Findings: Multinomial Sampling Outperforms Beam Search</h2>
<p><strong>Perplexity correlates with molecular similarity.</strong> The Pearson correlation between perplexity and Tanimoto distance to the fine-tuning set stabilized at approximately 0.5 across all data regimes. This correlation emerged earlier with larger fine-tuning sets. The result confirms that perplexity captures both substructural and pharmacophore features while also incorporating additional CLM-learned information.</p>
<p><strong>Multinomial sampling produces better-ranked molecules than beam search.</strong> With the smallest fine-tuning sets (5 molecules), the top 50 molecules from multinomial sampling consistently exhibited lower (better) perplexity values than beam search at $k = 10$ or $k = 50$. Increasing the beam width from 10 to 50 did not markedly improve beam search performance. For novel molecules (Tanimoto similarity below 50% to the nearest fine-tuning compound), multinomial sampling identified lower-perplexity molecules in 72% of cases with the smallest fine-tuning sets.</p>
<p><strong>Perplexity scoring narrows the quality distribution.</strong> The top 50 molecules selected by perplexity from multinomial sampling spanned a narrower range of perplexity values compared to beam search, suggesting a more consistent pool of high-quality candidates for follow-up synthesis.</p>
<p><strong>Pretraining bias is substantial.</strong> The delta score analysis revealed that more than 40% of sampled molecules had negative delta scores during the first 20 fine-tuning epochs, meaning they were ranked higher by the pretrained model than the fine-tuned model. This fraction remained above 10% even at the end of 100 fine-tuning epochs across all data regimes, confirming that 10-40% of generated molecules reflect &ldquo;generic&rdquo; pretraining rather than task-focused fine-tuning.</p>
<p><strong>Perplexity alone partially mitigates bias.</strong> Among the top 50 molecules selected by perplexity from multinomial sampling, only up to 3% had negative delta scores, compared to 10-40% in the unfiltered population. This suggests that perplexity-based ranking already reduces pretraining bias, though the delta score provides additional filtering power.</p>
<p><strong>SMILES validity remained high.</strong> Mean SMILES string validity consistently exceeded 90% across all fine-tuned models and fine-tuning epochs.</p>
<h3 id="limitations">Limitations</h3>
<p>The authors note several limitations and future directions. The study used a fixed temperature of $T = 1$ for multinomial sampling; combining perplexity with temperature tuning or <a href="/notes/chemistry/molecular-design/property-prediction/maxsmi-smiles-augmentation-property-prediction/">SMILES augmentation</a> remains unexplored. The evaluation focused on 10 protein targets, and broader validation across diverse target classes would strengthen the conclusions. The authors also suggest that combining CLMs with perplexity scoring could be applied to screen large collections of commercially available compounds, which has not yet been tested.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL v28</td>
          <td>1,683,181 molecules</td>
          <td>Canonical SMILES, 20-90 characters, salts and duplicates removed</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>ChEMBL v28 (split)</td>
          <td>84,160 molecules</td>
          <td>Random split from pretraining set</td>
      </tr>
      <tr>
          <td>Fine-tuning</td>
          <td>ChEMBL v28 (per target)</td>
          <td>5, 10, 20, or 40 molecules</td>
          <td>pChEMBL &gt; 6, 10 targets</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>LSTM-based CLM with character-level SMILES prediction</li>
<li>Multinomial sampling at $T = 1$</li>
<li>Beam search at $k = 10$ and $k = 50$</li>
<li>Perplexity computed per Equation 1; delta score per Equation 2</li>
<li>Adam optimizer, learning rate $10^{-4}$, 90 pretraining epochs, 100 fine-tuning epochs</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>4-layer LSTM RNN: batch normalization, LSTM (1024 units), LSTM (256 units), batch normalization</li>
<li>5,820,515 parameters total</li>
<li>One-hot encoded SMILES input</li>
<li>Pretrained weights available in the GitHub repository</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Perplexity</td>
          <td>Model confidence in generated SMILES</td>
          <td>Lower is better</td>
      </tr>
      <tr>
          <td>Delta score</td>
          <td>Rank difference between fine-tuned and pretrained models</td>
          <td>Positive indicates task-relevant generation</td>
      </tr>
      <tr>
          <td>Tanimoto similarity</td>
          <td>Morgan and pharmacophore fingerprints</td>
          <td>Compared to fine-tuning set</td>
      </tr>
      <tr>
          <td>Pearson correlation</td>
          <td>Perplexity vs. Tanimoto distance</td>
          <td>Stabilizes at ~0.5</td>
      </tr>
      <tr>
          <td>SMILES validity</td>
          <td>Fraction of valid SMILES strings</td>
          <td>Consistently &gt; 90%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not reported in the paper. The implementation uses Keras (v2.2.0) with TensorFlow GPU backend (v1.9.0).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ETHmodlab/CLM_perplexity">CLM_perplexity</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Framework, pretrained weights, and training data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/ETHmodlab/molecular_design_with_beam_search">Beam search implementation</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Referenced beam search implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Moret, M., Grisoni, F., Katzberger, P., &amp; Schneider, G. (2022). Perplexity-Based Molecule Ranking and Bias Estimation of Chemical Language Models. <em>Journal of Chemical Information and Modeling</em>, 62(5), 1199-1206. <a href="https://doi.org/10.1021/acs.jcim.2c00079">https://doi.org/10.1021/acs.jcim.2c00079</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling, 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/ETHmodlab/CLM_perplexity">GitHub: CLM_perplexity (MIT License)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{moret2022perplexity,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Perplexity-Based Molecule Ranking and Bias Estimation of Chemical Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Moret, Michael and Grisoni, Francesca and Katzberger, Paul and Schneider, Gisbert}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{62}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1199--1206}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.2c00079}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Frechet ChemNet Distance for Molecular Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/</guid><description>FCD uses ChemNet activations and the Wasserstein-2 distance to evaluate molecular generative models for chemical validity, biological relevance, and diversity.</description><content:encoded><![CDATA[<h2 id="a-unified-evaluation-metric-for-molecular-generation">A Unified Evaluation Metric for Molecular Generation</h2>
<p>This is a <strong>Method</strong> paper that introduces the Frechet ChemNet Distance (FCD), a single scalar metric for evaluating generative models that produce molecules for drug discovery. FCD adapts the Frechet Inception Distance (FID) from image generation to the molecular domain. By comparing distributions of learned representations from a drug-activity prediction network (ChemNet), FCD simultaneously captures whether generated molecules are chemically valid, biologically relevant, and structurally diverse.</p>
<h2 id="inconsistent-evaluation-of-molecular-generative-models">Inconsistent Evaluation of Molecular Generative Models</h2>
<p>At the time of this work (2018), deep generative models for molecules were proliferating: RNNs combined with <a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">variational autoencoders</a>, reinforcement learning, and <a href="/posts/what-is-a-gan/">GANs</a> all produced <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings representing novel molecules. The evaluation landscape was fragmented. Different papers reported different metrics: percentage of valid SMILES, mean logP, druglikeness, synthetic accessibility (SA) scores, or internal diversity via Tanimoto distance.</p>
<p>This inconsistency created several problems. First, method comparison across publications was difficult because no common metric existed. Second, simple metrics like &ldquo;fraction of valid SMILES&rdquo; could be trivially maximized by generating short, simple molecules (e.g., &ldquo;CC&rdquo; or &ldquo;CCC&rdquo;). Third, individual property metrics (logP, druglikeness) each captured only one dimension of quality. A model could score well on logP but produce molecules that were not diverse or not biologically meaningful.</p>
<p>The authors argued that a good metric should capture three properties simultaneously: (1) chemical validity and similarity to real drug-like molecules, (2) biological relevance, and (3) diversity within the generated set.</p>
<h2 id="core-innovation-frechet-distance-over-chemnet-activations">Core Innovation: Frechet Distance over ChemNet Activations</h2>
<p>The key insight is to use a neural network trained on biological activity prediction as a feature extractor for molecules, then compare distributions of these features using the Frechet (Wasserstein-2) distance.</p>
<h3 id="chemnet-architecture">ChemNet Architecture</h3>
<p>ChemNet is a multi-task neural network trained to predict bioactivities across approximately 6,000 assays from three major drug discovery databases (ChEMBL, ZINC, PubChem). The architecture processes one-hot encoded SMILES strings through:</p>
<ol>
<li>Two 1D convolutional layers with SELU activations</li>
<li>A max-pooling layer</li>
<li>Two stacked LSTM layers</li>
<li>A fully connected output layer</li>
</ol>
<p>The penultimate layer (the second LSTM&rsquo;s hidden state after processing the full input sequence) serves as the molecular representation. Because ChemNet was trained to predict drug activities, its internal representations encode both chemical structure (from the input side) and biological function (from the output side).</p>
<h3 id="the-fcd-formula">The FCD Formula</h3>
<p>Given a set of real molecules and a set of generated molecules, FCD is computed as follows:</p>
<ol>
<li>Pass each molecule (as a SMILES string) through ChemNet and extract penultimate-layer activations.</li>
<li>Fit a multivariate Gaussian to each set by computing the mean $\mathbf{m}$ and covariance $\mathbf{C}$ for the generated set, and mean $\mathbf{m}_w$ and covariance $\mathbf{C}_w$ for the real set.</li>
<li>Compute the squared Frechet distance:</li>
</ol>
<p>$$
d^{2}\left((\mathbf{m}, \mathbf{C}), (\mathbf{m}_w, \mathbf{C}_w)\right) = |\mathbf{m} - \mathbf{m}_w|_2^{2} + \mathrm{Tr}\left(\mathbf{C} + \mathbf{C}_w - 2(\mathbf{C}\mathbf{C}_w)^{1/2}\right)
$$</p>
<p>The Gaussian assumption is justified by the maximum entropy principle: the Gaussian is the maximum-entropy distribution for given mean and covariance. A lower FCD indicates that the generated distribution is closer to the real distribution.</p>
<h3 id="why-not-just-fingerprints">Why Not Just Fingerprints?</h3>
<p>The authors also define a Frechet Fingerprint Distance (FFD) that replaces ChemNet activations with 2048-bit ECFP_4 fingerprints. FFD captures chemical structure but not biological function. The experimental comparison shows that FCD produces more distinct separations between biased and unbiased molecule sets, particularly for biologically meaningful biases.</p>
<h2 id="detecting-flaws-in-generative-models">Detecting Flaws in Generative Models</h2>
<p>The experiments evaluate whether FCD can detect specific failure modes in generative models. The authors simulate five types of biased generators by selecting molecules from real databases that exhibit particular properties, then compare FCD against individual metrics (logP, druglikeness, SA score, internal diversity) and FFD.</p>
<h3 id="simulated-bias-experiments">Simulated Bias Experiments</h3>
<p>All experiments use 5,000 molecules drawn 5 times each. The reference distribution is 200,000 randomly drawn real molecules not used for ChemNet training.</p>
<table>
  <thead>
      <tr>
          <th>Bias Type</th>
          <th>logP</th>
          <th>Druglikeness</th>
          <th>SA Score</th>
          <th>Int. Diversity</th>
          <th>FFD</th>
          <th>FCD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Low druglikeness (&lt;5th pct)</td>
          <td>-</td>
          <td>Detects</td>
          <td>-</td>
          <td>-</td>
          <td>Detects</td>
          <td>Detects</td>
      </tr>
      <tr>
          <td>High logP (&gt;95th pct)</td>
          <td>Detects</td>
          <td>Detects</td>
          <td>-</td>
          <td>-</td>
          <td>Detects</td>
          <td>Detects</td>
      </tr>
      <tr>
          <td>Low SA score (&lt;5th pct)</td>
          <td>-</td>
          <td>Partial</td>
          <td>-</td>
          <td>Partial</td>
          <td>Detects</td>
          <td>Detects</td>
      </tr>
      <tr>
          <td>Mode collapse (cluster)</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
          <td>Detects</td>
          <td>Detects</td>
          <td>Detects</td>
      </tr>
      <tr>
          <td>Kinase inhibitors (PLK1)</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
          <td>Detects</td>
          <td>Detects</td>
      </tr>
  </tbody>
</table>
<p>FCD is the only metric that detects all five bias types. The biological bias test (kinase inhibitors for PLK1-PBD from PubChem AID 720504) is particularly notable: only FFD and FCD detect this bias, and FCD provides a more distinct separation. This validates the hypothesis that incorporating biological information through ChemNet activations improves evaluation beyond purely chemical descriptors.</p>
<h3 id="sample-size-requirements">Sample Size Requirements</h3>
<p>The authors tested FCD convergence with varying sample sizes (5 to 300,000 molecules). Mean FCD values for samples drawn from the real distribution:</p>
<table>
  <thead>
      <tr>
          <th>Sample Size</th>
          <th>Mean FCD</th>
          <th>Std Dev</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>5</td>
          <td>76.46</td>
          <td>5.03</td>
      </tr>
      <tr>
          <td>50</td>
          <td>31.86</td>
          <td>0.75</td>
      </tr>
      <tr>
          <td>500</td>
          <td>4.41</td>
          <td>0.03</td>
      </tr>
      <tr>
          <td>5,000</td>
          <td>0.42</td>
          <td>0.01</td>
      </tr>
      <tr>
          <td>50,000</td>
          <td>0.05</td>
          <td>0.00</td>
      </tr>
      <tr>
          <td>300,000</td>
          <td>0.02</td>
          <td>0.00</td>
      </tr>
  </tbody>
</table>
<p>A sample size of 5,000 molecules is sufficient for reliable estimation, with the mean FCD approaching zero and negligible variance.</p>
<h3 id="benchmarking-published-generative-models">Benchmarking Published Generative Models</h3>
<p>The authors computed FCD for several published generative methods:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>FCD</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Random real molecules</td>
          <td>0.22</td>
          <td>Baseline (near zero as expected)</td>
      </tr>
      <tr>
          <td>Segler et al. (LSTM)</td>
          <td>1.62</td>
          <td>Trained to approximate full ChEMBL distribution</td>
      </tr>
      <tr>
          <td>DRD2-targeted methods</td>
          <td>24.14 to 47.85</td>
          <td>Olivecrona, RL, and ORGAN agents</td>
      </tr>
      <tr>
          <td>Rule-based baseline</td>
          <td>58.76</td>
          <td>Random concatenation of C, N, O atoms</td>
      </tr>
  </tbody>
</table>
<p>The ranking matches expectations. The Segler model, trained to approximate the overall molecule distribution, achieves the lowest FCD (1.62). Models optimized for a specific target (DRD2), including the Olivecrona RL agents, the RL method by Benhenda, and ORGAN, produce higher FCD values (24.14 to 47.85) against the general distribution. More training iterations push these models further from the general distribution, as they become increasingly DRD2-specific. The canonical and reduced Olivecrona agents learn similar chemical spaces, consistent with the original authors&rsquo; conclusions. The rule-based system scores worst (58.76), confirming FCD as a meaningful quality metric.</p>
<h2 id="conclusions-and-impact">Conclusions and Impact</h2>
<p>FCD provides a single metric that unifies the evaluation of chemical validity, biological relevance, and diversity for molecular generative models. Its main advantages are:</p>
<ol>
<li>It captures multiple quality dimensions in one score, simplifying method comparison.</li>
<li>It detects biases that no single existing metric can catch alone.</li>
<li>It requires only SMILES strings as input, making it applicable to any generative method (including graph-based approaches via SMILES conversion).</li>
<li>It incorporates biological information through ChemNet, distinguishing it from purely chemical metrics like FFD.</li>
</ol>
<p><strong>Limitations</strong>: The metric depends on the ChemNet model, which was trained on a specific set of bioactivity assays. Molecules outside the training distribution of ChemNet may not be well-represented. The Gaussian assumption for the activation distributions may not hold perfectly. FCD measures distance to a reference set, so it evaluates how well a generator approximates a given distribution rather than the absolute quality of individual molecules. When using FCD for targeted generation (e.g., molecules active against a specific protein), the reference set should be chosen accordingly, not the general drug-like molecule distribution.</p>
<p>FCD has since become a standard evaluation metric in the molecular generation community, adopted by benchmarking platforms like <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a>.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemNet training</td>
          <td>ChEMBL, ZINC, PubChem</td>
          <td>~6,000 assays</td>
          <td>Two-thirds for training, one-third for testing</td>
      </tr>
      <tr>
          <td>Reference distribution</td>
          <td>Combined databases</td>
          <td>200,000 molecules</td>
          <td>Excluded from ChemNet training</td>
      </tr>
      <tr>
          <td>Bias simulations</td>
          <td>Subsets of combined databases</td>
          <td>5,000 per experiment</td>
          <td>5 repetitions each</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>ChemNet: 2x 1D-conv (SELU), max-pool, 2x stacked LSTM, FC output</li>
<li>FCD: Squared Frechet distance between Gaussian-fitted ChemNet penultimate-layer activations</li>
<li>FFD: Same as FCD but using 2048-bit ECFP_4 fingerprints instead of ChemNet activations</li>
<li>Molecular property calculations: RDKit (logP, druglikeness, SA score, Morgan fingerprints with radius 2)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FCD</td>
          <td>Frechet distance over ChemNet activations (lower = closer to reference)</td>
      </tr>
      <tr>
          <td>FFD</td>
          <td>Frechet distance over ECFP_4 fingerprints</td>
      </tr>
      <tr>
          <td>logP</td>
          <td>Mean partition coefficient</td>
      </tr>
      <tr>
          <td>Druglikeness</td>
          <td>Geometric mean of desired molecular properties (QED)</td>
      </tr>
      <tr>
          <td>SA Score</td>
          <td>Synthetic accessibility score</td>
      </tr>
      <tr>
          <td>Internal Diversity</td>
          <td>Tanimoto distance within generated set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Hardware specifications are not provided in the paper.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/bioinf-jku/FCD">FCD Implementation</a></td>
          <td>Code</td>
          <td>LGPL-3.0</td>
          <td>Official Python implementation; requires only SMILES input</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Preuer, K., Renz, P., Unterthiner, T., Hochreiter, S., &amp; Klambauer, G. (2018). Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery. <em>Journal of Chemical Information and Modeling</em>, 58(9), 1736-1741.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{preuer2018frechet,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Fr{\&#39;e}chet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Preuer, Kristina and Renz, Philipp and Unterthiner, Thomas and Hochreiter, Sepp and Klambauer, G{\&#34;u}nter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{58}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1736--1741}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.8b00234}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Benchmarking Molecular Property Prediction at Scale</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/</link><pubDate>Wed, 25 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/systematic-study-molecular-property-prediction/</guid><description>A study training 62,820 models finds fixed molecular representations often outperform learned representations for property prediction.</description><content:encoded><![CDATA[<h2 id="a-large-scale-empirical-study-of-molecular-property-prediction">A Large-Scale Empirical Study of Molecular Property Prediction</h2>
<p>This is an <strong>Empirical</strong> paper that systematically benchmarks molecular property prediction across multiple dimensions: molecular representations, model architectures, evaluation metrics, data splitting strategies, and chemical space generalization. The primary contribution is a rigorous, large-scale comparison (62,820 trained models) showing that traditional machine learning models on fixed molecular representations frequently outperform recent deep representation learning approaches, and that several overlooked evaluation factors (statistical testing, metric choice, activity cliffs, dataset size) significantly influence conclusions about model performance.</p>
<h2 id="motivation-overlooked-evaluation-pitfalls-in-molecular-property-prediction">Motivation: Overlooked Evaluation Pitfalls in Molecular Property Prediction</h2>
<p>Molecular property prediction is a core task in AI-driven drug discovery, and recent years have seen a proliferation of representation learning methods (transformers on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, GNNs on molecular graphs) claiming improved performance on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet benchmark datasets</a>. However, the authors identify several systemic problems in how these methods are evaluated:</p>
<ol>
<li><strong>Heavy reliance on MoleculeNet benchmarks</strong>, which may not reflect real-world drug discovery challenges. Some benchmark tasks (e.g., SIDER, ClinTox) are arguably unreasonable because they try to predict outcomes from chemical structure alone when other factors (food-drug interactions, patient-level variables) dominate.</li>
<li><strong>Lack of statistical rigor.</strong> Most papers report mean metrics over 3 or 10 splits without statistical tests. Without rigorous analysis, improved metrics could be statistical noise.</li>
<li><strong>Inconsistent data splits.</strong> Across studies, the actual splits vary because seeds and splitting implementations differ, making cross-paper comparisons unreliable.</li>
<li><strong>Inappropriate metrics.</strong> AUROC, the default for classification, can overestimate performance, especially on imbalanced datasets. Precision-oriented metrics (PPV, NPV) may be more relevant for virtual screening.</li>
<li><strong>Neglect of activity cliffs.</strong> Most studies only evaluate inter-scaffold generalization via scaffold splits, ignoring intra-scaffold generalization where structurally similar molecules exhibit drastically different activities (<a href="/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/">activity cliffs</a>).</li>
</ol>
<h2 id="core-contribution-fixed-representations-often-outperform-learned-representations">Core Contribution: Fixed Representations Often Outperform Learned Representations</h2>
<p>The central finding is that traditional ML models (RF, SVM, XGBoost) operating on fixed molecular representations (RDKit2D descriptors, Morgan fingerprints, MACCS keys, AtomPairs) frequently outperform recent self-supervised pretrained models (<a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, GROVER) across diverse datasets. The authors frame the paper around a central thesis:</p>
<blockquote>
<p>&ldquo;A model cannot save an unqualified dataset which cannot remedy an improper evaluation for an ambiguous chemical space generalization claim.&rdquo;</p></blockquote>
<p>Key findings on representations and models:</p>
<ul>
<li><strong>RF on RDKit2D descriptors</strong> achieves the best performance on BACE, BBBP, ESOL, and Lipop under scaffold split. MolBERT only matches RF in HIV.</li>
<li><strong>Concatenating RDKit2D descriptors to GROVER&rsquo;s learned embeddings (GROVER_RDKit)</strong> significantly improves performance, suggesting the learned representations alone are insufficient and that fixed descriptors carry substantial predictive signal.</li>
<li><strong>For binding activity datasets</strong> (<a href="https://en.wikipedia.org/wiki/Opioid_receptor">opioid receptors</a> MOR, DOR, KOR), MorganBits fingerprints outperform other representations, consistent with the structural nature of binding.</li>
<li><strong>PhysChem descriptors</strong> excel on datasets where properties correlate strongly with simple molecular features (e.g., ESOL has a near-linear relationship between MolLogP and solubility), but perform poorly on binding activity datasets where the relationship is more complex.</li>
</ul>
<h2 id="experimental-setup-62820-models-across-diverse-datasets">Experimental Setup: 62,820 Models Across Diverse Datasets</h2>
<h3 id="models-evaluated">Models evaluated</h3>
<p>The study evaluates nine models across three categories:</p>
<ul>
<li><strong>Traditional ML</strong>: Random Forest (RF), Support Vector Machine (SVM), XGBoost</li>
<li><strong>Regular neural networks</strong>: RNN (GRU variant), GCN, GIN</li>
<li><strong>Pretrained models</strong>: MolBERT (SMILES-based, ~85M parameters, pretrained on 1.6M molecules), GROVER (graph-based, ~48M parameters, pretrained on ~10M molecules), and GROVER_RDKit (GROVER with concatenated RDKit2D descriptors)</li>
</ul>
<h3 id="molecular-representations">Molecular representations</h3>
<p>Six fixed representations are evaluated: RDKit2D descriptors (200 features), PhysChem descriptors (11 features), MACCS keys, MorganBits fingerprints, MorganCounts fingerprints, and AtomPairs fingerprints. Morgan fingerprints use radius 2 and 2048 bits after testing showed little difference between common parameter choices.</p>
<h3 id="datasets">Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Category</th>
          <th>Datasets</th>
          <th>Task Type</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoleculeNet benchmarks</td>
          <td>BACE, BBBP, HIV</td>
          <td>Classification</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>MoleculeNet benchmarks</td>
          <td>ESOL, FreeSolv, Lipop</td>
          <td>Regression</td>
          <td>MoleculeNet</td>
      </tr>
      <tr>
          <td>Opioids-related</td>
          <td>MDR1, CYP2D6, CYP3A4, MOR, DOR, KOR</td>
          <td>Classification + Regression</td>
          <td>ChEMBL</td>
      </tr>
      <tr>
          <td>Activity datasets</td>
          <td>24 targets</td>
          <td>Regression</td>
          <td>Cortes-Ciriano et al.</td>
      </tr>
      <tr>
          <td>Activity datasets</td>
          <td>30 targets (MoleculeACE)</td>
          <td>Regression</td>
          <td>Tilborg et al.</td>
      </tr>
      <tr>
          <td>Descriptor datasets</td>
          <td>MolWt, NumAtoms (16 sizes each)</td>
          <td>Regression</td>
          <td>ZINC250k</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation-protocol">Evaluation protocol</h3>
<ul>
<li>Both scaffold and random splits (80:10:10 ratio)</li>
<li><strong>30 different random seeds</strong> per experiment for statistical rigor</li>
<li><a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">Mann-Whitney U test</a> for pairwise significance ($p &lt; 0.05$, two-sided)</li>
<li>Multiple metrics per task: AUROC, AUPRC, PPV, NPV for classification; RMSE, MAE, $R^2$, Pearson $R$ for regression</li>
</ul>
<h3 id="key-metrics">Key metrics</h3>
<p>Classification:</p>
<p>$$
\text{PPV} = \frac{\text{TP}}{\text{TP} + \text{FP}}
$$</p>
<p>$$
\text{NPV} = \frac{\text{TN}}{\text{TN} + \text{FN}}
$$</p>
<p>Regression:</p>
<p>$$
\text{RMSE} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}
$$</p>
<p>$$
\text{MAE} = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|
$$</p>
<p>$$
\text{Pearson}_R = \frac{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})(\hat{y}_i - \bar{y}_{pred})}{\sqrt{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})^2 \sum_{i=1}^{N} (\hat{y}_i - \bar{y}_{pred})^2}}
$$</p>
<p>$$
R^2 = 1 - \frac{\sum_{i=1}^{N} (y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} (y_i - \bar{y}_{obs})^2}
$$</p>
<h2 id="key-findings-metrics-activity-cliffs-and-dataset-size">Key Findings: Metrics, Activity Cliffs, and Dataset Size</h2>
<h3 id="statistical-testing-is-essential">Statistical testing is essential</h3>
<p>Without statistical tests, there is a real risk of drawing incorrect conclusions. Analysis of individual splits shows that in certain splits, MolBERT or GROVER can appear to outperform RF, even though on aggregate with proper statistical testing, RF is significantly better. For example, in BBBP, RF dominates in 20 of 30 splits, but the remaining 10 could mislead a researcher using only a single split.</p>
<h3 id="metric-choice-changes-conclusions">Metric choice changes conclusions</h3>
<p>Different evaluation metrics can lead to contradictory conclusions about the same models:</p>
<ul>
<li>In BBBP under scaffold split, RF significantly outperforms other models by AUROC, but shows similar performance when evaluated by PPV or NPV.</li>
<li>In FreeSolv, GROVER outperforms RF by Pearson $R$ ($p &lt; 0.05$) but shows similar performance by $R^2$.</li>
<li>Pearson $R$ can overestimate $R^2$: even when $R^2$ drops to zero or negative, Pearson $R$ can remain around 0.5.</li>
<li>AUROC can be over-optimistic, especially on imbalanced datasets like CYP2D6 and CYP3A4.</li>
</ul>
<p>The authors argue that PPV and NPV are more practically relevant for <a href="/notes/chemistry/molecular-design/generation/evaluation/molscore-scoring-benchmarking-framework/">virtual screening</a> than AUROC or AUPRC, since the goal is to identify true hits among predicted positives (or true non-binders among predicted negatives).</p>
<h3 id="activity-cliffs-pose-a-major-challenge">Activity cliffs pose a major challenge</h3>
<p>Activity cliffs, defined as <a href="https://en.wikipedia.org/wiki/IC50">IC50</a> values spanning at least two orders of magnitude within one scaffold, are prevalent in the opioid-related datasets. Although AC scaffolds represent only about 10% of scaffolds, they encompass 25-46% of all molecules:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>AC scaffolds (%)</th>
          <th>AC molecules (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MDR1</td>
          <td>62 (10.2%)</td>
          <td>594 (41.3%)</td>
      </tr>
      <tr>
          <td>CYP2D6</td>
          <td>124 (9.3%)</td>
          <td>710 (31.0%)</td>
      </tr>
      <tr>
          <td>CYP3A4</td>
          <td>146 (7.2%)</td>
          <td>926 (25.2%)</td>
      </tr>
      <tr>
          <td>MOR</td>
          <td>213 (13.1%)</td>
          <td>1627 (46.1%)</td>
      </tr>
      <tr>
          <td>DOR</td>
          <td>178 (11.6%)</td>
          <td>1342 (41.6%)</td>
      </tr>
      <tr>
          <td>KOR</td>
          <td>218 (13.1%)</td>
          <td>1502 (45.2%)</td>
      </tr>
  </tbody>
</table>
<p>Prediction performance is consistently worse for AC molecules, indicating limited intra-scaffold generalization. Removing edge-case molecules (those sharing scaffolds with pIC50 spanning 5 to 7) from test sets generally improves classification performance, confirming that activity cliffs are a key source of prediction error.</p>
<h3 id="dataset-size-is-critical-for-representation-learning">Dataset size is critical for representation learning</h3>
<p>Experiments on descriptor datasets (predicting MolWt and NumAtoms) reveal clear patterns:</p>
<ul>
<li>With fewer than 1K data points, traditional ML on fixed representations outperforms all neural network models except pretrained GROVER, which shows competitive performance in the low-data regime.</li>
<li>MolBERT shows severely limited performance (RMSE &gt; 200 for MolWt) with fewer than 10K data points.</li>
<li>RNN achieves the best performance when dataset size exceeds 10K, demonstrating the promise of representation learning in the &ldquo;big-data&rdquo; regime.</li>
<li>SVM achieves near-perfect RMSE (close to zero) on datasets larger than 10K when paired with AtomPairs fingerprints.</li>
<li>GROVER&rsquo;s performance does not substantially improve with increasing dataset size, while MolBERT improves at 100K but is slow to benefit from more data.</li>
</ul>
<h3 id="representation-learning-models-show-higher-metric-variability">Representation learning models show higher metric variability</h3>
<p>Representation learning models, particularly GROVER, exhibit higher variability in performance metrics across splits. This variability correlates negatively with mean performance: models with higher variability tend to perform worse on average. The authors emphasize the importance of reporting metric variability alongside means.</p>
<h3 id="scaffold-split-versus-random-split">Scaffold split versus random split</h3>
<p>Prediction performance under scaffold split is consistently worse than under random split, confirming the inter-scaffold generalization challenge. Notably, random split alleviates the intra-scaffold generalization challenge because some AC scaffolds are seen during training.</p>
<h3 id="descriptors-correlate-with-specific-properties">Descriptors correlate with specific properties</h3>
<p>PhysChem descriptors excel on datasets where molecular properties correlate with simple descriptors (e.g., MolLogP has near $-1$ correlation with ESOL labels). For binding activity datasets, correlation coefficients mostly fall within $[-0.5, 0.5]$, explaining why PhysChem descriptors show limited performance on those tasks, while structural fingerprints are more useful.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<p>The authors acknowledge several limitations:</p>
<ol>
<li><strong>Uncertainty from model training</strong> (random initialization, mini-batch shuffling) was not fully addressed. Ensembling was not evaluated due to computational cost.</li>
<li><strong>Experimental uncertainty in labels</strong> (noise, measurement error in pIC50 values) was not modeled, though it can be <a href="https://en.wikipedia.org/wiki/Homoscedasticity_and_heteroscedasticity">heteroscedastic</a> and impact performance.</li>
<li><strong>Model explainability</strong> was not covered, although it is important for building trust in AI tools for drug discovery.</li>
<li>The study focused on GROVERbase only (not GROVERlarge) due to computational constraints.</li>
</ol>
<p>Future directions include: exploring better ways to use fixed representations alongside learned ones, developing techniques for chemical space generalization (both inter- and intra-scaffold), incorporating experimental uncertainty into model training and evaluation, and generating larger high-quality datasets to fully harness representation learning models.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Benchmark</td>
          <td>MoleculeNet (BACE, BBBP, HIV, ESOL, FreeSolv, Lipop)</td>
          <td>642-41,127 molecules</td>
          <td>Downloaded from MolMapNet; max length &lt; 400</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>Opioids-related (MDR1, CYP2D6, CYP3A4, MOR, DOR, KOR)</td>
          <td>Varies</td>
          <td>Collected from ChEMBL27; pIC50 values</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>Cortes-Ciriano et al. 24 targets</td>
          <td>Varies</td>
          <td>Activity data for drug targets</td>
      </tr>
      <tr>
          <td>Activity</td>
          <td>MoleculeACE 30 targets</td>
          <td>Varies</td>
          <td>Activity cliffs emphasis</td>
      </tr>
      <tr>
          <td>Descriptor</td>
          <td>MolWt, NumAtoms from <a href="/notes/chemistry/datasets/zinc-22/">ZINC250k</a></td>
          <td>0.1K to 100K</td>
          <td>16 dataset sizes per descriptor</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>RF: 500 trees (following Chemprop)</li>
<li>SVM: linear kernel</li>
<li>XGBoost: gradient boosting regressor/classifier with default hyperparameters</li>
<li>RNN: GRU variant, hidden size 512, 3 fully connected layers</li>
<li>GCN/GIN: embedding dimension 300, 5 convolutional layers, hidden size 512</li>
<li>MolBERT: BERTBase architecture, 768 embedding, 12 layers, 12 heads, ~85M parameters (769 fine-tuned)</li>
<li>GROVER: GROVERbase, ~48M parameters (~5.2M fine-tuned)</li>
<li>All splits repeated 30 times with seeds 0-29</li>
</ul>
<h3 id="models">Models</h3>
<p>All model configurations, splits, and raw predictions are available in the <a href="https://github.com/dengjianyuan/Respite_MPP">GitHub repository</a>.</p>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics: AUROC, AUPRC, PPV, NPV (classification); RMSE, MAE, $R^2$, Pearson $R$ (regression). Statistical testing via Mann-Whitney U test ($p &lt; 0.05$, two-sided). <a href="https://en.wikipedia.org/wiki/Youden%27s_J_statistic">Youden&rsquo;s $J$ statistic</a> used to determine classification threshold for PPV/NPV.</p>
<h3 id="hardware">Hardware</h3>
<p>All neural network experiments run on a single NVIDIA V100 GPU for 100 epochs. Batch size 32 for most experiments; 256 for GROVER on HIV due to compute time (MolBERT takes ~3 hours per split on HIV at batch size 32; GROVER takes ~5 hours at batch size 256). The study is partially funded by Stony Brook University OVPR Seed Grant, using the AI Institute at Stony Brook for computational resources.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/dengjianyuan/Respite_MPP">Respite_MPP</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Code, data, and raw predictions</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.1038/s41467-023-41948-6">Nature Communications article</a></td>
          <td>Paper</td>
          <td>CC-BY-4.0</td>
          <td>Open access</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Deng, J., Yang, Z., Wang, H., Ojima, I., Samaras, D., &amp; Wang, F. (2023). A systematic study of key elements underlying molecular property prediction. <em>Nature Communications</em>, 14, 6395. <a href="https://doi.org/10.1038/s41467-023-41948-6">https://doi.org/10.1038/s41467-023-41948-6</a></p>
<p><strong>Publication</strong>: Nature Communications 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/dengjianyuan/Respite_MPP">Respite_MPP GitHub Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{deng2023systematic,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A systematic study of key elements underlying molecular property prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Deng, Jianyuan and Yang, Zhibo and Wang, Hehe and Ojima, Iwao and Samaras, Dimitris and Wang, Fusheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6395}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-023-41948-6}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ROGI-XD: Roughness of Pretrained Molecular Representations</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/rogi-xd-roughness-pretrained-representations/</link><pubDate>Tue, 24 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/rogi-xd-roughness-pretrained-representations/</guid><description>ROGI-XD enables cross-representation roughness comparison, showing pretrained chemical models produce no smoother QSPR surfaces than fingerprints.</description><content:encoded><![CDATA[<h2 id="evaluating-chemical-foundation-models-through-surface-roughness">Evaluating Chemical Foundation Models Through Surface Roughness</h2>
<p>This is a <strong>Systematization</strong> paper that introduces a metric reformulation (ROGI-XD) and uses it to evaluate whether pretrained chemical models (PCMs) learn representations that produce smoother <a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">quantitative structure-property relationship</a> (QSPR) surfaces than simple baselines. The key finding is negative: pretrained representations are no smoother than molecular fingerprints or descriptors, offering a principled explanation for their inconsistent performance on property prediction benchmarks.</p>
<h2 id="the-smoothness-gap-in-chemical-foundation-models">The Smoothness Gap in Chemical Foundation Models</h2>
<p>Chemical foundation models like ChemBERTa, ChemGPT, and graph-based pretrained networks promise to learn meaningful molecular representations from large unlabeled datasets via self-supervised learning. However, empirical benchmarks consistently show mixed results: these learned representations sometimes match and sometimes underperform simple baselines like Morgan fingerprints or RDKit descriptors.</p>
<p>Prior work by Deng et al. demonstrated that a random forest trained on 2048-bit Morgan fingerprints was competitive with, or superior to, pretrained models like <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a> and GROVER on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> and opioid bioactivity tasks. The authors sought to explain this pattern through the lens of QSPR surface roughness: if pretrained representations do not produce smoother mappings from molecular structure to property, they cannot consistently outperform baselines.</p>
<h2 id="rogi-xd-a-dimensionality-independent-roughness-metric">ROGI-XD: A Dimensionality-Independent Roughness Metric</h2>
<p>The original ROuGhness Index (ROGI) captures global surface roughness by measuring the loss in property dispersion as a dataset is progressively coarse-grained through <a href="https://en.wikipedia.org/wiki/Hierarchical_clustering">hierarchical clustering</a>. However, ROGI values are not comparable across representations of different dimensionalities because distances between randomly sampled points increase with dimension, artificially deflating ROGI for high-dimensional representations.</p>
<p>ROGI-XD addresses this by changing the integration variable. Instead of integrating over normalized distance threshold $t$, ROGI-XD integrates over $1 - \log N_{\text{clusters}} / \log N$, where $N_{\text{clusters}}$ is the number of clusters at a given dendrogram step and $N$ is the dataset size. This variable captures the degree of coarse-graining independent of representation dimensionality, producing comparable roughness values across representations ranging from 14 dimensions (descriptors) to 2048 dimensions (ChemGPT).</p>
<p>The procedure follows five steps: (1) cluster molecules using <a href="https://en.wikipedia.org/wiki/Complete-linkage_clustering">complete linkage</a> at distance threshold $t$, (2) coarse-grain by replacing each property label $y_i$ with its cluster mean $\bar{y}_j$, (3) compute the standard deviation $\sigma_t$ of the coarse-grained dataset, (4) repeat for all dendrogram steps, and (5) compute the area under the curve of $2(\sigma_0 - \sigma_t)$ versus the new integration variable.</p>
<h2 id="representations-and-tasks-evaluated">Representations and Tasks Evaluated</h2>
<p>The study compares seven molecular representations:</p>
<table>
  <thead>
      <tr>
          <th>Representation</th>
          <th>Type</th>
          <th>Dimensionality</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Descriptors</td>
          <td>Fixed</td>
          <td>14</td>
          <td>RDKit (14 properties)</td>
      </tr>
      <tr>
          <td>Morgan FP</td>
          <td>Fixed</td>
          <td>512</td>
          <td>Radius 2, 512-bit</td>
      </tr>
      <tr>
          <td>VAE</td>
          <td>Pretrained</td>
          <td>128</td>
          <td>Character-based <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> VAE, <a href="/notes/chemistry/datasets/zinc-22/">ZINC 250k</a></td>
      </tr>
      <tr>
          <td>GIN</td>
          <td>Pretrained</td>
          <td>300</td>
          <td>Node attribute masking, ZINC 250k</td>
      </tr>
      <tr>
          <td>ChemBERTa</td>
          <td>Pretrained</td>
          <td>384</td>
          <td>77M molecules, masked LM</td>
      </tr>
      <tr>
          <td>ChemGPT</td>
          <td>Pretrained</td>
          <td>2048</td>
          <td>PubChem 10M, causal LM</td>
      </tr>
      <tr>
          <td>Random</td>
          <td>Baseline</td>
          <td>128</td>
          <td>Uniform $[0,1]^{128}$</td>
      </tr>
  </tbody>
</table>
<p>These are evaluated on 17 regression tasks drawn from two sources: ADMET datasets from the Therapeutics Data Commons (TDC) and toy datasets generated using <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> oracle functions. Five ML models are used for cross-validation: KNN, MLP, <a href="https://en.wikipedia.org/wiki/Partial_least_squares_regression">PLS</a>, random forest, and SVR.</p>
<h2 id="pretrained-representations-are-not-smoother">Pretrained Representations Are Not Smoother</h2>
<p>ROGI-XD correlates strongly with cross-validated RMSE across representations (median Pearson $r = 0.72$-$0.88$ depending on model), compared to the original ROGI which produces weak cross-representation correlations (median $r \in [-0.32, 0.28]$). When correlating over both representations and tasks simultaneously, ROGI-XD achieves $r = 0.91$-$0.99$ versus $r = 0.68$-$0.84$ for the original ROGI.</p>
<p>Using this validated metric, the authors find that pretrained representations do not produce smoother QSPR surfaces than fingerprints or descriptors. In more than 50% of tasks, both descriptors and fingerprints generate smoother surfaces. The median relative ROGI-XD increase for pretrained representations is 9.1-21.3% compared to descriptors and 2.3-10.1% compared to fingerprints, indicating rougher surfaces.</p>
<p>As a practical tool, ROGI-XD can guide representation selection without exhaustive benchmarking. Selecting the representation with the lowest ROGI-XD for each task and then optimizing over model architecture results in only a 6.8% average relative increase in best-case model error across the 17 tasks. In 8 of 17 tasks, the lowest ROGI-XD correctly identifies the optimal representation.</p>
<p>Fine-tuning can improve smoothness. On the Lipophilicity task ($N_{\text{tot}} = 4200$), fine-tuning the VAE with a contrastive loss reduces ROGI-XD from 0.254 to 0.107 ($\pm 0.02$), well below the descriptor baseline of 0.227. On the smaller CACO2 task ($N_{\text{tot}} = 910$), fine-tuning yields ROGI-XD of 0.143 ($\pm 0.05$), comparable to descriptors at 0.132. The impact of fine-tuning is sensitive to both the task and the amount of labeled data.</p>
<h2 id="implications-for-chemical-foundation-model-development">Implications for Chemical Foundation Model Development</h2>
<p>The lack of smoothness in pretrained QSPR surfaces explains the inconsistent empirical performance of chemical foundation models. The authors note that ROGI-XD is thematically similar to a contrastive loss, as both scale proportionally with the frequency and severity of activity cliffs. This connection suggests that imposing stronger smoothness assumptions during pretraining, for example through weak supervision on calculable molecular properties, could help produce representations that generalize better to downstream property prediction. ROGI-XD provides a practical tool for evaluating new pretraining strategies without exhaustive benchmark testing: a representation with lower ROGI-XD on a given task is likely to yield lower model error.</p>
<p>A limitation is that the study treats pretrained representations as static (frozen features). Fine-tuning introduces many additional design choices and can substantially improve representation quality, but this evaluation is left for future work. Additionally, the survey of pretrained models is not exhaustive and focuses on four representative architectures.</p>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/coleygroup/rogi-xd">coleygroup/rogi-xd</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained models and notebooks; results reproducible via <code>make all</code></td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining (VAE, GIN)</td>
          <td>ZINC 250k</td>
          <td>250,000</td>
          <td>80/20 train/val split</td>
      </tr>
      <tr>
          <td>Pretraining (ChemBERTa)</td>
          <td>PubChem</td>
          <td>77M</td>
          <td>Masked language modeling</td>
      </tr>
      <tr>
          <td>Pretraining (ChemGPT)</td>
          <td>PubChem 10M</td>
          <td>10M</td>
          <td>Causal language modeling</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TDC ADMET</td>
          <td>~900-10,000 per task</td>
          <td>12 regression tasks</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>GuacaMol oracles</td>
          <td>10,000 per task</td>
          <td>5 synthetic tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>ROGI-XD</strong>: Hierarchical clustering (complete linkage) with integration over $1 - \log N_{\text{clusters}} / \log N$</li>
<li><strong>Cross-validation</strong>: 5-fold CV with KNN, MLP, PLS, RF (n_estimators=50), SVR from scikit-learn</li>
<li><strong>Fine-tuning loss</strong>: $\mathscr{L} = \mathscr{L}_{\text{CE}} + \beta \cdot \mathscr{L}_{\text{KL}} + \gamma \cdot \mathscr{L}_{\text{cont}}$ with $\beta = 0.1$, $\gamma = 50$; contrastive term uses cosine distance in latent space and absolute value in target space</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Two AMD Ryzen Threadripper PRO 3995WX CPUs, four NVIDIA A5000 GPUs, 512 GB RAM, Ubuntu 20.04 LTS.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Graff, D. E., Pyzer-Knapp, E. O., Jordan, K. E., Shakhnovich, E. I., &amp; Coley, C. W. (2023). Evaluating the roughness of structure-property relationships using pretrained molecular representations. <em>Digital Discovery</em>, 2(5), 1452-1460. <a href="https://doi.org/10.1039/d3dd00088e">https://doi.org/10.1039/d3dd00088e</a></p>
<p><strong>Publication</strong>: Digital Discovery 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/coleygroup/rogi-xd">ROGI-XD Code Repository</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{graff2023roughness,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Evaluating the roughness of structure--property relationships using pretrained molecular representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Graff, David E. and Pyzer-Knapp, Edward O. and Jordan, Kirk E. and Shakhnovich, Eugene I. and Coley, Connor W.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1452--1460}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/d3dd00088e}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Neural Scaling of Deep Chemical Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/neural-scaling-of-deep-chemical-models/</link><pubDate>Tue, 24 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/neural-scaling-of-deep-chemical-models/</guid><description>Frey et al. discover neural scaling laws for chemical LLMs and GNN interatomic potentials, showing power-law loss improvements with scale.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>discovery paper</strong> that identifies empirical neural scaling laws in two distinct domains of chemical deep learning: large language models (LLMs) for generative chemistry and graph neural networks (GNNs) for machine-learned interatomic potentials. The paper also introduces training performance estimation (TPE) as a practical tool for accelerating hyperparameter optimization in these domains.</p>
<h2 id="why-scaling-laws-matter-for-chemistry">Why scaling laws matter for chemistry</h2>
<p>Neural scaling laws, first characterized for NLP models by Kaplan et al. (2020), describe how model loss decreases as a power law with increasing model size, dataset size, or compute:</p>
<p>$$
L(R) = \alpha R^{-\beta}
$$</p>
<p>where $\alpha$ is a coefficient, $\beta$ is the scaling exponent, and $R$ is the resource being scaled (parameters, data, or compute). These relationships have guided resource allocation decisions in NLP and computer vision, but their applicability to scientific deep learning was unknown.</p>
<p>Chemical deep learning differs from standard NLP and vision tasks in several key ways. Physics-based priors (like symmetry constraints) may reduce the need for massive scale. The heterogeneity of chemical space and molecular tasks makes general pre-training more challenging. There are no established default architectures, datasets, or training recipes at large scale for chemistry.</p>
<p>This paper asks: do the same scaling behaviors hold for chemical models, and how do physical priors affect them?</p>
<h2 id="training-performance-estimation-for-efficient-scaling">Training performance estimation for efficient scaling</h2>
<p>Before running expensive scaling experiments, the authors needed a way to efficiently select hyperparameters. They introduced TPE, a generalization of training speed estimation (TSE) to new domains. TSE computes the cumulative training loss over the first $T$ epochs:</p>
<p>$$
\text{TSE} = \sum_{t=1}^{T} \left( \frac{1}{B} \sum_{i=1}^{B} \mathcal{L}\left(f_{\theta(t,i)}(\mathbf{X}_i), \mathbf{y}_i\right) \right)
$$</p>
<p>where $B$ is the number of training steps per epoch, $\mathcal{L}$ is the loss function, and $f_{\theta(t,i)}$ is the network at epoch $t$ and mini-batch $i$. A linear regression then predicts converged loss from early-training TSE:</p>
<p>$$
L = m \times \text{TSE} + b
$$</p>
<p>Using only 20% of the total training budget, TPE achieves $R^2 = 0.98$ and Spearman&rsquo;s $\rho = 1.0$ for ChemGPT on the MOSES dataset. For GNNs, it achieves $R^2 \geq 0.86$ and $\rho \geq 0.92$ across SchNet, PaiNN, and SpookyNet. This enables discarding suboptimal configurations early, saving up to 90% of compute.</p>
<h2 id="chemgpt-scaling-chemical-language-models">ChemGPT: scaling chemical language models</h2>
<p>ChemGPT is a GPT-3-style autoregressive transformer for molecular generation. It uses GPT-Neo as its backbone with a SELFIES tokenizer, factorizing the probability of a molecular sequence as:</p>
<p>$$
p(x) = \prod_{i=1}^{n} p\left(s_i \mid s_1, \dots, s_{i-1}\right)
$$</p>
<p>The authors trained ChemGPT models ranging from ~78K to over 1 billion non-embedding parameters on subsets of PubChem10M (up to ~10 million molecules, or ~300 million tokens). Key findings from the scaling experiments:</p>
<ul>
<li><strong>Pre-training loss monotonically improves</strong> with increasing dataset size up to nearly 10 million molecules, with no saturation observed.</li>
<li><strong>For a fixed data budget</strong>, increasing model size provides monotonic improvements until models reach ~1 billion parameters.</li>
<li><strong>The scaling exponent</strong> $\beta = 0.17 \pm 0.01$ for the largest dataset (after excluding the three largest models from the power-law fit), and $\beta = 0.30 \pm 0.01$ for the next largest dataset.</li>
<li><strong>Resolution-limited regimes</strong> appear where the power-law behavior breaks down, indicating either insufficient data for a given model size or vice versa. These regimes shift depending on the data budget.</li>
</ul>
<p>An interesting observation: for small datasets, large models ($10^7$ parameters and above) still provide notable loss improvements, suggesting that scaling up model size helps even when data is limited.</p>
<h2 id="neural-force-field-scaling-with-gnns">Neural force field scaling with GNNs</h2>
<p>For tasks requiring three-dimensional molecular geometry, the authors studied GNN-based neural force fields (NFFs). These models predict energies $\hat{E} = f_\theta(X)$ and derive forces by differentiation:</p>
<p>$$
\hat{F}_{ij} = -\frac{\partial \hat{E}}{\partial r_{ij}}
$$</p>
<p>Training uses an L1 loss over energies and forces:</p>
<p>$$
\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \left[ \alpha_E | E_i - \hat{E}_i | + \alpha_F | \mathbf{F}_i - \hat{\mathbf{F}}_i | \right]
$$</p>
<p>Four NFF architectures were studied, spanning a range of physical priors:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Key Characteristic</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SchNet</td>
          <td>E(3) invariant</td>
          <td>Continuous filter convolutions</td>
      </tr>
      <tr>
          <td>PaiNN</td>
          <td>E(3) equivariant</td>
          <td>Equivariant message passing</td>
      </tr>
      <tr>
          <td>Allegro</td>
          <td>E(3) equivariant</td>
          <td>Local, learned many-body functions</td>
      </tr>
      <tr>
          <td>SpookyNet</td>
          <td>E(3) equivariant</td>
          <td>Non-local interactions, empirical corrections</td>
      </tr>
  </tbody>
</table>
<p>Model capacity is parameterized as $c = d \times w$ (depth times width). Models were trained on subsets of the ANI-1x dataset (up to 100,000 geometries, corresponding to ~4.5 million force labels).</p>
<p>Key GNN scaling findings:</p>
<ul>
<li><strong>PaiNN shows monotonic loss improvement</strong> with increasing dataset size and strong correlation between converged loss and model capacity (Spearman&rsquo;s $\rho \geq 0.88$).</li>
<li><strong>Equivariant GNNs (PaiNN, Allegro) show better scaling efficiency</strong> than invariant GNNs (SchNet), with larger $\beta$ values.</li>
<li><strong>The scaling exponent for equivariant GNNs</strong> is $\beta = 0.26$, indicating that physics-based equivariance priors provide greater sample efficiency that persists to much larger and more chemically diverse datasets than previously studied.</li>
<li><strong>A transition at $10^4$ datapoints</strong> shows nearly perfect rank correlation between model capacity and converged loss ($\rho \geq 0.93$), suggesting this may be a threshold where models move from memorization to generalization.</li>
</ul>
<h2 id="results-and-practical-implications">Results and practical implications</h2>
<p>The scaling results provide actionable guidance for resource allocation:</p>
<ul>
<li>For <strong>chemical LLMs with large data budgets</strong>, the greatest loss improvements come from scaling up small models (around $10^5$ parameters).</li>
<li>For <strong>small data budgets</strong>, rapid improvements come from scaling medium-sized models ($10^7$ parameters).</li>
<li>For <strong>NFFs</strong>, low-capacity models show diminishing returns with more data, while high-capacity models show rapid improvements with increasing dataset size.</li>
<li><strong>Neither model type has saturated</strong> with respect to model size, dataset size, or compute, suggesting substantial room for improvement with further scaling.</li>
</ul>
<p>The 300-million-parameter ChemGPT trained on 300 million tokens and the PaiNN model with capacity ~1,000 trained on $10^5$ frames achieved the minimum losses in their respective scaling plots, providing concrete targets for practitioners.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Data:</strong></p>
<ul>
<li>PubChem10M (10M SMILES strings, via DeepChem)</li>
<li>MOSES (2M molecules, for TPE validation)</li>
<li>ANI-1x (5M DFT calculations, via Figshare)</li>
<li>Revised MD-17 (10 small organic molecules, 10,000 frames for TPE)</li>
</ul>
<p><strong>Models:</strong></p>
<ul>
<li>ChemGPT: GPT-Neo backbone, 24 layers, widths from 16 to 2,048, sizes from ~78K to ~1.2B non-embedding parameters</li>
<li>SchNet, PaiNN, Allegro, SpookyNet: widths of 16, 64, 256; depths of 2, 3, 4; 5 Angstrom cutoff</li>
</ul>
<p><strong>Training:</strong></p>
<ul>
<li>ChemGPT: AdamW optimizer, learning rate $2 \times 10^{-5}$, batch size 8 per GPU, 10 epochs, cross-entropy loss</li>
<li>GNNs: Adam optimizer, learning rate scheduler (halved after 30 epochs without improvement), early stopping after 50 stagnant epochs, max 1,000 epochs, L1 loss (force-only training)</li>
</ul>
<p><strong>Hardware:</strong></p>
<ul>
<li>NVIDIA Volta V100 GPUs (32 GB), 2 GPUs per node</li>
<li>PyTorch with distributed data parallel (DDP), PyTorch Lightning, LitMatter</li>
</ul>
<p><strong>Code:</strong> <a href="https://github.com/ncfrey/litmatter">LitMatter repository</a></p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation:</strong> Frey, N.C., Soklaski, R., Axelrod, S. et al. Neural scaling of deep chemical models. <em>Nat Mach Intell</em> <strong>5</strong>, 1297-1305 (2023).</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{frey2023neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural scaling of deep chemical models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Frey, Nathan C. and Soklaski, Ryan and Axelrod, Simon and Samsi, Siddharth and G{\&#39;o}mez-Bombarelli, Rafael and Coley, Connor W. and Gadepally, Vijay}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1297--1305}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-023-00740-3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Tied Two-Way Transformers for Diverse Retrosynthesis</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/tied-two-way-transformers-retrosynthesis/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/tied-two-way-transformers-retrosynthesis/</guid><description>Tied two-way transformers with cycle consistency and multinomial latent variables improve retrosynthetic prediction validity, plausibility, and diversity.</description><content:encoded><![CDATA[<h2 id="bridging-forward-and-backward-reaction-prediction">Bridging Forward and Backward Reaction Prediction</h2>
<p>This is a <strong>Method</strong> paper that addresses three key limitations of template-free <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthesis</a> models: invalid <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> outputs, chemically implausible predictions, and lack of diversity in reactant candidates. The solution combines three techniques: (1) cycle consistency checks using a paired forward reaction transformer, (2) parameter tying between the forward and backward transformers, and (3) multinomial latent variables with a learned prior to capture multiple reaction pathways.</p>
<h2 id="three-problems-in-template-free-retrosynthesis">Three Problems in Template-Free Retrosynthesis</h2>
<p>Template-free retrosynthesis models cast retrosynthesis as a <a href="/notes/chemistry/molecular-design/reaction-prediction/data-transfer-seq-to-seq-retrosynthesis/">sequence-to-sequence</a> translation problem (product SMILES to reactant SMILES). While these models avoid the cost of hand-coded reaction templates, they suffer from:</p>
<ol>
<li><strong>Invalid SMILES</strong>: predicted reactant strings that contain grammatical errors and cannot be parsed into molecules</li>
<li><strong>Implausibility</strong>: predicted reactants that are valid molecules but cannot actually synthesize the target product</li>
<li><strong>Lack of diversity</strong>: beam search produces duplicate or near-duplicate candidates, reducing the number of useful suggestions</li>
</ol>
<p>Prior work addressed these individually (SCROP adds a syntax corrector for validity, Chen et al. use latent variables for diversity), but this paper tackles all three simultaneously.</p>
<h2 id="model-architecture">Model Architecture</h2>
<h3 id="tied-two-way-transformers">Tied Two-Way Transformers</h3>
<p>The model pairs a retrosynthesis transformer $p(y|z, x)$ (product to reactants) with a forward reaction transformer $p(\tilde{x}|z, y)$ (reactants to product). Both use the standard encoder-decoder transformer architecture with 6 layers, 8 attention heads, and 256-dimensional embeddings.</p>
<p>The key architectural innovation is aggressive parameter tying: the two transformers share the entire encoder and all decoder parameters except layer normalization. This means the two-transformer system has approximately the same parameter count as a single transformer (17.5M vs. 17.4M). The shared parameters force the model to learn bidirectional reaction patterns from both forward and backward training data simultaneously, improving grammar learning and reducing invalid outputs.</p>
<h3 id="multinomial-latent-variables">Multinomial Latent Variables</h3>
<p>A discrete latent variable $z \in \{1, \ldots, K\}$ is introduced to capture multiple reaction modes. Each latent value conditions a different decoding path, encouraging diverse reactant predictions. The decoder initializes with a latent-class-specific start token (e.g., &ldquo;&lt;CLS2&gt;&rdquo;) and then decodes autoregressively.</p>
<p>The prior $p(z|x)$ is a learned multinomial distribution parametrized by a two-layer feed-forward network with tanh activation, taking the mean-pooled encoder output as input. This learned prior outperforms the uniform prior used by Chen et al., producing a smaller trade-off between top-1 and top-10 accuracy as $K$ increases.</p>
<h3 id="training-with-hard-em">Training with Hard EM</h3>
<p>Since the latent variable $z$ is unobserved during training, the model is trained with the online <a href="https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm">hard-EM algorithm</a>. The loss function is:</p>
<p>$$\mathcal{L}(\theta) = \mathbb{E}_{(x,y) \sim \text{data}} \left[ \min_{z} \mathcal{L}_h(x, y, z; \theta) \right]$$</p>
<p>where $\mathcal{L}_h = -(\log p(z|x) + \log p(y|z,x) + \log p(\tilde{x}=x|z,y))$. The E-step selects the best $z$ for each training pair (with dropout disabled), and the M-step updates parameters given the complete data.</p>
<h3 id="inference-with-cycle-consistency-reranking">Inference with Cycle Consistency Reranking</h3>
<p>At inference, the model: (1) generates $K$ sets of beam search hypotheses from the retrosynthesis transformer (one per latent value), (2) scores each candidate with the forward reaction transformer for cycle consistency $p(\tilde{x}=x|z,y)$, and (3) reranks candidates by the full likelihood $p(z|x) \cdot p(y|z,x) \cdot p(\tilde{x}=x|z,y)$. This pushes chemically plausible predictions to higher ranks.</p>
<h2 id="results-on-uspto-50k">Results on USPTO-50K</h2>
<p>All results are averaged over 5 random seeds with beam size 10.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Top-1 Acc.</th>
          <th>Top-5 Acc.</th>
          <th>Top-10 Acc.</th>
          <th>Top-1 Invalid</th>
          <th>Top-10 Invalid</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Liu-LSTM</td>
          <td>37.4%</td>
          <td>57.0%</td>
          <td>61.7%</td>
          <td>12.2%</td>
          <td>22.0%</td>
      </tr>
      <tr>
          <td>SCROP</td>
          <td>43.7%</td>
          <td>65.2%</td>
          <td>68.7%</td>
          <td>0.7%</td>
          <td>2.3%</td>
      </tr>
      <tr>
          <td>Lin-TF</td>
          <td>42.0%</td>
          <td>71.3%</td>
          <td>77.6%</td>
          <td>2.2%</td>
          <td>7.8%</td>
      </tr>
      <tr>
          <td>Base transformer</td>
          <td>44.3%</td>
          <td>68.4%</td>
          <td>72.7%</td>
          <td>1.7%</td>
          <td>12.1%</td>
      </tr>
      <tr>
          <td>Proposed ($K$=5)</td>
          <td>46.8%</td>
          <td>73.5%</td>
          <td>78.5%</td>
          <td>0.1%</td>
          <td>2.6%</td>
      </tr>
  </tbody>
</table>
<p>The proposed model achieves a +3.1% top-1 accuracy improvement over the best previous template-free method and reduces top-1 invalid rate to 0.1%.</p>
<h3 id="ablation-analysis">Ablation Analysis</h3>
<p>The ablation study isolates the contribution of each component:</p>
<ul>
<li><strong>Base+CC</strong> (cycle consistency only): reranks candidates to improve top-1/3/5 accuracy and validity, but top-10 stays the same since the candidate set is unchanged. Parameter count doubles (34.8M).</li>
<li><strong>Base+PT</strong> (parameter tying only): improves accuracy and validity at all top-$k$ levels with negligible parameter increase. Parameter tying during training improves the retrosynthesis transformer itself, even without cycle consistency at inference.</li>
<li><strong>Proposed ($K$=1)</strong>: combines tying with cycle consistency reranking.</li>
<li><strong>Proposed ($K$=5)</strong>: adds latent diversity, further improving top-10 accuracy (+2.2%) and reducing top-10 invalid rate (from 10.2% to 2.6%).</li>
</ul>
<h3 id="diversity-unique-rate">Diversity: Unique Rate</h3>
<p>As $K$ increases from 1 to 5, the unique molecule rate among 10 predictions rises substantially, confirming that latent modeling produces more diverse candidates. The learned prior reduces the top-1/top-10 accuracy trade-off compared to Chen et al.&rsquo;s uniform prior.</p>
<h2 id="results-on-in-house-multi-pathway-dataset">Results on In-House Multi-Pathway Dataset</h2>
<p>The in-house dataset (162K reactions from <a href="https://en.wikipedia.org/wiki/Reaxys">Reaxys</a>) contains multiple ground-truth reactions per product, enabling direct evaluation of pathway diversity through coverage (proportion of ground-truth pathways correctly predicted in the top-10 candidates).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Top-1 Acc.</th>
          <th>Top-10 Acc.</th>
          <th>Unique Rate</th>
          <th>Coverage</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Base</td>
          <td>64.2%</td>
          <td>91.6%</td>
          <td>76.1%</td>
          <td>84.4%</td>
      </tr>
      <tr>
          <td>Proposed</td>
          <td>66.0%</td>
          <td>92.8%</td>
          <td>93.2%</td>
          <td>87.3%</td>
      </tr>
  </tbody>
</table>
<p>The proposed model covers 87.3% of ground-truth reaction pathways on average, compared to 84.4% for the baseline. The unique rate jumps from 76.1% to 93.2%, confirming that the latent variables effectively encourage diverse predictions.</p>
<h2 id="limitations">Limitations</h2>
<p>The model uses SMILES string representation, which linearizes molecules and does not exploit the inherently rich chemical graph structure. Graph-based retrosynthesis models (e.g., GraphRetro at 63.8% top-1) substantially outperform template-free string-based models. The USPTO-50K dataset provides only one ground-truth pathway per product, making diversity evaluation limited on this benchmark. The in-house dataset is not publicly available. The model also does not predict reaction conditions (solvents, catalysts, temperature) or reagents.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/ejklike/tied-twoway-transformer">ejklike/tied-twoway-transformer</a></td>
          <td>Code</td>
          <td>Not specified</td>
          <td>Training and inference code</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: USPTO-50K dataset (public, 50K reactions from USPTO patents). In-house dataset (162K reactions from Reaxys, not publicly available).</p>
<p><strong>Hardware</strong>: 4 NVIDIA Tesla M40 GPUs. Checkpoints saved every 5000 steps, last 5 averaged.</p>
<p><strong>Training</strong>: Adam optimizer ($\beta$ = 0.9, 0.98), initial learning rate 2 with 8000 warm-up steps, dropout 0.3, gradient accumulation over 4 batches. Label smoothing set to 0.</p>
<p><strong>Inference</strong>: Beam size 10, generating 10 candidates per product.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kim, E., Lee, D., Kwon, Y., Park, M. S., &amp; Choi, Y.-S. (2021). Valid, Plausible, and Diverse Retrosynthesis Using Tied Two-Way Transformers with Latent Variables. <em>Journal of Chemical Information and Modeling</em>, 61, 123-133.</p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling, 2021</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/ejklike/tied-twoway-transformer">GitHub: ejklike/tied-twoway-transformer</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{kim2021valid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Valid, Plausible, and Diverse Retrosynthesis Using Tied Two-Way Transformers with Latent Variables}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kim, Eunji and Lee, Dongseon and Kwon, Youngchun and Park, Min Sik and Choi, Youn-Suk}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{61}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{123--133}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACS Publications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.0c01074}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolGenSurvey: Systematic Survey of ML for Molecule Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molgensurvey-molecule-design/</link><pubDate>Mon, 23 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/molgensurvey-molecule-design/</guid><description>Survey of ML molecule design methods across 1D string, 2D graph, and 3D geometry representations with deep generative and optimization approaches.</description><content:encoded><![CDATA[<h2 id="a-taxonomy-for-ml-driven-molecule-design">A Taxonomy for ML-Driven Molecule Design</h2>
<p>This is a <strong>Systematization</strong> paper that reviews machine learning approaches for molecule design across all three major molecular representations (1D string, 2D graph, 3D geometry) and both deep generative and combinatorial optimization paradigms. Prior surveys (including <a href="/notes/chemistry/molecular-design/generation/evaluation/inverse-molecular-design-ml-review/">Sánchez-Lengeling &amp; Aspuru-Guzik, 2018</a>, <a href="/notes/chemistry/molecular-design/generation/evaluation/deep-learning-molecular-design-review/">Elton et al., 2019</a>, Xue et al. 2019, Vanhaelen et al. 2020, Alshehri et al. 2020, Jiménez-Luna et al. 2020, and Axelrod et al. 2022) each covered subsets of the literature (e.g., only generative methods, or only specific task types). MolGenSurvey extends these by unifying the field into a single taxonomy based on input type, output type, and generation goal, identifying eight distinct molecule generation tasks. It catalogs over 100 methods across these categories and provides a structured comparison of evaluation metrics, datasets, and experimental setups.</p>
<p>The chemical space of drug-like molecules is estimated at $10^{23}$ to $10^{60}$, making exhaustive enumeration computationally infeasible. Traditional high-throughput screening searches existing databases but is slow and expensive. ML-based generative approaches offer a way to intelligently explore this space, either by learning continuous latent representations (deep generative models) or by directly searching the discrete chemical space (combinatorial optimization methods).</p>
<h2 id="molecular-representations">Molecular Representations</h2>
<p>The survey identifies three mainstream featurization approaches for molecules, each carrying different tradeoffs for generation tasks.</p>
<h3 id="1d-string-descriptions">1D String Descriptions</h3>
<p><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> and <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> are the two dominant string representations. SMILES encodes molecules as character strings following grammar rules for bonds, branches, and ring closures. Its main limitation is that arbitrary strings are often chemically invalid. SELFIES augments the encoding rules for branches and rings to achieve 100% validity by construction.</p>
<p>Other string representations exist (InChI, SMARTS) but are less commonly used for generation. Representation learning over strings has adopted CNNs, RNNs, and Transformers from NLP.</p>
<h3 id="2d-molecular-graphs">2D Molecular Graphs</h3>
<p>Molecules naturally map to graphs where atoms are nodes and bonds are edges. Graph neural networks (GNNs), particularly those following the message-passing neural network (MPNN) framework, have become the standard representation method. The MPNN updates each node&rsquo;s representation by aggregating information from its $K$-hop neighborhood. Notable architectures include D-MPNN (directional message passing), PNA (diverse aggregation methods), AttentiveFP (attention-based), and Graphormer (transformer-based).</p>
<h3 id="3d-molecular-geometry">3D Molecular Geometry</h3>
<p>Molecules are inherently 3D objects with conformations (3D structures at local energy minima) that determine function. Representing 3D geometry requires models that respect E(3) or SE(3) equivariance (invariance to rotation and translation). The survey catalogs architectures along this line including SchNet, DimeNet, EGNN, SphereNet, and PaiNN.</p>
<p>Additional featurization methods (molecular fingerprints/descriptors, 3D density maps, 3D surface meshes, and chemical images) are noted but have seen limited use in generation tasks.</p>
<h2 id="deep-generative-models">Deep Generative Models</h2>
<p>The survey covers six families of deep generative models applied to molecule design.</p>
<h3 id="autoregressive-models-ars">Autoregressive Models (ARs)</h3>
<p>ARs factorize the joint distribution of a molecule as a product of conditional distributions over its subcomponents:</p>
<p>$$p(\boldsymbol{x}) = \prod_{i=1}^{d} p(\bar{x}_i \mid \bar{x}_1, \bar{x}_2, \ldots, \bar{x}_{i-1})$$</p>
<p>For molecular graphs, this means sequentially predicting the next atom or bond conditioned on the partial structure built so far. RNNs, Transformers, and BERT-style models all implement this paradigm.</p>
<h3 id="variational-autoencoders-vaes">Variational Autoencoders (VAEs)</h3>
<p>VAEs learn a continuous latent space by maximizing the evidence lower bound (ELBO):</p>
<p>$$\log p(\boldsymbol{x}) \geq \mathbb{E}_{q(\boldsymbol{z}|\boldsymbol{x})}[\log p(\boldsymbol{x}|\boldsymbol{z})] - D_{KL}(q(\boldsymbol{z}|\boldsymbol{x}) | p(\boldsymbol{z}))$$</p>
<p>The first term is the reconstruction objective, and the second is a KL-divergence regularizer encouraging diverse, disentangled latent codes. Key molecular VAEs include <a href="/notes/chemistry/molecular-design/generation/latent-space/automatic-chemical-design-vae/">ChemVAE</a> (SMILES-based), JT-VAE (junction tree graphs), and <a href="/notes/chemistry/molecular-design/generation/latent-space/grammar-variational-autoencoder/">GrammarVAE</a> (grammar-constrained SMILES).</p>
<h3 id="normalizing-flows-nfs">Normalizing Flows (NFs)</h3>
<p>NFs model $p(\boldsymbol{x})$ via an invertible, deterministic mapping between data and latent space, using the change-of-variable formula with Jacobian determinants. Molecular applications include GraphNVP, MoFlow (one-shot graph generation), GraphAF (autoregressive flow), and GraphDF (discrete flow).</p>
<h3 id="generative-adversarial-networks-gans">Generative Adversarial Networks (GANs)</h3>
<p>GANs use a generator-discriminator game where the generator produces molecules and the discriminator distinguishes real from generated samples. Molecular GANs include MolGAN (graph-based with RL reward), <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a> (SMILES-based with RL), and Mol-CycleGAN (molecule-to-molecule translation).</p>
<h3 id="diffusion-models">Diffusion Models</h3>
<p>Diffusion models learn to reverse a gradual noising process. The forward process adds Gaussian noise over $T$ steps; a neural network learns to denoise at each step. The training objective reduces to predicting the noise added at each step:</p>
<p>$$\mathcal{L}_t = \mathbb{E}_{\boldsymbol{x}_0, \boldsymbol{\epsilon}}\left[|\epsilon_t - \epsilon_\theta(\sqrt{\bar{\alpha}_t}\boldsymbol{x}_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon_t, t)|^2\right]$$</p>
<p>Diffusion has been particularly successful for 3D conformation generation (ConfGF, GeoDiff, DGSM).</p>
<h3 id="energy-based-models-ebms">Energy-Based Models (EBMs)</h3>
<p>EBMs define $p(\boldsymbol{x}) = \frac{\exp(-E_\theta(\boldsymbol{x}))}{A}$ where $E_\theta$ is a learned energy function. The challenge is computing the intractable partition function $A$, addressed via contrastive divergence, noise-contrastive estimation, or score matching.</p>
<h2 id="combinatorial-optimization-methods">Combinatorial Optimization Methods</h2>
<p>Unlike DGMs that learn from data distributions, combinatorial optimization methods (COMs) search directly over discrete chemical space using oracle calls to evaluate candidate molecules.</p>
<h3 id="reinforcement-learning-rl">Reinforcement Learning (RL)</h3>
<p>RL formulates molecule generation as a Markov Decision Process: states are partial molecules, actions are adding/removing atoms or bonds, and rewards come from property oracles. Methods include GCPN (graph convolutional policy network), MolDQN (deep Q-network), RationaleRL (property-aware substructure assembly), and REINVENT (SMILES-based policy gradient).</p>
<h3 id="genetic-algorithms-ga">Genetic Algorithms (GA)</h3>
<p>GAs maintain a population of molecules and evolve them through mutation and crossover operations. GB-GA operates on molecular graphs, GA+D uses SELFIES with adversarial discriminator enhancement, and JANUS uses SELFIES with parallel exploration strategies.</p>
<h3 id="bayesian-optimization-bo">Bayesian Optimization (BO)</h3>
<p>BO builds a Gaussian process surrogate of the objective function and uses an acquisition function to decide which molecules to evaluate next. It is often combined with VAE latent spaces (Constrained-BO-VAE, MSO) to enable continuous optimization.</p>
<h3 id="monte-carlo-tree-search-mcts">Monte Carlo Tree Search (MCTS)</h3>
<p>MCTS explores the molecular construction tree by branching and evaluating promising intermediates. ChemTS and MP-MCTS combine MCTS with autoregressive SMILES generators.</p>
<h3 id="mcmc-sampling">MCMC Sampling</h3>
<p>MCMC methods (MIMOSA, MARS) formulate molecule optimization as sampling from a target distribution defined by multiple property objectives, using graph neural networks as proposal distributions.</p>
<h3 id="other-approaches">Other Approaches</h3>
<p>The survey also identifies two additional paradigms that do not fit neatly into either DGM or COM categories. <strong>Optimal Transport (OT)</strong> is used when matching between groups of molecules, particularly for conformation generation where each molecule has multiple associated 3D structures (e.g., GeoMol, EquiBind). <strong>Differentiable Learning</strong> formulates discrete molecules as differentiable objects, enabling gradient-based continuous optimization directly on molecular graphs (e.g., DST).</p>
<h2 id="task-taxonomy-eight-molecule-generation-tasks">Task Taxonomy: Eight Molecule Generation Tasks</h2>
<p>The survey&rsquo;s central organizational contribution is a unified taxonomy of eight distinct molecule design tasks, defined by three axes: (1) whether generation is <em>de novo</em> (from scratch, no reference molecule) or conditioned on an input molecule, (2) whether the goal is <em>generation</em> (distribution learning, producing valid and diverse molecules) or <em>optimization</em> (goal-directed search for molecules with specific properties), and (3) the input/output data representation (1D string, 2D graph, 3D geometry). The paper&rsquo;s Table 2 maps all combinations of these axes, showing that many are not meaningful (e.g., 1D string input to 2D graph output with no goal). Only eight combinations correspond to active research areas.</p>
<h3 id="1d2d-tasks">1D/2D Tasks</h3>
<ul>
<li><strong>De novo 1D/2D molecule generation</strong>: Generate new molecules from scratch to match a training distribution. Methods span VAEs (ChemVAE, JT-VAE), flows (GraphNVP, MoFlow, GraphAF), GANs (MolGAN, <a href="/notes/chemistry/molecular-design/generation/rl-tuned/organ-objective-reinforced-gan/">ORGAN</a>), ARs (<a href="/notes/chemistry/molecular-design/generation/rl-tuned/molecularrnn-graph-generation-optimized-properties/">MolecularRNN</a>), and EBMs (GraphEBM).</li>
<li><strong>De novo 1D/2D molecule optimization</strong>: Generate molecules with optimal properties from scratch, using oracle feedback. Methods include RL (GCPN, MolDQN), GA (GB-GA, JANUS), MCTS (ChemTS), and MCMC (MIMOSA, MARS).</li>
<li><strong>1D/2D molecule optimization</strong>: Optimize properties of a given input molecule via local search. Methods include graph-to-graph translation (VJTNN, CORE, MOLER), VAE+BO (MSO, Constrained-BO-VAE), GANs (Mol-CycleGAN, <a href="/notes/chemistry/molecular-design/generation/latent-space/latentgan-de-novo-molecular-generation/">LatentGAN</a>), and differentiable approaches (DST).</li>
</ul>
<h3 id="3d-tasks">3D Tasks</h3>
<ul>
<li><strong>De novo 3D molecule generation</strong>: Generate novel 3D molecular structures from scratch, respecting geometric validity. Methods include ARs (G-SchNet, G-SphereNet), VAEs (3DMolNet), flows (E-NFs), and RL (MolGym).</li>
<li><strong>De novo 3D conformation generation</strong>: Generate 3D conformations from given 2D molecular graphs. Methods include VAEs (CVGAE, ConfVAE), diffusion models (ConfGF, GeoDiff, DGSM), and optimal transport (GeoMol).</li>
<li><strong>De novo binding-based 3D molecule generation</strong>: Design 3D molecules for specific protein binding pockets. Methods include density-based VAEs (liGAN), RL (DeepLigBuilder), and ARs (3DSBDD).</li>
<li><strong>De novo binding-pose conformation generation</strong>: Find the appropriate 3D conformation of a given molecule for a given protein pocket. Methods include EBMs (DeepDock) and optimal transport (EquiBind).</li>
<li><strong>3D molecule optimization</strong>: Optimize 3D molecular properties (scaffold replacement, conformation refinement). Methods include BO (BOA), ARs (3D-Scaffold, cG-SchNet), and VAEs (Coarse-GrainingVAE).</li>
</ul>
<h2 id="evaluation-metrics">Evaluation Metrics</h2>
<p>The survey organizes evaluation metrics into four categories.</p>
<h3 id="generation-evaluation">Generation Evaluation</h3>
<p>Basic metrics assess the quality of generated molecules:</p>
<ul>
<li><strong>Validity</strong>: fraction of chemically valid molecules among all generated molecules</li>
<li><strong>Novelty</strong>: fraction of generated molecules absent from the training set</li>
<li><strong>Uniqueness</strong>: fraction of distinct molecules among generated samples</li>
<li><strong>Quality</strong>: fraction passing a predefined chemical rule filter</li>
<li><strong>Diversity</strong> (internal/external): measured via pairwise similarity (Tanimoto, scaffold, or fragment) within generated set and between generated and training sets</li>
</ul>
<h3 id="distribution-evaluation">Distribution Evaluation</h3>
<p>Metrics measuring how well generated molecules capture the training distribution: KL divergence over physicochemical descriptors, <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance</a> (FCD), and Mean Maximum Discrepancy (MMD).</p>
<h3 id="optimization-evaluation">Optimization Evaluation</h3>
<p>Property oracles used as optimization targets: Synthetic Accessibility (SA), Quantitative Estimate of Drug-likeness (QED), LogP, kinase inhibition scores (GSK3-beta, JNK3), DRD2 activity, <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> benchmark oracles, and Vina docking scores. Constrained optimization additionally considers structural similarity to reference molecules via Tanimoto, scaffold, or fragment similarity.</p>
<h3 id="3d-evaluation">3D Evaluation</h3>
<p>3D-specific metrics include stability (matching valence rules in 3D), RMSD and Kabsch-RMSD (conformation alignment), and Coverage/Matching scores for conformation ensembles.</p>
<h2 id="datasets">Datasets</h2>
<p>The survey catalogs 12 major datasets spanning 1D/2D and 3D molecule generation:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Scale</th>
          <th>Dimensionality</th>
          <th>Purpose</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ZINC</td>
          <td>250K</td>
          <td>1D/2D</td>
          <td>Virtual screening compounds</td>
      </tr>
      <tr>
          <td>ChEMBL</td>
          <td>2.1M</td>
          <td>1D/2D</td>
          <td>Bioactive molecules</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a></td>
          <td>1.9M</td>
          <td>1D/2D</td>
          <td>Benchmarking generation</td>
      </tr>
      <tr>
          <td>CEPDB</td>
          <td>4.3M</td>
          <td>1D/2D</td>
          <td>Organic photovoltaics</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a></td>
          <td>970M</td>
          <td>1D/2D</td>
          <td>Enumerated small molecules</td>
      </tr>
      <tr>
          <td>QM9</td>
          <td>134K</td>
          <td>1D/2D/3D</td>
          <td>Quantum chemistry properties</td>
      </tr>
      <tr>
          <td><a href="/notes/chemistry/datasets/geom/">GEOM</a></td>
          <td>450K/37M</td>
          <td>1D/2D/3D</td>
          <td>Conformer ensembles</td>
      </tr>
      <tr>
          <td>ISO17</td>
          <td>200/431K</td>
          <td>1D/2D/3D</td>
          <td>Molecule-conformation pairs</td>
      </tr>
      <tr>
          <td>Molecule3D</td>
          <td>3.9M</td>
          <td>1D/2D/3D</td>
          <td>DFT ground-state geometries</td>
      </tr>
      <tr>
          <td>CrossDock2020</td>
          <td>22.5M</td>
          <td>1D/2D/3D</td>
          <td>Docked ligand poses</td>
      </tr>
      <tr>
          <td>scPDB</td>
          <td>16K</td>
          <td>1D/2D/3D</td>
          <td>Binding sites</td>
      </tr>
      <tr>
          <td>DUD-E</td>
          <td>23K</td>
          <td>1D/2D/3D</td>
          <td>Active compounds with decoys</td>
      </tr>
  </tbody>
</table>
<h2 id="challenges-and-opportunities">Challenges and Opportunities</h2>
<h3 id="challenges">Challenges</h3>
<ol>
<li><strong>Out-of-distribution generation</strong>: Most deep generative models imitate known molecule distributions and struggle to explore truly novel chemical space.</li>
<li><strong>Unrealistic problem formulation</strong>: Many task setups do not respect real-world chemistry constraints.</li>
<li><strong>Expensive oracle calls</strong>: Methods typically assume unlimited access to property evaluators, which is unrealistic in drug discovery.</li>
<li><strong>Lack of interpretability</strong>: Few methods explain why generated molecules have desired properties. Quantitative interpretability evaluation remains an open problem.</li>
<li><strong>No unified evaluation protocols</strong>: The field lacks consensus on what defines a &ldquo;good&rdquo; drug candidate and how to fairly compare methods.</li>
<li><strong>Insufficient benchmarking</strong>: Despite the enormous chemical space ($10^{23}$ to $10^{60}$ drug-like molecules), available benchmarks use only small fractions of large databases.</li>
<li><strong>Low-data regime</strong>: Many real-world applications have limited training data, and generating molecules under data scarcity remains difficult.</li>
</ol>
<h3 id="opportunities">Opportunities</h3>
<ol>
<li><strong>Extension to complex structured data</strong>: Techniques from small molecule generation may transfer to proteins, antibodies, genes, crystal structures, and polysaccharides.</li>
<li><strong>Connection to later drug development phases</strong>: Bridging the gap between molecule design and preclinical/clinical trial outcomes could improve real-world impact.</li>
<li><strong>Knowledge discovery</strong>: Generative models over molecular latent spaces could reveal chemical rules governing molecular properties, and graph structure learning could uncover implicit non-bonded interactions.</li>
</ol>
<h2 id="limitations">Limitations</h2>
<ul>
<li>The survey was published in March 2022, so it does not cover subsequent advances in diffusion models for molecules (e.g., EDM, DiffSBDD), large language models applied to chemistry, or flow matching approaches.</li>
<li>Coverage focuses on small molecules. Macromolecule design (proteins, nucleic acids) is noted as a future direction rather than surveyed.</li>
<li>The survey catalogs methods but does not provide head-to-head experimental comparisons across all 100+ methods. Empirical discussion relies on individual papers&rsquo; reported results.</li>
<li>1D string-based methods receive less detailed coverage than graph and geometry-based approaches, reflecting the field&rsquo;s shift toward structured representations at the time of writing.</li>
<li>As a survey, this paper produces no code, models, or datasets. The surveyed methods&rsquo; individual repositories are referenced in their original publications but are not aggregated here.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Du, Y., Fu, T., Sun, J., &amp; Liu, S. (2022). MolGenSurvey: A Systematic Survey in Machine Learning Models for Molecule Design. <em>arXiv preprint arXiv:2203.14500</em>.</p>
<p><strong>Publication</strong>: arXiv preprint, March 2022. <strong>Note</strong>: This survey covers literature through early 2022 and does not include subsequent advances in diffusion models, LLMs for chemistry, or flow matching.</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2203.14500">arXiv: 2203.14500</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{du2022molgensurvey,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolGenSurvey: A Systematic Survey in Machine Learning Models for Molecule Design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Du, Yuanqi and Fu, Tianfan and Sun, Jimeng and Liu, Shengchao}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2203.14500}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>UnCorrupt SMILES: Post Hoc Correction for De Novo Design</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/uncorrupt-smiles/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/evaluation/uncorrupt-smiles/</guid><description>A transformer-based SMILES corrector that fixes invalid outputs from molecular generators, recovering 60-95% of erroneous SMILES strings.</description><content:encoded><![CDATA[<h2 id="a-transformer-based-smiles-error-corrector">A Transformer-Based SMILES Error Corrector</h2>
<p>This is a <strong>Method</strong> paper that proposes a post hoc approach to fixing invalid SMILES produced by de novo molecular generators. Rather than trying to prevent invalid outputs through alternative representations (<a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>) or constrained architectures (graph models), the authors train a transformer model to translate invalid SMILES into valid ones. The corrector is framed as a sequence-to-sequence translation task, drawing on techniques from grammatical error correction (GEC) in natural language processing.</p>
<h2 id="the-problem-of-invalid-smiles-in-molecular-generation">The Problem of Invalid SMILES in Molecular Generation</h2>
<p><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>-based generative models produce some percentage of invalid outputs that cannot be converted to molecules. The invalidity rate varies substantially across model types:</p>
<ul>
<li><strong>RNN models</strong> (DrugEx): 5.7% invalid (pretrained) and 4.7% invalid (target-directed)</li>
<li><strong>GANs</strong> (ORGANIC): 9.5% invalid</li>
<li><strong>VAEs</strong> (GENTRL): 88.9% invalid</li>
</ul>
<p>These invalid outputs represent wasted computation and potentially introduce bias toward molecules that are easier to generate correctly. Previous approaches to this problem include using alternative representations (<a href="/notes/chemistry/molecular-representations/notations/deepsmiles-adaptation-for-ml/">DeepSMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>) or graph-based models, but these either limit the search space or increase computational cost. The authors propose a complementary strategy: fix the errors after generation.</p>
<h2 id="error-taxonomy-across-generator-types">Error Taxonomy Across Generator Types</h2>
<p>The paper classifies invalid SMILES errors into six categories based on RDKit error messages:</p>
<ol>
<li><strong>Syntax errors</strong>: malformed SMILES grammar</li>
<li><strong>Unclosed rings</strong>: unmatched ring closure digits</li>
<li><strong>Parentheses errors</strong>: unbalanced open/close parentheses</li>
<li><strong>Bond already exists</strong>: duplicate bonds between the same atoms</li>
<li><strong>Aromaticity errors</strong>: atoms incorrectly marked as aromatic or kekulization failures</li>
<li><strong>Valence errors</strong>: atoms exceeding their maximum bond count</li>
</ol>
<p>The distribution of error types differs across generators. RNN-based models primarily produce aromaticity errors, suggesting they learn SMILES grammar well but struggle with chemical validity. The GAN (ORGANIC) produces mostly valence errors. The VAE (GENTRL) produces more grammar-level errors (syntax, parentheses, unclosed rings), indicating that sampling from the continuous latent space often produces sequences that violate basic SMILES structure.</p>
<h2 id="architecture-and-training">Architecture and Training</h2>
<p>The SMILES corrector uses a standard encoder-decoder transformer architecture based on Vaswani et al., with learned positional encodings. Key specifications:</p>
<ul>
<li>Embedding dimension: 256</li>
<li>Encoder/decoder layers: 3 each</li>
<li>Attention heads: 8</li>
<li>Feed-forward dimension: 512</li>
<li>Dropout: 0.1</li>
<li>Optimizer: Adam (learning rate 0.0005)</li>
<li>Training: 20 epochs, batch size 16</li>
</ul>
<p>Since no dataset of manually corrected invalid-valid SMILES pairs exists, the authors create synthetic training data by introducing errors into valid SMILES from the Papyrus bioactivity dataset (approximately 1.3M pairs). Errors are introduced through random perturbations following SMILES syntax rules: character substitutions, bond order changes, fragment additions from the <a href="/notes/chemistry/datasets/gdb-11/">GDB</a>-8 database to atoms with full valence, and other structural modifications.</p>
<h2 id="training-with-multiple-errors-improves-correction">Training with Multiple Errors Improves Correction</h2>
<p>A key finding is that training the corrector on inputs with multiple errors per SMILES substantially improves performance on real generator outputs. The baseline model (1 error per input) fixes 35-80% of invalid outputs depending on the generator. Increasing errors per training input to 12 raises this to 62-95%:</p>
<table>
  <thead>
      <tr>
          <th>Generator</th>
          <th>1 error/input</th>
          <th>12 errors/input</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN (DrugEx)</td>
          <td>~60% fixed</td>
          <td>62% fixed</td>
      </tr>
      <tr>
          <td>Target-directed RNN</td>
          <td>~60% fixed</td>
          <td>68% fixed</td>
      </tr>
      <tr>
          <td>GAN (ORGANIC)</td>
          <td>~80% fixed</td>
          <td>95% fixed</td>
      </tr>
      <tr>
          <td>VAE (GENTRL)</td>
          <td>~35% fixed</td>
          <td>80% fixed</td>
      </tr>
  </tbody>
</table>
<p>Training beyond 12 errors per input yields diminishing returns (80% average at 20 errors vs. 78% at 12). The improvement from multi-error training is consistent with GEC literature, where models learn to &ldquo;distrust&rdquo; inputs more when exposed to higher error rates.</p>
<p>The model also shows low overcorrection: only 14% of valid SMILES are altered during translation, comparable to overcorrection rates in spelling correction systems.</p>
<h2 id="fixed-molecules-are-comparable-to-generator-outputs">Fixed Molecules Are Comparable to Generator Outputs</h2>
<p>The corrected molecules are evaluated against both the training set and the readily generated (valid) molecules from each generator:</p>
<ul>
<li><strong>Uniqueness</strong>: 97% of corrected molecules are unique</li>
<li><strong>Novelty vs. generated</strong>: 97% of corrected molecules are novel compared to the valid generator outputs</li>
<li><strong>Similarity to nearest neighbor (SNN)</strong>: 0.45 between fixed and generated sets, indicating the corrected molecules explore different parts of chemical space</li>
<li><strong>Property distributions</strong>: KL divergence scores between fixed molecules and the training set are comparable to those between generated molecules and the training set</li>
</ul>
<p>This demonstrates that SMILES correction produces molecules that are as chemically reasonable as the generator&rsquo;s valid outputs while exploring complementary regions of chemical space.</p>
<h2 id="local-chemical-space-exploration-via-error-introduction">Local Chemical Space Exploration via Error Introduction</h2>
<p>Beyond fixing generator errors, the authors propose using the SMILES corrector for analog generation. The workflow is:</p>
<ol>
<li>Take a known active molecule</li>
<li>Introduce random errors into its SMILES (repeated 1000 times)</li>
<li>Correct the errors using the trained corrector</li>
</ol>
<p>This &ldquo;local sequence exploration&rdquo; generates novel analogs with 97% validity. The uniqueness (39%) and novelty (16-37%) are lower than for generator correction because the corrector often regenerates the original molecule. However, the approach produces molecules that are structurally similar to the starting compound (SNN of 0.85 to known ligands).</p>
<p>The authors demonstrate this on selective <a href="https://en.wikipedia.org/wiki/Aurora_kinase_B">Aurora kinase B</a> (AURKB) inhibitors. The generated analogs occupy the same binding site region as the co-crystallized ligand VX-680 in docking studies, with predicted bioactivities similar to known compounds. Compared to target-directed RNN generation, SMILES exploration produces molecules closer to known actives (higher SNN, scaffold similarity, and KL divergence scores).</p>
<h2 id="limitations">Limitations</h2>
<p>The corrector performance drops when applied to real generator outputs compared to synthetic test data, because the synthetic error distribution does not perfectly match the errors that generators actually produce. Generator-specific correctors trained on actual invalid outputs could improve performance. The local exploration approach has limited novelty since the corrector frequently regenerates the original molecule. The evaluation uses predicted rather than experimental bioactivities for the Aurora kinase case study.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/LindeSchoenmaker/SMILES-corrector">LindeSchoenmaker/SMILES-corrector</a></td>
          <td>Code + Data</td>
          <td>MIT</td>
          <td>Training code, synthetic error generation, and evaluation scripts</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: Synthetic training pairs derived from the Papyrus bioactivity dataset (v5.5). Approximately 1.3M invalid-valid pairs per error-count setting.</p>
<p><strong>Code</strong>: Transformer implemented in PyTorch, adapted from Ben Trevett&rsquo;s seq2seq tutorial. Generative model baselines use DrugEx, GENTRL, and ORGANIC.</p>
<p><strong>Evaluation</strong>: Validity assessed with RDKit. Similarity metrics (SNN, fragment, scaffold) and KL divergence computed following <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> and <a href="/notes/chemistry/molecular-design/generation/evaluation/guacamol-benchmarking-de-novo-molecular-design/">GuacaMol</a> benchmark protocols.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Schoenmaker, L., Béquignon, O. J. M., Jespers, W., &amp; van Westen, G. J. P. (2023). UnCorrupt SMILES: a novel approach to de novo design. <em>Journal of Cheminformatics</em>, 15, 22.</p>
<p><strong>Publication</strong>: Journal of Cheminformatics, 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/LindeSchoenmaker/SMILES-corrector">GitHub: LindeSchoenmaker/SMILES-corrector</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{schoenmaker2023uncorrupt,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{UnCorrupt SMILES: a novel approach to de novo design}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Schoenmaker, Linde and B{\&#39;e}quignon, Olivier J. M. and Jespers, Willem and van Westen, Gerard J. P.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1186/s13321-023-00696-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>LIMO: Latent Inceptionism for Targeted Molecule Generation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/</guid><description>LIMO uses gradient-based optimization through a VAE latent space and stacked property predictor to generate drug-like molecules with high binding affinity.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Eckmann, P., Sun, K., Zhao, B., Feng, M., Gilson, M. K., &amp; Yu, R. (2022). LIMO: Latent Inceptionism for Targeted Molecule Generation. <em>Proceedings of the 39th International Conference on Machine Learning (ICML 2022)</em>, PMLR 162, 5777&ndash;5792.</p>
<p><strong>Publication</strong>: ICML 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Rose-STL-Lab/LIMO">GitHub: Rose-STL-Lab/LIMO</a></li>
<li><a href="https://arxiv.org/abs/2206.09010">arXiv: 2206.09010</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{eckmann2022limo,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{LIMO: Latent Inceptionism for Targeted Molecule Generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Eckmann, Peter and Sun, Kunyang and Zhao, Bo and Feng, Mudong and Gilson, Michael K and Yu, Rose}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5777--5792}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">organization</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="gradient-based-reverse-optimization-in-molecular-latent-space">Gradient-Based Reverse Optimization in Molecular Latent Space</h2>
<p>This is a <strong>Method</strong> paper that introduces LIMO, a framework for generating molecules with desired properties using gradient-based optimization on a VAE latent space. The key innovation is a stacked architecture where a property predictor operates on the decoded molecular representation rather than directly on the latent space, combined with an inceptionism-like technique that backpropagates through the frozen decoder and predictor to optimize the latent code. This approach is 6-8x faster than RL baselines and 12x faster than sampling-based approaches while producing molecules with higher binding affinities.</p>
<h2 id="slow-property-optimization-in-existing-methods">Slow Property Optimization in Existing Methods</h2>
<p>Generating molecules with high binding affinity to target proteins is a central goal of early drug discovery, but existing computational approaches are slow when optimizing for properties that are expensive to evaluate (such as docking-based binding affinity). RL-based methods require many calls to the property function during training. Sampling-based approaches like MARS need hundreds of iterations. Latent optimization methods that predict properties directly from the latent space suffer from poor prediction accuracy because the mapping from latent space to molecular properties is difficult to learn.</p>
<h2 id="the-limo-framework">The LIMO Framework</h2>
<p>LIMO consists of three components: a VAE for learning a molecular latent space, a property predictor with a novel stacked architecture, and a gradient-based reverse optimization procedure.</p>
<h3 id="selfies-based-vae">SELFIES-Based VAE</h3>
<p>The VAE encodes molecules represented as SELFIES strings into a 1024-dimensional latent space $\mathbf{z} \in \mathbb{R}^m$ and decodes to probability distributions over SELFIES symbols. Since all SELFIES strings correspond to valid molecules, this guarantees 100% chemical validity. The output molecule is obtained by taking the argmax at each position:</p>
<p>$$\hat{x}_i = s_{d_i^*}, \quad d_i^* = \operatorname{argmax}_{d} \{y_{i,1}, \ldots, y_{i,d}\}$$</p>
<p>The VAE uses fully-connected layers (not recurrent), with a 64-dimensional embedding layer, four batch-normalized linear layers (2000-dimensional first layer, 1000-dimensional for the rest) with ReLU activation, and is trained with ELBO loss (0.9 weight on reconstruction, 0.1 on KL divergence).</p>
<h3 id="stacked-property-predictor">Stacked Property Predictor</h3>
<p>The critical architectural choice: the property predictor $g_\theta$ takes the decoded molecular representation $\hat{\mathbf{x}}$ as input rather than the latent code $\mathbf{z}$. The predictor is trained after the VAE is frozen by minimizing MSE on VAE-generated molecules:</p>
<p>$$\ell_0(\theta) = \left\| g_\theta\left(f_{\text{dec}}(\mathbf{z})\right) - \pi\left(f_{\text{dec}}(\mathbf{z})\right) \right\|^2$$</p>
<p>where $\pi$ is the ground-truth property function. This stacking improves prediction accuracy from $r^2 = 0.04$ (predicting from $\mathbf{z}$) to $r^2 = 0.38$ (predicting from $\hat{\mathbf{x}}$) on an unseen test set. The improvement comes because the mapping from molecular space to property is easier to learn than the mapping from latent space to property.</p>
<h3 id="reverse-optimization-inceptionism">Reverse Optimization (Inceptionism)</h3>
<p>After training, the decoder and predictor weights are frozen and $\mathbf{z}$ becomes the trainable parameter. For multiple properties with weights $(w_1, \ldots, w_k)$, the optimization minimizes:</p>
<p>$$\ell_1(\mathbf{z}) = -\sum_{i=1}^{k} w_i \cdot g^i\left(f_{\text{dec}}(\mathbf{z})\right)$$</p>
<p>Since both the decoder and predictor are neural networks, gradients flow through the entire chain, enabling efficient optimization with Adam. This is analogous to the &ldquo;inceptionism&rdquo; (DeepDream) technique from computer vision, where network inputs are optimized to maximize specific outputs.</p>
<h3 id="substructure-constrained-optimization">Substructure-Constrained Optimization</h3>
<p>For lead optimization, LIMO can fix a molecular substructure during optimization by adding a regularization term:</p>
<p>$$\ell_2(\mathbf{z}) = \lambda \sum_{i=1}^{n} \sum_{j=1}^{d} \left(M_{i,j} \cdot \left(f_{\text{dec}}(\mathbf{z})_{i,j} - (\hat{\mathbf{x}}_{\text{start}})_{i,j}\right)\right)^2$$</p>
<p>where $M$ is a binary mask specifying which SELFIES positions must remain unchanged and $\lambda = 1000$. This capability is enabled by the intermediate decoded representation, which most VAE-based methods lack.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<h3 id="benchmark-tasks-qed-and-penalized-logp">Benchmark Tasks (QED and Penalized LogP)</h3>
<p>LIMO achieves competitive results with deep generative and RL-based models in 1 hour, compared to 8-24 hours for baselines. Top QED score: 0.947 (maximum possible: 0.948). Top penalized LogP: 10.5 (among length-limited models, comparable to MolDQN&rsquo;s 11.8).</p>
<p>The ablation study (&ldquo;LIMO on z&rdquo;) confirms the stacked predictor architecture: predicting from $\hat{\mathbf{x}}$ yields top p-logP of 10.5 versus 6.52 when predicting directly from $\mathbf{z}$.</p>
<h3 id="binding-affinity-maximization">Binding Affinity Maximization</h3>
<p>The primary contribution. LIMO generates molecules with substantially higher computed binding affinities (lower $K_D$) than baselines against two protein targets:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>ESR1 best $K_D$ (nM)</th>
          <th>ACAA1 best $K_D$ (nM)</th>
          <th>Time (hrs)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>GCPN</td>
          <td>6.4</td>
          <td>75</td>
          <td>6</td>
      </tr>
      <tr>
          <td>MolDQN</td>
          <td>373</td>
          <td>240</td>
          <td>6</td>
      </tr>
      <tr>
          <td>MARS</td>
          <td>17</td>
          <td>163</td>
          <td>6</td>
      </tr>
      <tr>
          <td>GraphDF</td>
          <td>25</td>
          <td>370</td>
          <td>12</td>
      </tr>
      <tr>
          <td>LIMO</td>
          <td>0.72</td>
          <td>37</td>
          <td>1</td>
      </tr>
  </tbody>
</table>
<p>For ESR1, LIMO&rsquo;s best molecule has a $K_D$ of 0.72 nM from docking, nearly 10x better than the next method (GCPN at 6.4 nM). When corroborated with more rigorous absolute binding free energy (ABFE) calculations, one LIMO compound achieved a predicted $K_D$ of $6 \times 10^{-14}$ M (0.00006 nM), far exceeding the affinities of approved drugs tamoxifen ($K_D$ = 1.5 nM) and raloxifene ($K_D$ = 0.03 nM).</p>
<h3 id="multi-objective-optimization">Multi-Objective Optimization</h3>
<p>Single-objective optimization produces molecules with high affinity but problematic structures (polyenes, large rings). Multi-objective optimization simultaneously targeting binding affinity, QED ($&gt;$ 0.4), and SA ($&lt;$ 5.5) produces drug-like, synthesizable molecules that still have nanomolar binding affinities. Generated molecules satisfy Lipinski&rsquo;s rule of 5 with zero PAINS alerts.</p>
<h2 id="limitations">Limitations</h2>
<p>The LIMO property predictor achieves only moderate prediction accuracy ($r^2$ = 0.38), meaning the optimization relies on gradient direction being correct rather than absolute predictions being accurate. AutoDock-GPU docking scores do not correlate well with the more accurate ABFE results, a known limitation of docking. The fully-connected VAE architecture limits the molecular diversity compared to recurrent or attention-based alternatives (LSTM decoder produced max QED of only 0.3). The greedy fine-tuning step (replacing carbons with heteroatoms) is a heuristic rather than a learned procedure.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Rose-STL-Lab/LIMO">Rose-STL-Lab/LIMO</a></td>
          <td>Code</td>
          <td>UC San Diego Custom (non-commercial)</td>
          <td>Full training, optimization, and evaluation code</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: ZINC250k dataset for optimization tasks. MOSES dataset for random generation evaluation. Binding affinities computed with AutoDock-GPU.</p>
<p><strong>Hardware</strong>: Two GTX 1080 Ti GPUs (one for PyTorch, one for AutoDock-GPU), 4 CPU cores, 32 GB memory.</p>
<p><strong>Training</strong>: VAE trained for 18 epochs with learning rate 0.0001. Property predictor uses 3 layers of 1000 units, trained for 5 epochs. Reverse optimization uses learning rate 0.1 for 10 epochs.</p>
<p><strong>Targets</strong>: Human estrogen receptor (ESR1, PDB 1ERR) and human peroxisomal acetyl-CoA acyl transferase 1 (ACAA1, PDB 2IIK).</p>
]]></content:encoded></item><item><title>Language Models Learn Complex Molecular Distributions</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/lm-complex-molecular-distributions/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/lm-complex-molecular-distributions/</guid><description>RNN language models trained on SMILES and SELFIES outperform graph models at learning complex, multi-modal, and large-scale molecular distributions.</description><content:encoded><![CDATA[<h2 id="rnn-language-models-as-flexible-molecular-generators">RNN Language Models as Flexible Molecular Generators</h2>
<p>This is an <strong>Empirical</strong> paper that investigates the capacity of simple recurrent neural network (RNN) language models to learn complex molecular distributions. The core finding is that LSTM-based models trained on <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> (SM-RNN) or <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> (SF-RNN) string representations consistently outperform popular graph generative models (JTVAE, CGVAE) across three increasingly challenging generative modeling tasks. The paper positions language models as flexible, scalable alternatives to graph-based approaches for molecular generation.</p>
<h2 id="scaling-beyond-standard-benchmarks">Scaling Beyond Standard Benchmarks</h2>
<p>Most molecular generative models are evaluated on relatively small, drug-like molecules from datasets like <a href="https://en.wikipedia.org/wiki/ZINC_database">ZINC</a> or <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a>. These standard benchmarks do not test whether models can handle larger, more structurally diverse molecules or distributions with complex shapes (multi-modal, heavy-tailed). This gap matters because there is increasing interest in larger, more complex molecules for therapeutics, including peptides and natural products.</p>
<p>Graph generative models like JTVAE and CGVAE impose structural constraints (tree decompositions, valency restrictions) that help with validity but limit their ability to scale. Language models, by contrast, only need to generate a single character sequence, making them inherently more flexible.</p>
<h2 id="three-challenging-generative-modeling-tasks">Three Challenging Generative Modeling Tasks</h2>
<p>The paper introduces three benchmark tasks designed to stress-test generative models:</p>
<h3 id="task-1-penalized-logp-distribution">Task 1: Penalized LogP Distribution</h3>
<p>A dataset of approximately 160K molecules from ZINC15 with penalized <a href="https://en.wikipedia.org/wiki/Partition_coefficient">LogP</a> scores exceeding 4.0. The training distribution is sharply peaked around 4.0 to 4.5 with a subtle tail extending above 6.0. Molecules in the tail tend to have long carbon chains and fewer rings. The challenge is learning this skewed distribution rather than just finding individual high-scoring molecules.</p>
<h3 id="task-2-multi-modal-distribution">Task 2: Multi-Modal Distribution</h3>
<p>A composite dataset of approximately 200K molecules drawn from four sources with distinct molecular weight ranges:</p>
<ul>
<li><a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> (MW $\leq$ 185)</li>
<li>ZINC (185 $\leq$ MW $\leq$ 425)</li>
<li>Harvard Clean Energy Project (460 $\leq$ MW $\leq$ 600)</li>
<li>POLYMERS (MW $&gt;$ 600)</li>
</ul>
<p>Models must learn to generate from all four modes simultaneously, each with very different molecular structures.</p>
<h3 id="task-3-large-scale-molecules">Task 3: Large-Scale Molecules</h3>
<p>The largest molecules in <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> with more than 100 heavy atoms, yielding approximately 300K molecules with molecular weights ranging from 1,250 to 5,000. These include small biomolecules, photovoltaics, peptides, and cyclic peptides. This task is particularly challenging because the SMILES/SELFIES strings are very long.</p>
<h2 id="evaluation-by-distributional-fidelity">Evaluation by Distributional Fidelity</h2>
<p>The evaluation framework focuses on how well a model learns the full training distribution rather than generating individual good molecules. The primary quantitative metric is the <a href="https://en.wikipedia.org/wiki/Wasserstein_metric">Wasserstein distance</a> (earth mover&rsquo;s distance) between molecular property distributions of generated and training molecules:</p>
<p>$$W(P, Q) = \inf_{\gamma \in \Gamma(P,Q)} \int | x - y | , d\gamma(x, y)$$</p>
<p>Properties evaluated include LogP, synthetic accessibility (SA), quantitative estimate of drug-likeness (QED), molecular weight (MW), Bertz complexity (BCT), and natural product likeness (NP). An oracle baseline is computed by measuring the Wasserstein distance between different random samples of the training data itself.</p>
<p>Standard metrics (validity, uniqueness, novelty) are also reported but are secondary to distributional fidelity.</p>
<h2 id="architecture-lstm-language-models">Architecture: LSTM Language Models</h2>
<p>The language models use standard LSTM architectures trained autoregressively on molecular strings. Two variants are compared:</p>
<ul>
<li><strong>SM-RNN</strong>: Trained on canonical SMILES</li>
<li><strong>SF-RNN</strong>: Trained on SELFIES representations</li>
</ul>
<p>Hyperparameters are tuned via random search over learning rate ($\in [0.0001, 0.001]$), hidden units ($\in [100, 1000]$), layers (1 to 5), and dropout ($\in [0.0, 0.5]$). Model selection uses a combination of standard metrics and Wasserstein distance rankings.</p>
<p>The graph model baselines include JTVAE (junction tree VAE) and CGVAE (constrained graph VAE), along with several additional baselines (MolGAN, GraphNVP, and others).</p>
<h2 id="results-language-models-outperform-graph-models-across-all-tasks">Results: Language Models Outperform Graph Models Across All Tasks</h2>
<h3 id="penalized-logp">Penalized LogP</h3>
<p>Both RNN models learn the sharp training distribution far better than graph models. The SM-RNN achieves the lowest Wasserstein distances across most properties. The graph models produce substantial out-of-distribution mass around penalized LogP scores of 1.75 to 2.25, failing to capture the peaked nature of the training distribution.</p>
<p>Critically, the RNNs also learn the subtle tail above penalized LogP of 6.0, generating molecules with long carbon chains and fewer rings that match the structural characteristics of high-scoring training molecules. CGVAE and JTVAE almost entirely miss this tail.</p>
<h3 id="multi-modal-distribution">Multi-Modal Distribution</h3>
<p>Both RNN models capture all four modes of the training distribution. JTVAE entirely misses the GDB13 mode and poorly learns the ZINC and CEP modes. CGVAE learns GDB13 but misses the CEP mode. The SM-RNN again achieves the best Wasserstein metrics.</p>
<h3 id="large-scale-molecules">Large-Scale Molecules</h3>
<p>This is the most discriminating task. Both JTVAE and CGVAE completely fail to train on these large molecules. JTVAE&rsquo;s tree decomposition produces a vocabulary of approximately 11,000 substructures, making training intractable. Only the RNN models succeed, with the SF-RNN achieving slightly better distributional match due to SELFIES guaranteeing 100% validity even for very long strings.</p>
<p>Both RNN models also learn the bimodal LogP structure within the large-molecule distribution and can generate molecules with substructures resembling peptides, including backbone chains and standard amino acid side chains.</p>
<h3 id="summary-of-wasserstein-distance-results">Summary of Wasserstein Distance Results</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Model</th>
          <th>LogP</th>
          <th>SA</th>
          <th>QED</th>
          <th>MW</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>LogP</td>
          <td>SM-RNN</td>
          <td>0.095</td>
          <td>0.031</td>
          <td>0.007</td>
          <td>3.3</td>
      </tr>
      <tr>
          <td>LogP</td>
          <td>SF-RNN</td>
          <td>0.177</td>
          <td>0.290</td>
          <td>0.010</td>
          <td>6.3</td>
      </tr>
      <tr>
          <td>LogP</td>
          <td>JTVAE</td>
          <td>0.536</td>
          <td>0.289</td>
          <td>0.081</td>
          <td>35.9</td>
      </tr>
      <tr>
          <td>LogP</td>
          <td>CGVAE</td>
          <td>1.000</td>
          <td>2.120</td>
          <td>0.115</td>
          <td>69.3</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>SM-RNN</td>
          <td>0.081</td>
          <td>0.025</td>
          <td>0.006</td>
          <td>5.5</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>SF-RNN</td>
          <td>0.286</td>
          <td>0.179</td>
          <td>0.023</td>
          <td>11.4</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>JTVAE</td>
          <td>0.495</td>
          <td>0.274</td>
          <td>0.034</td>
          <td>27.7</td>
      </tr>
      <tr>
          <td>Multi</td>
          <td>CGVAE</td>
          <td>1.617</td>
          <td>1.802</td>
          <td>0.076</td>
          <td>30.3</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>SM-RNN</td>
          <td>1.367</td>
          <td>0.213</td>
          <td>0.003</td>
          <td>124.5</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>SF-RNN</td>
          <td>1.095</td>
          <td>0.342</td>
          <td>0.010</td>
          <td>67.3</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>JTVAE</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>CGVAE</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
          <td>&ndash;</td>
      </tr>
  </tbody>
</table>
<h3 id="smiles-vs-selfies-trade-off">SMILES vs. SELFIES Trade-off</h3>
<p>An interesting finding is that SMILES and SELFIES RNNs each have complementary strengths. The SF-RNN consistently achieves better standard metrics (validity, uniqueness, novelty) across all tasks, while the SM-RNN achieves better Wasserstein distance metrics. The authors suggest that the SELFIES grammar may reduce memorization of the training data, improving novelty but slightly hurting distributional fidelity.</p>
<h2 id="limitations">Limitations</h2>
<p>The authors acknowledge several limitations. Language models cannot account for molecular geometry or 3D information, which is important for many applications. The study evaluates distributional fidelity but does not test downstream utility for specific molecular design tasks (e.g., optimizing for a particular biological target). Additionally, while the graph models (JTVAE, CGVAE) are more interpretable, the language models operate as black boxes over string representations. The comparison is also limited to two specific graph model architectures, and more recent or specialized graph models may close the performance gap. Finally, trained model weights are only available upon request rather than being publicly released.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/danielflamshep/genmoltasks">danielflamshep/genmoltasks</a></td>
          <td>Dataset</td>
          <td>Apache-2.0</td>
          <td>Processed training data and generated samples</td>
      </tr>
  </tbody>
</table>
<p><strong>Data</strong>: Three custom datasets constructed from ZINC15, <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a>, Harvard Clean Energy Project, POLYMERS, and PubChem. Processed data available at the GitHub repository.</p>
<p><strong>Code</strong>: LSTM networks implemented in PyTorch using the char-rnn code from the <a href="https://github.com/molecularsets/moses">MOSES repository</a>. Baselines use the official <a href="https://github.com/wengong-jin/icml18-jtnn">JTVAE</a> and <a href="https://github.com/microsoft/constrained-graph-variational-autoencoder">CGVAE</a> implementations. No unified training script is provided in the repository.</p>
<p><strong>Evaluation</strong>: Wasserstein distances computed using SciPy. Molecular properties computed using RDKit. 10K molecules generated from each model for evaluation.</p>
<p><strong>Hyperparameters</strong>: Task-specific configurations reported. For example, the LogP task SM-RNN uses 2 hidden layers with 400 units, dropout of 0.2, and learning rate of 0.0001.</p>
<p><strong>Hardware</strong>: Models were trained using the Canada Computing Systems (Compute Canada). Specific GPU types and training times are not reported.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Flam-Shepherd, D., Zhu, K., &amp; Aspuru-Guzik, A. (2022). Language models can learn complex molecular distributions. <em>Nature Communications</em>, 13, 3293. <a href="https://doi.org/10.1038/s41467-022-30839-x">https://doi.org/10.1038/s41467-022-30839-x</a></p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/danielflamshep/genmoltasks">GitHub: danielflamshep/genmoltasks</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{flamshepherd2022language,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Language models can learn complex molecular distributions}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Flam-Shepherd, Daniel and Zhu, Kevin and Aspuru-Guzik, Al{\&#39;a}n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3293}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s41467-022-30839-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>BARTSmiles: BART Pre-Training for Molecular SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/</link><pubDate>Sun, 22 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/bartsmiles-molecular-representations/</guid><description>BARTSmiles applies BART-style denoising pre-training to 1.7B SMILES from ZINC20, achieving top results on 11 molecular property and reaction tasks.</description><content:encoded><![CDATA[<h2 id="a-bart-based-method-for-molecular-self-supervised-learning">A BART-Based Method for Molecular Self-Supervised Learning</h2>
<p>BARTSmiles is a <strong>Method</strong> paper. It introduces a self-supervised pre-training approach for molecular representations based on the BART (Bidirectional and Auto-Regressive Transformers) architecture from Lewis et al. (2019). The primary contribution is a pre-training strategy, discovered through systematic ablations, that trains a BART-large model on 1.7 billion deduplicated <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES strings</a> from the <a href="/notes/chemistry/datasets/zinc-22/">ZINC20 dataset</a>. BARTSmiles achieves the best reported results on 11 tasks spanning molecular property classification, regression, and chemical reaction generation.</p>
<h2 id="scaling-self-supervised-molecular-representations-beyond-prior-work">Scaling Self-Supervised Molecular Representations Beyond Prior Work</h2>
<p>At the time of publication, large-scale self-supervised representation learning had produced significant improvements in NLP, computer vision, and speech, but molecular representation learning had not benefited from comparable scale. Previous SMILES-based pre-trained models such as <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> (Chithrananda et al., 2020) and <a href="/notes/chemistry/molecular-design/generation/autoregressive/chemformer/">ChemFormer</a> (Irwin et al., 2022) used encoder-only or encoder-decoder architectures with substantially less compute. ChemFormer, the most closely related prior work, also trained a BART-like model but with a fraction of the compute and data.</p>
<p>The paper argues that three gaps needed to be addressed:</p>
<ol>
<li><strong>Scale</strong>: Prior molecular pre-training used orders of magnitude less compute than NLP pre-training.</li>
<li><strong>Architecture choice</strong>: Encoder-only models like ChemBERTa cannot perform generative fine-tuning (retrosynthesis, reaction prediction), limiting their applicability.</li>
<li><strong>Pre-training recipe</strong>: Standard BART hyperparameters (e.g., 30% mask token budget) were tuned for natural language and had not been validated for molecular SMILES strings.</li>
</ol>
<h2 id="core-innovation-ablation-driven-pre-training-recipe-for-smiles">Core Innovation: Ablation-Driven Pre-Training Recipe for SMILES</h2>
<p>The key insight of BARTSmiles is that the BART denoising objective, when carefully tuned for the molecular domain, learns representations that implicitly encode downstream task information. The authors discover this through a systematic three-stage ablation:</p>
<h3 id="tokenization">Tokenization</h3>
<p>Rather than using hand-crafted tokenization rules that separate individual atoms (C, N, H) and bond symbols (#, =), BARTSmiles uses a learned SentencePiece unigram tokenizer trained on 10 million random SMILES with a vocabulary size of 1,021. On matched compute budgets, learned tokenization achieves 0.801 average AUC-ROC vs. 0.779 for hand-crafted tokenization on the ablation benchmark (HIV, BBBP, ClinTox).</p>
<h3 id="masking-strategy">Masking Strategy</h3>
<p>The BART denoising objective has three main hyperparameters: the mask token budget (fraction of tokens masked), random mask probability, and the Poisson $\lambda$ controlling mask span length. The ablation results show:</p>
<ul>
<li><strong>Mask token budget</strong>: The standard BART value of 0.30 is suboptimal for molecules. A budget of 0.20 performs best (0.821 AUC-ROC), with performance degrading at both lower (0.10: 0.753) and higher (0.40: 0.701) budgets.</li>
<li><strong>Span masking</strong>: The choice of random mask probability and $\lambda$ has a minor effect once the budget is set to 0.20. Values of random mask = 0.10 and $\lambda$ = 2.5 or 3.5 all yield 0.821.</li>
<li><strong>Token randomization</strong>: Disabling the randomize-tokens noise (where some tokens are replaced with random tokens rather than masked) improves performance from 0.821 to 0.835.</li>
</ul>
<h3 id="scale">Scale</h3>
<p>Training on the full 1.7 billion molecule ZINC20 dataset (20 hours on 1,024 A100 GPUs, totaling 20,480 A100 GPU-hours) improves performance by 5 absolute AUC-ROC points over the same model trained on 100 million samples. The previous most compute-intensive molecular pre-training used 3,330 V100-hours (Ross et al., 2021).</p>
<h3 id="implicit-task-encoding">Implicit Task Encoding</h3>
<p>The paper provides a quantitative demonstration that frozen BARTSmiles representations encode task-specific information. Using L1-regularized logistic regression on frozen 1,024-dimensional mean-pooled representations, just 7 neurons are sufficient to achieve 0.987 AUC-ROC on ClinTox (within 2 percentage points of full fine-tuning). Even a single neuron achieves 0.77 AUC-ROC on ClinTox subtask 1.</p>
<h2 id="experimental-setup-moleculenet-toxicology-and-generative-benchmarks">Experimental Setup: MoleculeNet, Toxicology, and Generative Benchmarks</h2>
<h3 id="classification-tasks">Classification Tasks</h3>
<p>BARTSmiles is evaluated on 7 classification datasets from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> (SIDER, ClinTox, Tox21, ToxCast, HIV, BACE, BBBP) plus 2 toxicology datasets (<a href="https://en.wikipedia.org/wiki/Ames_test">Ames</a>, <a href="https://en.wikipedia.org/wiki/Micronucleus_test">Micronucleus Assay</a>). All classification tasks use AUC-ROC. Baselines include both supervised graph models (D-MPNN, Attentive FP, 3D InfoMax) and self-supervised methods (ChemBERTa, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer-XL</a>, GROVER-large, MolCLR, iMolCLR).</p>
<p>Selected classification results (AUC-ROC):</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>BARTSmiles</th>
          <th>Previous Best</th>
          <th>Previous Best Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ClinTox</td>
          <td><strong>0.997</strong></td>
          <td>0.954</td>
          <td>iMolCLR</td>
      </tr>
      <tr>
          <td>ToxCast</td>
          <td><strong>0.825</strong></td>
          <td>0.805</td>
          <td>Attentive FP</td>
      </tr>
      <tr>
          <td>SIDER</td>
          <td><strong>0.705</strong></td>
          <td>0.699</td>
          <td>iMolCLR</td>
      </tr>
      <tr>
          <td>Tox21</td>
          <td>0.851</td>
          <td>0.858</td>
          <td>Attentive FP</td>
      </tr>
  </tbody>
</table>
<p>The authors note that three scaffold-split datasets (HIV, BACE, BBBP) are highly sensitive to the specific split used, and they suspect some baseline results use different or random splits. These results are marked with caveats in the paper.</p>
<h3 id="regression-tasks">Regression Tasks</h3>
<p>All three MoleculeNet regression tasks (ESOL, FreeSolv, Lipophilicity) are evaluated using RMSE:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>BARTSmiles</th>
          <th>Previous Best</th>
          <th>Previous Best Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ESOL</td>
          <td><strong>0.095</strong></td>
          <td>0.279</td>
          <td>MoLFormer-XL</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td><strong>0.114</strong></td>
          <td>0.231</td>
          <td>MoLFormer-XL</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td><strong>0.292</strong></td>
          <td>0.529</td>
          <td>MoLFormer-XL</td>
      </tr>
  </tbody>
</table>
<p>BARTSmiles achieves substantial improvements on all three regression tasks.</p>
<h3 id="generative-tasks">Generative Tasks</h3>
<p><strong><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></strong> (USPTO-50k): BARTSmiles achieves 55.6% Top-1 accuracy using a sample-128 + perplexity re-ranking strategy, compared to 55.3% for Dual-TF and 54.3% for ChemFormer. Top-5 and Top-10 results are 74.2% and 80.9% respectively.</p>
<p><strong>Chemical Reaction Prediction</strong> (USPTO MIT/LEF/STEREO): BARTSmiles with beam search outperforms the <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a> baseline across all six evaluation settings. On USPTO-MIT (split), BARTSmiles achieves 91.8% vs. 90.4% for the Transformer baseline.</p>
<h3 id="fine-tuning-recipe">Fine-Tuning Recipe</h3>
<p>The fine-tuning approach is designed to minimize hyperparameter tuning:</p>
<ul>
<li>Batch size 16, 10 epochs, polynomial decay learning rate schedule with warmup at 16% of training</li>
<li>Grid search over dropout (0.1, 0.2, 0.3) and learning rate ($5 \times 10^{-6}$, $1 \times 10^{-5}$, $3 \times 10^{-5}$)</li>
<li>Stochastic Weight Averaging (SWA) over three sets of four checkpoints</li>
<li>For generative tasks: R3F regularization (Aghajanyan et al., 2020a) and full fp32 precision</li>
<li>For generation: beam search (beam size 10) or sample 128 sequences with perplexity re-ranking</li>
</ul>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<h3 id="key-findings">Key Findings</h3>
<ol>
<li><strong>Scale matters for molecular pre-training</strong>: Training on 1.7B molecules with 20,480 A100 GPU-hours yields 5 absolute points of AUC-ROC improvement over training on 100M molecules.</li>
<li><strong>Domain-specific ablation is necessary</strong>: The optimal BART masking configuration for molecules (20% budget, no token randomization) differs from the standard NLP configuration (30% budget, with randomization).</li>
<li><strong>Frozen representations capture task structure</strong>: A small number of neurons from the frozen model can nearly match full fine-tuning performance on certain tasks, suggesting the pre-training objective implicitly encodes molecular properties.</li>
<li><strong>Interpretability aligns with domain knowledge</strong>: Integrated Gradients attribution on fine-tuned BARTSmiles highlights known structural alerts (e.g., <a href="https://en.wikipedia.org/wiki/Nitro_compound">nitro groups</a> in mutagenic compounds, hydroxyl groups in soluble compounds).</li>
</ol>
<h3 id="limitations">Limitations</h3>
<ul>
<li><strong>Scaffold split sensitivity</strong>: Results on HIV, BACE, and BBBP are sensitive to the specific scaffold split, making direct comparison with baselines difficult.</li>
<li><strong>Pre-training data distribution</strong>: The <a href="https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance">Frechet distance</a> analysis shows that some downstream datasets (BBBP, SIDER) are far from ZINC20 in representation space, which may explain weaker performance on those tasks.</li>
<li><strong>Fingerprints carry complementary information</strong>: On the Ames and Micronucleus Assay datasets, BARTSmiles alone does not beat fingerprint-based baselines. Combining BARTSmiles with ECFP4 fingerprints closes the gap, implying that SMILES-based pre-training does not fully capture all structural information.</li>
<li><strong>Compute requirements</strong>: Pre-training requires 1,024 A100 GPUs, which limits accessibility.</li>
</ul>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest investigating the impact of pre-training data composition, noting that ZINC20 contains over a billion molecules but its distribution may be irrelevant for many downstream tasks. They also propose further collaboration between ML and chemistry experts to discover new molecular substructure-property relationships.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/YerevaNN/BARTSmiles">BARTSmiles (GitHub)</a></td>
          <td>Code + Model</td>
          <td>MIT</td>
          <td>Pre-training, fine-tuning, and evaluation scripts with pre-trained weights</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>ZINC20 (deduplicated)</td>
          <td>~1.7B molecules</td>
          <td>Canonicalized SMILES, 10K validation holdout</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>MoleculeNet (7 datasets)</td>
          <td>1,427-41,127 compounds</td>
          <td>AUC-ROC metric</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>MoleculeNet (3 datasets)</td>
          <td>642-4,200 compounds</td>
          <td>RMSE metric</td>
      </tr>
      <tr>
          <td>Toxicology</td>
          <td>Ames, MN Assay</td>
          <td>6,512 / 641 compounds</td>
          <td>Cross-validation for Ames; external test for MN</td>
      </tr>
      <tr>
          <td>Retrosynthesis</td>
          <td>USPTO-50k</td>
          <td>Standard split</td>
          <td>Top-K accuracy</td>
      </tr>
      <tr>
          <td>Reaction prediction</td>
          <td>USPTO (MIT/LEF/STEREO)</td>
          <td>Standard splits</td>
          <td>Top-1 accuracy</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Architecture: BART-Large (pre-layer norm Transformer encoder-decoder)</li>
<li>Tokenizer: SentencePiece unigram, vocabulary size 1,021, max sequence length 128</li>
<li>Pre-training objective: BART denoising (mask token budget 0.20, Poisson span masking with $\lambda$ = 2.5, no token randomization)</li>
<li>Fine-tuning: polynomial decay LR, SWA, grid search over dropout and LR</li>
<li>Generative fine-tuning: R3F regularization, fp32 precision, Adam initialized from pre-training moving averages</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>BART-Large architecture (exact parameter count not specified in paper)</li>
<li>Pre-trained checkpoint released on GitHub</li>
<li>Maximum sequence length: 128 tokens</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>BARTSmiles</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ClinTox</td>
          <td>AUC-ROC</td>
          <td>0.997</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>ToxCast</td>
          <td>AUC-ROC</td>
          <td>0.825</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>ESOL</td>
          <td>RMSE</td>
          <td>0.095</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>FreeSolv</td>
          <td>RMSE</td>
          <td>0.114</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>Lipophilicity</td>
          <td>RMSE</td>
          <td>0.292</td>
          <td>New SOTA</td>
      </tr>
      <tr>
          <td>USPTO-50k Retro (Top-1)</td>
          <td>Accuracy</td>
          <td>55.6%</td>
          <td>New SOTA (sample + re-rank)</td>
      </tr>
      <tr>
          <td>USPTO-MIT Rxn (Split)</td>
          <td>Accuracy</td>
          <td>91.8%</td>
          <td>New SOTA (beam-10)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Pre-training: 1,024 NVIDIA A100 GPUs for 20 hours (20,480 A100 GPU-hours)</li>
<li>Ablation runs: 128 A100 GPUs per run</li>
<li>Framework: FairSeq with FairScale (fully sharded data parallel), automatic mixed precision</li>
<li>Experiment tracking: Aim</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chilingaryan, G., Tamoyan, H., Tevosyan, A., Babayan, N., Khondkaryan, L., Hambardzumyan, K., Navoyan, Z., Khachatrian, H., &amp; Aghajanyan, A. (2024). BARTSmiles: Generative Masked Language Models for Molecular Representations. <em>Journal of Chemical Information and Modeling</em>, 64(15), 5832-5843. <a href="https://doi.org/10.1021/acs.jcim.4c00512">https://doi.org/10.1021/acs.jcim.4c00512</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling, 2024 (preprint: arXiv 2022)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/YerevaNN/BARTSmiles">BARTSmiles GitHub Repository (MIT License)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chilingaryan2024bartsmiles,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{BARTSmiles: Generative Masked Language Models for Molecular Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chilingaryan, Gayane and Tamoyan, Hovhannes and Tevosyan, Ani and Babayan, Nelly and Khondkaryan, Lusine and Hambardzumyan, Karen and Navoyan, Zaven and Khachatrian, Hrant and Aghajanyan, Armen}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5832--5843}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.4c00512}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Molecular Transformer: Calibrated Reaction Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/</link><pubDate>Wed, 18 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/</guid><description>A Transformer seq2seq model for chemical reaction prediction achieving 90.4% top-1 accuracy on USPTO_MIT with calibrated uncertainty estimation.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-methodological-classification">Paper Contribution and Methodological Classification</h2>
<p>This is a <strong>Method</strong> paper. It adapts the Transformer architecture to chemical reaction prediction, treating it as a machine translation problem from reactant <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> to product SMILES. The key contributions are (1) demonstrating that a fully attention-based model outperforms all prior template-based, graph-based, and RNN-based methods, (2) showing the model works without separating reactants from reagents, and (3) introducing calibrated uncertainty estimation for ranking synthesis pathways.</p>
<h2 id="motivation-limitations-of-existing-reaction-prediction">Motivation: Limitations of Existing Reaction Prediction</h2>
<p>Prior approaches to reaction prediction fell into two broad groups, template-based and template-free, each with fundamental limitations:</p>
<ul>
<li><strong>Template-based methods</strong> rely on libraries of reaction rules, either handcrafted or automatically extracted from atom-mapped data. Automatic template extraction itself depends on atom mapping, which depends on templates, creating a circular dependency.</li>
<li><strong>Graph-based template-free methods</strong> (e.g., WLDN, ELECTRO) avoid explicit templates but still require atom-mapped training data and cannot handle stereochemistry.</li>
<li><strong><a href="/notes/chemistry/molecular-design/reaction-prediction/nmt-organic-reaction-prediction/">RNN-based seq2seq models</a></strong> (also template-free) treat reactions as SMILES translation but impose a positional inductive bias: tokens far apart in the SMILES string are assumed to be less related. This is incorrect because SMILES position has no relationship to 3D spatial distance.</li>
</ul>
<h2 id="core-innovation-transformer-for-reaction-prediction">Core Innovation: Transformer for Reaction Prediction</h2>
<p>The Molecular Transformer adapts the Transformer architecture to chemical reactions by treating SMILES strings of reactants and reagents as source sequences and product SMILES as target sequences.</p>
<ul>
<li><strong>Architecture</strong>: Encoder-decoder Transformer with 4 layers, 256-dimensional hidden states, 8 attention heads, and 12M parameters (reduced from the original 65M NMT model).</li>
<li><strong>Tokenization</strong>: Atom-wise regex tokenization of SMILES strings, applied uniformly to both reactants and reagents (no special reagent tokens).</li>
<li><strong>Data augmentation</strong>: Training data is doubled by generating <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">random (non-canonical) SMILES</a> for each reaction, which improves top-1 accuracy by roughly 1%.</li>
<li><strong>Weight averaging</strong>: Final model weights are averaged over the last 20 checkpoints, providing a further accuracy boost without the inference cost of ensembling.</li>
<li><strong>Mixed input</strong>: Unlike all prior work that separates reactants from reagents (which implicitly assumes knowledge of the product), the Molecular Transformer operates on mixed inputs where no distinction is made.</li>
</ul>
<p>The multihead attention mechanism is the key architectural advantage over RNNs. It allows the model to attend to any pair of tokens regardless of their position in the SMILES string, correctly capturing long-range chemical relationships that RNNs miss.</p>
<h2 id="uncertainty-estimation">Uncertainty Estimation</h2>
<p>A central contribution is calibrated uncertainty scoring. The product of predicted token probabilities serves as a confidence score for each prediction. This score achieves 0.89 AUC-ROC for classifying whether a prediction is correct.</p>
<p>An important finding: <strong>label smoothing hurts uncertainty calibration</strong>. While label smoothing (as used in the original Transformer) marginally improves top-1 accuracy (87.44% vs 87.28%), it destroys the model&rsquo;s ability to distinguish correct from incorrect predictions. Setting the label smoothing parameter to 0.0 preserves calibration.</p>
<p>The confidence score shows no correlation with SMILES length (Pearson $r = 0.06$), confirming it is not biased against predictions of larger molecules.</p>
<h2 id="experimental-results">Experimental Results</h2>
<h3 id="forward-synthesis-prediction">Forward Synthesis Prediction</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Setting</th>
          <th style="text-align: left">Top-1 (%)</th>
          <th style="text-align: left">Top-2 (%)</th>
          <th style="text-align: left">Top-5 (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">separated</td>
          <td style="text-align: left">90.4</td>
          <td style="text-align: left">93.7</td>
          <td style="text-align: left">95.3</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">mixed</td>
          <td style="text-align: left">88.6</td>
          <td style="text-align: left">92.4</td>
          <td style="text-align: left">94.2</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">separated</td>
          <td style="text-align: left">78.1</td>
          <td style="text-align: left">84.0</td>
          <td style="text-align: left">87.1</td>
      </tr>
      <tr>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">mixed</td>
          <td style="text-align: left">76.2</td>
          <td style="text-align: left">82.4</td>
          <td style="text-align: left">85.8</td>
      </tr>
  </tbody>
</table>
<p>The mixed-input model (88.6%) outperforms all prior methods that used separated inputs (best previous: WLDN5 at 85.6%).</p>
<h3 id="comparison-with-quantum-chemistry">Comparison with Quantum Chemistry</h3>
<p>On <a href="https://en.wikipedia.org/wiki/Regioselectivity">regioselectivity</a> of <a href="https://en.wikipedia.org/wiki/Electrophilic_aromatic_substitution">electrophilic aromatic substitution</a> in heteroaromatics, the Molecular Transformer achieves 83% top-1 accuracy vs 81% for RegioSQM (a quantum-chemistry-based predictor), at a fraction of the computational cost.</p>
<h3 id="comparison-with-human-chemists">Comparison with Human Chemists</h3>
<p>On 80 reactions sampled across rarity bins, the Molecular Transformer achieves 87.5% top-1 accuracy vs 76.5% for the best human chemist and 72.5% for the best graph-based model (WLDN5).</p>
<h3 id="chemically-constrained-beam-search">Chemically Constrained Beam Search</h3>
<p>Constraining beam search to only predict atoms present in the reactants (preventing &ldquo;alchemy&rdquo;) produces no change in accuracy, confirming the model has learned conservation of atoms from data alone.</p>
<h2 id="trade-offs-and-limitations">Trade-offs and Limitations</h2>
<ul>
<li><strong><a href="https://en.wikipedia.org/wiki/Stereochemistry">Stereochemistry</a></strong>: Accuracy drops significantly on USPTO_STEREO (76-78% vs 88-90% on USPTO_MIT), indicating stereochemical prediction remains challenging.</li>
<li><strong>Resolution reactions</strong>: Near-zero accuracy on resolution reactions (28.6%), where reagent information is often missing from patent data.</li>
<li><strong>Unclassified reactions</strong>: Accuracy on &ldquo;unrecognized&rdquo; reaction classes is 46.3%, likely reflecting noisy or mistranscribed data.</li>
<li><strong>No atom mapping</strong>: The model provides no explicit atom mapping between reactants and products, which limits interpretability for understanding reaction mechanisms.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Primary benchmark</strong></td>
          <td style="text-align: left">USPTO_MIT</td>
          <td style="text-align: left">479K</td>
          <td style="text-align: left">Filtered by Jin et al., no stereochemistry</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>LEF subset</strong></td>
          <td style="text-align: left">USPTO_LEF</td>
          <td style="text-align: left">350K</td>
          <td style="text-align: left">Subset of MIT with linear electron flow only</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stereo benchmark</strong></td>
          <td style="text-align: left">USPTO_STEREO</td>
          <td style="text-align: left">1.0M</td>
          <td style="text-align: left">Patent reactions through Sept 2016, includes stereochemistry</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Time-split test</strong></td>
          <td style="text-align: left">Pistachio_2017</td>
          <td style="text-align: left">15.4K</td>
          <td style="text-align: left">Non-public, reactions from 2017</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>: SMILES canonicalized with RDKit. Regex tokenization from Schwaller et al. (2018). Two input modes: &ldquo;separated&rdquo; (reactants &gt; reagents) and &ldquo;mixed&rdquo; (all molecules concatenated).</p>
<h3 id="model">Model</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hyperparameter</th>
          <th style="text-align: left">Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Layers</strong></td>
          <td style="text-align: left">4</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Model dimension</strong></td>
          <td style="text-align: left">256</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Attention heads</strong></td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Parameters</strong></td>
          <td style="text-align: left">~12M</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Label smoothing</strong></td>
          <td style="text-align: left">0.0</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Optimizer</strong></td>
          <td style="text-align: left">Adam</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Warm-up steps</strong></td>
          <td style="text-align: left">8000</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Batch size</strong></td>
          <td style="text-align: left">~4096 tokens</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Beam width</strong></td>
          <td style="text-align: left">5</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Key Result</th>
          <th style="text-align: left">Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">USPTO_MIT (sep)</td>
          <td style="text-align: left"><strong>90.4%</strong></td>
          <td style="text-align: left">85.6% (WLDN5)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">USPTO_MIT (mixed)</td>
          <td style="text-align: left"><strong>88.6%</strong></td>
          <td style="text-align: left">80.3% (S2S RNN)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>AUC-ROC</strong></td>
          <td style="text-align: left">Uncertainty calibration</td>
          <td style="text-align: left"><strong>0.89</strong></td>
          <td style="text-align: left">N/A</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">Regioselectivity</td>
          <td style="text-align: left"><strong>83%</strong></td>
          <td style="text-align: left">81% (RegioSQM)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 accuracy</strong></td>
          <td style="text-align: left">Human comparison</td>
          <td style="text-align: left"><strong>87.5%</strong></td>
          <td style="text-align: left">76.5% (best human)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: Single Nvidia P100 GPU, 48h for best single model</li>
<li>Inference: 20 min for 40K reactions on single P100</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Schwaller, P., Laino, T., Gaudin, T., Bolgar, P., Hunter, C. A., Bekas, C., &amp; Lee, A. A. (2019). Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction. <em>ACS Central Science</em>, 5(9), 1572-1583. <a href="https://doi.org/10.1021/acscentsci.9b00576">https://doi.org/10.1021/acscentsci.9b00576</a></p>
<p><strong>Publication</strong>: ACS Central Science 2019</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{schwallerMolecularTransformerModel2019,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Schwaller, Philippe and Laino, Teodoro and Gaudin, Th{\&#39;e}ophile and Bolgar, Peter and Hunter, Christopher A. and Bekas, Costas and Lee, Alpha A.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2019</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{ACS Central Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1572--1583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/acscentsci.9b00576}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SELFormer: A SELFIES-Based Molecular Language Model</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/selformer/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/selformer/</guid><description>A SELFIES-based RoBERTa model pretrained on 2M ChEMBL molecules for molecular property prediction on MoleculeNet benchmarks.</description><content:encoded><![CDATA[<h2 id="a-selfies-based-chemical-language-model">A SELFIES-Based Chemical Language Model</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$) with a secondary <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>SELFormer applies the RoBERTa transformer architecture to <a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">SELFIES</a> molecular string representations instead of the <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> notation used by prior chemical language models. The model is pretrained via masked language modeling (MLM) on 2M drug-like compounds from <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> and fine-tuned for molecular property prediction tasks on <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> benchmarks. The authors release pretrained models, fine-tuning code, and datasets as open-source resources.</p>
<h2 id="why-selfies-over-smiles-for-pretraining">Why SELFIES Over SMILES for Pretraining?</h2>
<p>Existing chemical language models, including <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/chemberta-2/">ChemBERTa-2</a>, <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>, and <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MolFormer</a>, all use SMILES as their input representation. SMILES has well-documented validity and robustness issues: arbitrary perturbations to a SMILES string frequently produce syntactically invalid outputs. This means a pretrained model must spend capacity learning SMILES grammar rules rather than chemical semantics.</p>
<p><a href="/notes/chemistry/molecular-representations/notations/selfies-original-paper/">SELFIES</a> addresses this by construction: every possible SELFIES string decodes to a valid molecule. Despite this theoretical advantage and SELFIES&rsquo; growing adoption in generative chemistry, no prior work had systematically evaluated SELFIES as input for large-scale transformer pretraining. SELFormer fills this gap by providing a direct comparison between SELFIES-based and SMILES-based chemical language models on standard benchmarks.</p>
<h2 id="masked-language-modeling-on-guaranteed-valid-molecular-strings">Masked Language Modeling on Guaranteed-Valid Molecular Strings</h2>
<p>SELFormer uses byte-level Byte-Pair Encoding (BPE) to tokenize SELFIES strings, then pretrains a RoBERTa encoder using the standard MLM objective. 15% of input tokens are masked, and the model minimizes the cross-entropy loss over the masked positions:</p>
<p>$$
\mathcal{L}_{\text{MLM}} = -\frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid x_{\setminus \mathcal{M}}; \theta)
$$</p>
<p>where $\mathcal{M}$ is the set of masked token indices, $x_i$ is the true token at position $i$, $x_{\setminus \mathcal{M}}$ is the corrupted input context, and $\theta$ are the model parameters.</p>
<p>The key insight is that because SELFIES guarantees 100% validity, every masked token prediction corresponds to a valid molecular fragment. The model never wastes capacity predicting invalid chemistry. For fine-tuning, a two-layer classification or regression head is added on top of the encoder&rsquo;s output embedding.</p>
<p>Two model sizes were trained. Notably, the larger SELFormer uses fewer attention heads (4) but more hidden layers (12) than SELFormer-Lite (12 heads, 8 layers). This counterintuitive configuration emerged from the authors&rsquo; hyperparameter search over ~100 models, where deeper architectures with fewer heads outperformed wider, shallower ones:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>SELFormer-Lite</th>
          <th>SELFormer</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Attention Heads</td>
          <td>12</td>
          <td>4</td>
      </tr>
      <tr>
          <td>Hidden Layers</td>
          <td>8</td>
          <td>12</td>
      </tr>
      <tr>
          <td>Batch Size</td>
          <td>16</td>
          <td>16</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>5e-5</td>
          <td>5e-5</td>
      </tr>
      <tr>
          <td>Weight Decay</td>
          <td>0.01</td>
          <td>0.01</td>
      </tr>
      <tr>
          <td>Pretraining Epochs</td>
          <td>100</td>
          <td>100</td>
      </tr>
      <tr>
          <td>Parameters</td>
          <td>58.3M</td>
          <td>86.7M</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmarking-against-smiles-transformers-and-graph-models">Benchmarking Against SMILES Transformers and Graph Models</h2>
<p>SELFormer was pretrained on 2.08M drug-like compounds from ChEMBL v30 (converted from SMILES to SELFIES), then fine-tuned on nine MoleculeNet tasks. All evaluations use scaffold splitting via the Chemprop library.</p>
<p><strong>Classification tasks</strong> (ROC-AUC, scaffold split):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BACE</th>
          <th>BBBP</th>
          <th>HIV</th>
          <th>Tox21</th>
          <th>SIDER</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SELFormer</td>
          <td>0.832</td>
          <td><strong>0.902</strong></td>
          <td>0.681</td>
          <td>0.653</td>
          <td><strong>0.745</strong></td>
      </tr>
      <tr>
          <td>ChemBERTa-2</td>
          <td>0.799</td>
          <td>0.728</td>
          <td>0.622</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>MolBERT</td>
          <td><strong>0.866</strong></td>
          <td>0.762</td>
          <td><strong>0.783</strong></td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.809</td>
          <td>0.710</td>
          <td>0.771</td>
          <td>0.759</td>
          <td>0.570</td>
      </tr>
      <tr>
          <td>MolCLR</td>
          <td><strong>0.890</strong></td>
          <td>0.736</td>
          <td><strong>0.806</strong></td>
          <td><strong>0.787</strong></td>
          <td>0.652</td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.856</td>
          <td>0.724</td>
          <td><strong>0.806</strong></td>
          <td>0.781</td>
          <td>0.672</td>
      </tr>
      <tr>
          <td>KPGT</td>
          <td>0.855</td>
          <td><strong>0.908</strong></td>
          <td>-</td>
          <td><strong>0.848</strong></td>
          <td>0.649</td>
      </tr>
  </tbody>
</table>
<p><strong>Regression tasks</strong> (RMSE, scaffold split, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
          <th>PDBbind</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>SELFormer</td>
          <td><strong>0.682</strong></td>
          <td>2.797</td>
          <td>0.735</td>
          <td>1.488</td>
      </tr>
      <tr>
          <td>ChemBERTa-2</td>
          <td>-</td>
          <td>-</td>
          <td>0.986</td>
          <td>-</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>1.050</td>
          <td><strong>2.082</strong></td>
          <td><strong>0.683</strong></td>
          <td><strong>1.397</strong></td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.798</td>
          <td><strong>1.877</strong></td>
          <td>0.660</td>
          <td>-</td>
      </tr>
      <tr>
          <td>KPGT</td>
          <td>0.803</td>
          <td>2.121</td>
          <td><strong>0.600</strong></td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>The ablation study compared SELFormer vs. SELFormer-Lite across pretrained-only, 25-epoch, and 50-epoch fine-tuning configurations on randomly split datasets. SELFormer consistently outperformed SELFormer-Lite, confirming the benefit of the deeper (12-layer) architecture.</p>
<h2 id="strong-classification-performance-with-compact-pretraining">Strong Classification Performance with Compact Pretraining</h2>
<p>SELFormer&rsquo;s strongest results come on classification tasks where molecular substructure matters:</p>
<ul>
<li><strong>SIDER</strong>: Best overall ROC-AUC (0.745), outperforming the next best method (MolCLR at 0.652) by 9.3 percentage points. The authors attribute this to SELFIES&rsquo; ability to capture subtle structural differences relevant to drug side effects.</li>
<li><strong>BBBP</strong>: Second best (0.902), behind only KPGT (0.908). SELFormer scored 17.4 percentage points above ChemBERTa-2 (0.728) on this task.</li>
<li><strong>BACE/HIV vs. ChemBERTa-2</strong>: SELFormer outperformed ChemBERTa-2 by 3.3 points on BACE (0.832 vs 0.799), 17.4 on BBBP, and 5.9 on HIV (0.681 vs 0.622). Since both models use similar RoBERTa architectures, this comparison is suggestive of a SELFIES advantage, though differences in pretraining corpus (ChEMBL vs PubChem), corpus size, and training procedure confound a clean attribution to the input representation alone.</li>
<li><strong>ESOL regression</strong>: Best RMSE (0.682) vs GEM (0.798), a 14.5% relative improvement.</li>
</ul>
<p>Limitations are also apparent:</p>
<ul>
<li><strong>HIV and Tox21</strong>: SELFormer underperforms graph-based methods (MolCLR, GEM, KPGT) on these larger datasets. The authors attribute this to insufficient hyperparameter search given computational constraints.</li>
<li><strong>FreeSolv and Lipophilicity regression</strong>: D-MPNN and graph-based methods maintain an edge, suggesting that explicit 2D/3D structural inductive biases remain valuable for certain property types.</li>
<li><strong>Small pretraining corpus</strong>: At 2M molecules, SELFormer&rsquo;s corpus is orders of magnitude smaller than MolFormer&rsquo;s 1.1B. Despite this, SELFormer outperforms MolFormer on SIDER (0.745 vs 0.690), highlighting SELFIES&rsquo; representational advantage.</li>
<li><strong>Single-task ablation scope</strong>: Some architectural claims rest on limited task coverage, and broader benchmarking would strengthen the conclusions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>ChEMBL v30</td>
          <td>2,084,725 compounds (2,084,472 after SELFIES conversion)</td>
          <td>Drug-like bioactive small molecules</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BACE</td>
          <td>1,513</td>
          <td><a href="https://en.wikipedia.org/wiki/Beta-secretase_1">Beta-secretase 1</a> inhibitor binding</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP</td>
          <td>2,039</td>
          <td><a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">Blood-brain barrier</a> permeability</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>HIV</td>
          <td>41,127</td>
          <td>HIV replication inhibition</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>SIDER</td>
          <td>1,427</td>
          <td>Drug side effects (27 classes)</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>Tox21</td>
          <td>7,831</td>
          <td>Toxicity (12 targets)</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>ESOL</td>
          <td>1,128</td>
          <td>Aqueous solubility</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>FreeSolv</td>
          <td>642</td>
          <td>Hydration free energy</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>Lipophilicity</td>
          <td>4,200</td>
          <td>Octanol/water distribution coefficient</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>PDBbind</td>
          <td>11,908</td>
          <td>Binding affinity</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining objective</strong>: Masked language modeling (MLM), 15% token masking</li>
<li><strong>Tokenization</strong>: Byte-level Byte-Pair Encoding (BPE) on SELFIES strings</li>
<li><strong>SMILES to SELFIES conversion</strong>: SELFIES API with Pandaral.lel for parallelization</li>
<li><strong>Splitting</strong>: Scaffold splitting via Chemprop library (80/10/10 train/validation/test)</li>
<li><strong>Fine-tuning</strong>: Two-layer classification/regression head on encoder output; up to 200 epochs with hyperparameter search</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: RoBERTa (HuggingFace Transformers)</li>
<li><strong>SELFormer</strong>: 12 hidden layers, 4 attention heads, 86.7M parameters</li>
<li><strong>SELFormer-Lite</strong>: 8 hidden layers, 12 attention heads, 58.3M parameters</li>
<li><strong>Hyperparameter search</strong>: Sequential search over ~100 configurations on 100K molecule subset</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Area under receiver operating characteristic curve</td>
      </tr>
      <tr>
          <td>PRC-AUC</td>
          <td>Classification</td>
          <td>Area under precision-recall curve (reported for random splits)</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression</td>
          <td>Root mean squared error</td>
      </tr>
  </tbody>
</table>
<p>Results reported on scaffold split and random split datasets.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 2x NVIDIA A5000 GPUs</li>
<li><strong>Hyperparameter optimization time</strong>: ~11 days</li>
<li><strong>Full pretraining</strong>: 100 epochs on 2.08M molecules</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/HUBioDataLab/SELFormer">SELFormer GitHub</a></td>
          <td>Code</td>
          <td>GPL-3.0</td>
          <td>Pretraining, fine-tuning, and evaluation scripts</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/HUBioDataLab/SELFormer">SELFormer on HuggingFace</a></td>
          <td>Model</td>
          <td>GPL-3.0</td>
          <td>Pretrained SELFormer weights</td>
      </tr>
      <tr>
          <td><a href="https://www.ebi.ac.uk/chembl/">ChEMBL v30</a></td>
          <td>Dataset</td>
          <td>CC BY-SA 3.0</td>
          <td>Source pretraining data</td>
      </tr>
      <tr>
          <td><a href="https://moleculenet.org/">MoleculeNet</a></td>
          <td>Benchmark</td>
          <td>Unknown</td>
          <td>Downstream evaluation tasks</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yüksel, A., Ulusoy, E., Ünlü, A., &amp; Doğan, T. (2023). SELFormer: Molecular Representation Learning via SELFIES Language Models. <em>Machine Learning: Science and Technology</em>, 4(2), 025035. <a href="https://doi.org/10.1088/2632-2153/acdb30">https://doi.org/10.1088/2632-2153/acdb30</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/HUBioDataLab/SELFormer">GitHub Repository (SELFormer)</a></li>
<li><a href="https://huggingface.co/HUBioDataLab/SELFormer">HuggingFace Model Hub (SELFormer)</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yuksel2023selformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{SELFormer}: Molecular Representation Learning via {SELFIES} Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Y{\&#34;u}ksel, Atakan and Ulusoy, Erva and {\&#34;U}nl{\&#34;u}, Atabey and Do{\u{g}}an, Tunca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{025035}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1088/2632-2153/acdb30}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MoLFormer: Large-Scale Chemical Language Representations</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molformer/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/molformer/</guid><description>A linear-attention transformer pretrained on 1.1B SMILES from PubChem and ZINC for molecular property prediction across MoleculeNet benchmarks.</description><content:encoded><![CDATA[<h2 id="a-billion-scale-chemical-language-model">A Billion-Scale Chemical Language Model</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>MoLFormer is a transformer encoder pretrained via masked language modeling on 1.1 billion <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> and <a href="https://en.wikipedia.org/wiki/ZINC_database">ZINC</a>. The key architectural choices are linear attention (for $O(N)$ complexity instead of $O(N^2)$) and rotary positional embeddings (RoPE). The resulting model, MoLFormer-XL, produces molecular embeddings that outperform or match GNN baselines across a wide range of <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification and regression tasks, including quantum-chemical property prediction from SMILES alone.</p>
<h2 id="bridging-the-gap-between-molecular-languages-and-graph-neural-networks">Bridging the Gap Between Molecular Languages and Graph Neural Networks</h2>
<p>Prior chemical language models like <a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> were pretrained on relatively small datasets (10M-77M molecules) and generally underperformed GNNs on molecular property prediction. The core question: does a transformer trained on a sufficiently large SMILES corpus learn enough chemical structure to compete with graph-based methods that have explicit topological inductive biases?</p>
<p>Two specific challenges motivated this work:</p>
<ul>
<li><strong>Scale</strong>: The chemical space spans $10^{60}$ to $10^{100}$ plausible molecules, yet labeled property data is scarce. Self-supervised pretraining on the ~1.1B unlabeled molecules available in public databases could provide a general-purpose representation.</li>
<li><strong>Efficiency</strong>: Standard transformer attention is $O(N^2)$ in sequence length, making billion-scale pretraining impractical without architectural modifications.</li>
</ul>
<h2 id="linear-attention-with-rotary-positional-embeddings">Linear Attention with Rotary Positional Embeddings</h2>
<p>MoLFormer&rsquo;s two key architectural choices are its attention mechanism and positional encoding scheme.</p>
<p><strong>Standard attention</strong> computes:</p>
<p>$$
\text{Attention}_m(Q, K, V) = \frac{\sum_{n=1}^{N} \exp(\langle q_m, k_n \rangle) v_n}{\sum_{n=1}^{N} \exp(\langle q_m, k_n \rangle)}
$$</p>
<p>MoLFormer replaces this with <strong>linear attention</strong> using a generalized feature map $\varphi$, combined with <strong>rotary positional embeddings</strong> $R_m$ applied before the feature map:</p>
<p>$$
\text{Attention}_m(Q, K, V) = \frac{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle v_n}{\sum_{n=1}^{N} \langle \varphi(R_m q_m), \varphi(R_n k_n) \rangle}
$$</p>
<p>This differs from the original RoFormer formulation, which applies the rotation after the feature map. The authors found that rotating the raw queries and keys before projection led to faster convergence and lower validation loss. The combination of linear attention and adaptive sequence-length bucketing reduces GPU requirements from ~1000 to 16 for training on the full 1.1B corpus.</p>
<p>The model uses masked language modeling (15% token masking, following BERT conventions) with a vocabulary of 2,362 SMILES tokens. Sequence length is capped at 202 tokens, covering 99.4% of all molecules.</p>
<h2 id="broad-moleculenet-benchmarking-with-scaling-ablations">Broad MoleculeNet Benchmarking with Scaling Ablations</h2>
<p>MoLFormer-XL was evaluated on 11 MoleculeNet tasks against supervised GNNs, self-supervised GNNs, and prior language models.</p>
<p><strong>Classification tasks</strong> (ROC-AUC, scaffold split; values reported as percentages in the original paper, converted to proportions here for consistency):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP</th>
          <th>Tox21</th>
          <th>ClinTox</th>
          <th>HIV</th>
          <th>BACE</th>
          <th>SIDER</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer-XL</td>
          <td><strong>0.937</strong></td>
          <td><strong>0.847</strong></td>
          <td><strong>0.948</strong></td>
          <td>0.822</td>
          <td>0.882</td>
          <td><strong>0.690</strong></td>
      </tr>
      <tr>
          <td>N-Gram</td>
          <td>0.912</td>
          <td>0.769</td>
          <td>0.855</td>
          <td>0.830</td>
          <td>0.876</td>
          <td>0.632</td>
      </tr>
      <tr>
          <td>MolCLR</td>
          <td>0.736</td>
          <td>0.798</td>
          <td>0.932</td>
          <td>0.806</td>
          <td><strong>0.890</strong></td>
          <td>0.680</td>
      </tr>
      <tr>
          <td>GEM</td>
          <td>0.724</td>
          <td>0.781</td>
          <td>0.901</td>
          <td>0.806</td>
          <td>0.856</td>
          <td>0.672</td>
      </tr>
      <tr>
          <td>Hu et al.</td>
          <td>0.708</td>
          <td>0.787</td>
          <td>0.789</td>
          <td>0.802</td>
          <td>0.859</td>
          <td>0.652</td>
      </tr>
      <tr>
          <td>GeomGCL</td>
          <td>-</td>
          <td>0.850</td>
          <td>0.919</td>
          <td>-</td>
          <td>-</td>
          <td>0.648</td>
      </tr>
      <tr>
          <td>ChemBERTa</td>
          <td>0.643</td>
          <td>-</td>
          <td>0.906</td>
          <td>0.622</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p><strong>Regression tasks</strong> (RMSE for ESOL/FreeSolv/Lipophilicity, avg MAE for QM9/QM8):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>QM9</th>
          <th>QM8</th>
          <th>ESOL</th>
          <th>FreeSolv</th>
          <th>Lipophilicity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MoLFormer-XL</td>
          <td><strong>1.5894</strong></td>
          <td><strong>0.0102</strong></td>
          <td><strong>0.2787</strong></td>
          <td><strong>0.2308</strong></td>
          <td><strong>0.5289</strong></td>
      </tr>
      <tr>
          <td>A-FP</td>
          <td>2.6355</td>
          <td>0.0282</td>
          <td>0.5030</td>
          <td>0.736</td>
          <td>0.578</td>
      </tr>
      <tr>
          <td>MPNN</td>
          <td>3.1898</td>
          <td>0.0143</td>
          <td>0.58</td>
          <td>1.150</td>
          <td>0.7190</td>
      </tr>
      <tr>
          <td>GC</td>
          <td>4.3536</td>
          <td>0.0148</td>
          <td>0.970</td>
          <td>1.40</td>
          <td>0.655</td>
      </tr>
  </tbody>
</table>
<p>MoLFormer-XL also outperforms geometry-aware GNNs (DimeNet, GeomGCL, GEM) on ESOL (0.279 vs 0.575), FreeSolv (0.231 vs 0.866), and Lipophilicity (0.529 vs 0.541).</p>
<p><strong>Key ablation findings</strong>:</p>
<ul>
<li><strong>Data scale matters</strong>: Performance improves monotonically from 10% subsets through the full 1.1B corpus. Training on 100% ZINC alone performed worst, likely due to its smaller vocabulary and less diverse molecule lengths.</li>
<li><strong>Model depth matters</strong>: MoLFormer-Base (6 layers) underperforms MoLFormer-XL (12 layers) on most tasks.</li>
<li><strong>Fine-tuning &raquo; frozen</strong>: Fine-tuning the full encoder consistently outperforms using frozen embeddings with a downstream classifier.</li>
<li><strong>Rotary &gt; absolute at scale</strong>: Rotary embeddings underperform absolute embeddings on smaller pretraining sets but overtake them once the corpus exceeds 1B molecules.</li>
</ul>
<h2 id="smiles-transformers-learn-molecular-geometry">SMILES Transformers Learn Molecular Geometry</h2>
<p>The most striking finding is that MoLFormer&rsquo;s attention patterns correlate with 3D interatomic distances, despite training only on 1D SMILES strings.</p>
<p>Using <a href="/notes/chemistry/datasets/qm9/">QM9</a> molecules with known 3D geometries, the authors computed cosine similarity between attention maps and spatial distance matrices across three distance categories:</p>
<table>
  <thead>
      <tr>
          <th>Distance Category</th>
          <th>Range</th>
          <th>Linear Attention (Rotary)</th>
          <th>Full Attention (Rotary)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Short</td>
          <td>$\leq$ 2 Å</td>
          <td>0.594-0.602</td>
          <td>0.598-0.615</td>
      </tr>
      <tr>
          <td>Medium</td>
          <td>2-4 Å</td>
          <td>0.724-0.730</td>
          <td>0.716-0.727</td>
      </tr>
      <tr>
          <td>Long</td>
          <td>4-10 Å</td>
          <td>0.209-0.211</td>
          <td>0.204-0.210</td>
      </tr>
  </tbody>
</table>
<p>The strong correlation in the short and medium categories indicates the model captures covalent bond connectivity and near-neighbor spatial relationships. Linear attention shows marginally higher cosine similarity than full attention on medium-range distances (0.724-0.730 vs 0.716-0.727), though the differences are small.</p>
<p>MoLFormer-XL embeddings also correlate more strongly with molecular fingerprint similarity (0.64 vs 0.48 for ChemBERTa) and maximum common subgraph size (-0.60 vs -0.44), confirming that the representations encode structural information.</p>
<p><strong>Limitations</strong>:</p>
<ul>
<li><strong>Quantum-chemical energies</strong>: SchNet and DimeNet (which encode explicit 3D geometry) outperform MoLFormer-XL on QM9 atomization energy tasks, with DimeNet achieving roughly 10x lower MAE on U0_atom (0.008 vs 0.083 eV). 3D information remains important for these properties.</li>
<li><strong>Sequence length cap</strong>: The 202-token limit excludes 0.6% of molecules, potentially limiting applicability to larger structures.</li>
<li><strong>SMILES canonicalization</strong>: The model depends on <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> canonical SMILES; sensitivity to non-canonical forms is not evaluated.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>PubChem</td>
          <td>111M molecules</td>
          <td>Canonical SMILES via RDKit</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>ZINC</td>
          <td>~1B molecules</td>
          <td>Canonical SMILES via RDKit</td>
      </tr>
      <tr>
          <td>Pretraining (combined)</td>
          <td>PubChem + ZINC</td>
          <td>~1.1B molecules</td>
          <td>MoLFormer-XL training set</td>
      </tr>
      <tr>
          <td>Classification</td>
          <td>BBBP, Tox21, ClinTox, HIV, BACE, SIDER</td>
          <td>1,427-41,127</td>
          <td>MoleculeNet scaffold splits</td>
      </tr>
      <tr>
          <td>Regression</td>
          <td>QM9, QM8, ESOL, FreeSolv, Lipophilicity</td>
          <td>642-133,885</td>
          <td>MoleculeNet random splits (QM9/QM8), scaffold (others)</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pretraining objective</strong>: Masked language modeling (15% selection: 80% masked, 10% random, 10% unchanged)</li>
<li><strong>Tokenization</strong>: SMILES tokenizer from Schwaller et al., vocabulary of 2,362 tokens</li>
<li><strong>Sequence length</strong>: 1-202 tokens (99.4% coverage)</li>
<li><strong>Optimizer</strong>: Fused LAMB (via APEX), chosen for stability with large batch sizes and no need for learning rate warm-up</li>
<li><strong>Adaptive bucketing</strong>: Sequences grouped by length into buckets to minimize padding waste</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Transformer encoder with linear attention and rotary positional embeddings</li>
<li><strong>MoLFormer-XL</strong>: 12 layers, 12 attention heads, hidden size 768</li>
<li><strong>MoLFormer-Base</strong>: 6 layers (ablation only)</li>
<li><strong>Feature map size</strong>: 32 (generalized feature map for linear attention)</li>
<li><strong>Frozen head</strong>: Fully connected model with hyperparameter sweep (learning rate, batch size, hidden dim, number of layers)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task Type</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ROC-AUC</td>
          <td>Classification</td>
          <td>Scaffold splits per MoleculeNet</td>
      </tr>
      <tr>
          <td>RMSE</td>
          <td>Regression (ESOL, FreeSolv, Lipophilicity)</td>
          <td>Scaffold splits</td>
      </tr>
      <tr>
          <td>Avg MAE</td>
          <td>Regression (QM9, QM8)</td>
          <td>Random splits per MoleculeNet</td>
      </tr>
  </tbody>
</table>
<p>QM9 results also reported with 5-fold cross-validation for robustness.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: GPU cluster with nodes containing either 8 NVIDIA Tesla V100 (32GB) or 8 Ampere A100 (40GB) GPUs connected via NVLink and InfiniBand</li>
<li><strong>GPU reduction</strong>: Linear attention + bucketing reduced GPU requirements from ~1000 to 16</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/molformer">IBM/molformer</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Pretraining, fine-tuning, and attention visualization</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm/MoLFormer-XL-both-10pct">MoLFormer-XL (HuggingFace)</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Pretrained weights (46.8M parameters)</td>
      </tr>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td>Dataset</td>
          <td>Public domain</td>
          <td>111M molecules</td>
      </tr>
      <tr>
          <td><a href="https://zinc.docking.org/">ZINC</a></td>
          <td>Dataset</td>
          <td>See ZINC terms</td>
          <td>~1B molecules</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ross, J., Belgodere, B., Chenthamarakshan, V., Padhi, I., Mroueh, Y., &amp; Das, P. (2022). Large-Scale Chemical Language Representations Capture Molecular Structure and Properties. <em>Nature Machine Intelligence</em>, 4, 1256-1264. <a href="https://doi.org/10.1038/s42256-022-00580-7">https://doi.org/10.1038/s42256-022-00580-7</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IBM/molformer">GitHub Repository (MoLFormer)</a></li>
<li><a href="https://huggingface.co/ibm/MoLFormer-XL-both-10pct">HuggingFace Models</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ross2022molformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Large-Scale Chemical Language Representations Capture Molecular Structure and Properties}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ross, Jerret and Belgodere, Brian and Chenthamarakshan, Vijil and Padhi, Inkit and Mroueh, Youssef and Das, Payel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1256--1264}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1038/s42256-022-00580-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Exposing Limitations of Molecular ML with Activity Cliffs</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/</link><pubDate>Mon, 16 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/property-prediction/activity-cliffs-benchmark/</guid><description>A benchmark of 24 ML methods on activity cliff compounds across 30 drug targets, showing descriptor-based models outperform deep learning.</description><content:encoded><![CDATA[<h2 id="a-benchmark-for-activity-cliff-prediction">A Benchmark for Activity Cliff Prediction</h2>
<p>This is a <strong>Systematization</strong> paper ($\Psi_{\text{Systematization}}$) with a significant <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>The paper systematically benchmarks 24 machine learning and deep learning approaches on their ability to predict bioactivity for activity cliff compounds: pairs of structurally similar molecules that exhibit large differences in potency. These cases violate the similarity principle (similar structure implies similar activity) and represent a practical failure mode for <a href="/notes/chemistry/molecular-design/property-prediction/">molecular property prediction</a> in drug discovery. The authors release MoleculeACE, an open-source benchmarking platform for evaluating ML models on activity cliffs.</p>
<h2 id="activity-cliffs-as-a-blind-spot-in-molecular-ml">Activity Cliffs as a Blind Spot in Molecular ML</h2>
<p>The <a href="https://en.wikipedia.org/wiki/Chemical_similarity">similarity principle</a> underpins most molecular ML: structurally similar compounds should have similar properties. Activity cliffs are the exceptions, where small structural changes cause large potency shifts (e.g., a single substituent change causing a 10x difference in $K_i$).</p>
<p>Despite their importance for <a href="https://en.wikipedia.org/wiki/Hit_to_lead">hit-to-lead optimization</a>, activity cliffs have received limited attention in ML benchmarking. Standard metrics like RMSE computed over entire test sets can mask poor predictions on cliff compounds. A model might achieve low overall error while systematically mispredicting these edge cases, which are precisely the molecules that matter most for medicinal chemistry applications.</p>
<p>The authors identify 7-52% of compounds as activity cliff molecules across their 30 target datasets, showing this is not a rare phenomenon.</p>
<h2 id="defining-and-detecting-activity-cliffs">Defining and Detecting Activity Cliffs</h2>
<p>The authors use three complementary similarity metrics to identify activity cliffs:</p>
<ol>
<li><strong>Substructure similarity</strong>: <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto coefficient</a> on extended connectivity fingerprints (ECFPs), capturing shared radial substructures</li>
<li><strong>Scaffold similarity</strong>: Tanimoto coefficient on ECFPs computed from molecular graph frameworks, detecting core/decoration differences</li>
<li><strong>SMILES similarity</strong>: <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a> on canonical <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, capturing character-level insertions, deletions, and translocations</li>
</ol>
<p>Pairs with $\geq 90%$ similarity on <strong>any one</strong> of the three metrics and $&gt; 10\times$ difference in bioactivity ($K_i$ or $\text{EC}_{50}$) are classified as activity cliff pairs. This union-based approach (rather than requiring agreement across all metrics) captures different types of structural relationships relevant to medicinal chemistry.</p>
<h2 id="24-methods-across-30-drug-targets">24 Methods Across 30 Drug Targets</h2>
<p>The benchmark evaluates 16 traditional ML configurations (4 algorithms $\times$ 4 descriptor types) and 8 deep learning approaches across 30 curated <a href="https://en.wikipedia.org/wiki/ChEMBL">ChEMBL</a> v29 datasets (48,707 total molecules).</p>
<p><strong>Traditional ML algorithms</strong>: KNN, RF, GBM, SVM, each combined with ECFPs, MACCS keys, WHIM descriptors, or physicochemical properties.</p>
<p><strong>Deep learning methods</strong>: MPNN, GCN, GAT, Attentive FP (graph-based), plus LSTM, CNN, Transformer/<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a> (SMILES-based), and an MLP on ECFPs.</p>
<p>Performance is measured with both standard RMSE and a dedicated $\text{RMSE}_{\text{cliff}}$ computed only on activity cliff compounds in the test set:</p>
<p>$$
\text{RMSE}_{\text{cliff}} = \sqrt{\frac{\sum_{j=1}^{n_c} (\hat{y}_j - y_j)^2}{n_c}}
$$</p>
<p>Key results:</p>
<ul>
<li><strong>Molecular descriptors matter more than algorithms</strong>: The choice of descriptor (ECFPs vs. MACCS vs. WHIM vs. physicochemical) had a larger impact on $\text{RMSE}_{\text{cliff}}$ than the choice of ML algorithm ($p &lt; 0.05$, <a href="https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test">Wilcoxon rank-sum test</a> with <a href="https://en.wikipedia.org/wiki/False_discovery_rate">Benjamini-Hochberg correction</a>).</li>
<li><strong>SVM + ECFPs wins on average</strong>: The best overall method for activity cliff prediction, though the difference from RF + ECFPs or GBM + ECFPs was not statistically significant.</li>
<li><strong>Deep learning underperforms</strong>: All graph and SMILES-based deep learning methods performed worse than a simple MLP on ECFPs. Among deep learning, LSTM with transfer learning (pretrained on 36K molecules) was the best, outperforming the ChemBERTa transformer pretrained on 10M compounds.</li>
<li><strong>Large case-by-case variation</strong>: $\text{RMSE}_{\text{cliff}}$ ranged from 0.62 to 1.60 log units across datasets, with no method consistently best. Deep learning methods showed the highest variance across targets.</li>
</ul>
<h2 id="simple-descriptors-beat-complex-architectures-on-cliffs">Simple Descriptors Beat Complex Architectures on Cliffs</h2>
<p>The core finding is that activity cliffs expose a gap in learned molecular representations. Despite graph neural networks and transformers being able to learn directly from molecular structure, they fail to capture the subtle structural differences that drive activity cliffs.</p>
<p>Key observations:</p>
<ul>
<li><strong>RMSE and $\text{RMSE}_{\text{cliff}}$ correlate ($r = 0.81$ on average)</strong>, so optimizing overall error usually helps with cliffs too. But this correlation breaks down for some targets (e.g., CLK4), where methods with similar RMSE can have very different $\text{RMSE}_{\text{cliff}}$.</li>
<li><strong>Training set size matters for the RMSE/$\text{RMSE}_{\text{cliff}}$ correlation</strong>: Datasets with $&gt; 1000$ training molecules show $r &gt; 0.80$ between the two metrics. In low-data regimes, the correlation weakens, making dedicated cliff evaluation more important.</li>
<li><strong>No relationship between % cliff compounds and model performance</strong>, and no target-family-specific effects were found.</li>
<li><strong>Transfer learning helped SMILES models (LSTM) but not graph models</strong>: Self-supervised pretraining strategies (context prediction, infomax, edge prediction, masking) did not improve GNN performance, consistent with findings from other studies.</li>
</ul>
<p>The MoleculeACE platform provides standardized data curation, activity cliff detection, and cliff-specific evaluation, enabling researchers to assess new methods against this benchmark.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Source</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Benchmarking</td>
          <td>ChEMBL v29</td>
          <td>48,707 molecules (35,632 unique) across 30 targets</td>
          <td>Curated for duplicates, salts, outliers</td>
      </tr>
      <tr>
          <td>Smallest dataset</td>
          <td>JAK1</td>
          <td>615 molecules</td>
          <td>7% activity cliffs</td>
      </tr>
      <tr>
          <td>Largest dataset</td>
          <td>DRD3</td>
          <td>3,657 molecules</td>
          <td>39% activity cliffs</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Activity cliff detection</strong>: Pairwise similarity $\geq 0.9$ (Tanimoto on ECFPs, scaffold ECFPs, or Levenshtein on SMILES) with $&gt; 10\times$ potency difference</li>
<li><strong>Splitting</strong>: <a href="https://en.wikipedia.org/wiki/Spectral_clustering">Spectral clustering</a> on ECFPs (5 clusters), 80/20 stratified split preserving cliff proportion</li>
<li><strong>Hyperparameter optimization</strong>: <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">Bayesian optimization</a> with Gaussian process, max 50 combinations, 5-fold cross-validation</li>
<li><strong>SMILES augmentation</strong>: 10-fold for all SMILES-based methods</li>
<li><strong>Transfer learning</strong>: LSTM pretrained on 36,281 merged training molecules (next-character prediction); ChemBERTa pretrained on 10M <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> compounds</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Traditional ML</strong>: KNN, RF, GBM, SVM (scikit-learn v1.0.2)</li>
<li><strong>Descriptors</strong>: ECFPs (1024-bit, radius 2), MACCS keys (166-bit), WHIM (114 descriptors), physicochemical (11 properties)</li>
<li><strong>GNNs</strong>: MPNN, GCN, GAT, AFP (PyTorch Geometric v2.0.4), with graph multiset transformer pooling</li>
<li><strong>SMILES models</strong>: LSTM (4 layers, 5.8M params), 1D CNN, ChemBERTa transformer</li>
<li><strong>Total models trained</strong>: 720 (24 methods $\times$ 30 targets)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RMSE</td>
          <td>All test molecules</td>
          <td>Standard root-mean-square error on $\text{pK}_i$ / $\text{pEC}_{50}$</td>
      </tr>
      <tr>
          <td>$\text{RMSE}_{\text{cliff}}$</td>
          <td>Activity cliff compounds only</td>
          <td>RMSE restricted to cliff molecules in test set</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/molML/MoleculeACE">MoleculeACE</a></td>
          <td>Code + Data</td>
          <td>MIT</td>
          <td>Benchmark platform with all 30 curated datasets</td>
      </tr>
      <tr>
          <td><a href="https://github.com/molML/MoleculeACE/tree/main/MoleculeACE/Data/benchmark_data">Curated datasets</a></td>
          <td>Data</td>
          <td>MIT</td>
          <td>Processed ChEMBL bioactivity data</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: van Tilborg, D., Alenicheva, A., &amp; Grisoni, F. (2022). Exposing the Limitations of Molecular Machine Learning with Activity Cliffs. <em>Journal of Chemical Information and Modeling</em>, 62(23), 5938-5951. <a href="https://doi.org/10.1021/acs.jcim.2c01073">https://doi.org/10.1021/acs.jcim.2c01073</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/molML/MoleculeACE">MoleculeACE GitHub Repository</a></li>
<li><a href="https://chemrxiv.org/engage/chemrxiv/article-details/630cc44058843b8403a19810">ChemRxiv Preprint</a></li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vantilborg2022activity,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exposing the Limitations of Molecular Machine Learning with Activity Cliffs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{van Tilborg, Derek and Alenicheva, Alisa and Grisoni, Francesca}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{62}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{5938--5951}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.2c01073}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Latent Diffusion Models for High-Res Image Synthesis</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/latent-diffusion-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/latent-diffusion-models/</guid><description>Latent Diffusion Models train diffusion in a compressed latent space, enabling high-res image synthesis with cross-attention conditioning at reduced compute.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It introduces Latent Diffusion Models (LDMs), which train denoising diffusion models in the latent space of pretrained autoencoders rather than directly in pixel space. The key insight is that separating perceptual compression from generative learning enables high-resolution image synthesis at a fraction of the computational cost of pixel-based diffusion. The paper also introduces a cross-attention conditioning mechanism for flexible multi-modal generation.</p>
<h2 id="computational-cost-of-pixel-space-diffusion">Computational Cost of Pixel-Space Diffusion</h2>
<p>Training diffusion models directly in pixel space is computationally expensive (150 to 1000 V100 GPU-days for leading models at the time) because the model must process high-dimensional RGB data at every denoising step. Much of this compute is spent modeling imperceptible high-frequency details. The authors observe that learning can be split into two stages: a perceptual compression stage that removes high-frequency detail, and a semantic compression stage where the generative model learns the conceptual composition. Prior two-stage approaches (VQGAN, DALL-E) relied on aggressive compression and autoregressive modeling in discrete latent spaces, trading off reconstruction quality for tractability.</p>
<h2 id="core-innovation-diffusion-in-latent-space">Core Innovation: Diffusion in Latent Space</h2>
<p>LDMs decompose image synthesis into two phases:</p>
<p><strong>Phase 1: Perceptual Compression.</strong> A pretrained autoencoder (encoder $\mathcal{E}$, decoder $\mathcal{D}$) maps images $x \in \mathbb{R}^{H \times W \times 3}$ to a lower-dimensional latent representation $z = \mathcal{E}(x) \in \mathbb{R}^{h \times w \times c}$ with spatial downsampling factor $f = H/h$. The autoencoder is trained with a perceptual loss (matching deep features from a pretrained VGG network) and a patch-based adversarial objective, with either KL or VQ regularization on the latent space.</p>
<p><strong>Phase 2: Latent Diffusion.</strong> A standard denoising diffusion model operates in this latent space. The training objective becomes:</p>
<p>$$L_{\text{LDM}} := \mathbb{E}_{\mathcal{E}(x), \epsilon \sim \mathcal{N}(0,1), t} \left[ \left| \epsilon - \epsilon_\theta(z_t, t) \right|_2^2 \right]$$</p>
<p>where $z_t$ is the noised latent at timestep $t$, and $\epsilon_\theta$ is a time-conditional UNet.</p>
<p><strong>Cross-Attention Conditioning.</strong> To enable conditioning on text, semantic maps, or other modalities, the authors introduce cross-attention layers into the UNet. A domain-specific encoder $\tau_\theta$ maps conditioning input $y$ to an intermediate representation $\tau_\theta(y) \in \mathbb{R}^{M \times d_\tau}$, which interacts with the UNet features via:</p>
<p>$$Q = W_Q^{(i)} \cdot \varphi_i(z_t), \quad K = W_K^{(i)} \cdot \tau_\theta(y), \quad V = W_V^{(i)} \cdot \tau_\theta(y)$$</p>
<p>The conditional objective then becomes:</p>
<p>$$L_{\text{LDM}} := \mathbb{E}_{\mathcal{E}(x), y, \epsilon \sim \mathcal{N}(0,1), t} \left[ \left| \epsilon - \epsilon_\theta(z_t, t, \tau_\theta(y)) \right|_2^2 \right]$$</p>
<p>Both $\tau_\theta$ and $\epsilon_\theta$ are optimized jointly.</p>
<h2 id="experimental-setup-and-results">Experimental Setup and Results</h2>
<p>The authors evaluate across multiple tasks and datasets:</p>
<p><strong>Perceptual compression tradeoffs.</strong> Downsampling factors $f \in {1, 2, 4, 8, 16, 32}$ are compared on ImageNet class-conditional generation. LDM-1 (pixel-based) trains slowly; LDM-32 loses too much information. LDM-4 and LDM-8 achieve the best balance, with LDM-8 outperforming pixel-based diffusion by 38 FID points after 2M training steps on a single A100.</p>
<p><strong>Unconditional image synthesis</strong> on CelebA-HQ 256, FFHQ 256, LSUN Churches/Bedrooms 256: LDM-4 achieves FID 5.11 on CelebA-HQ (state of the art at the time), outperforming LSGM, GANs, and other likelihood-based models. On LSUN-Bedrooms, LDM-4 achieves FID 2.95, close to ADM (1.90) with half the parameters and roughly 4x less training compute (see Appendix E.3.5).</p>
<p><strong>Text-to-image synthesis</strong> on MS-COCO: A 1.45B parameter LDM-KL-8 model trained on LAION-400M achieves FID 12.63 with classifier-free guidance (a technique that amplifies the conditioning signal at the cost of diversity, by interpolating between conditional and unconditional predictions) at scale s=1.5, on par with GLIDE (FID 12.24, 6B params) and Make-A-Scene (FID 11.84, 4B params) with substantially fewer parameters.</p>
<p><strong>Class-conditional ImageNet 256:</strong> LDM-4-G achieves FID 3.60, IS 247.67, outperforming ADM-G (FID 4.59) with fewer parameters and less compute.</p>
<p><strong>Super-resolution:</strong> LDM-4 (big) achieves FID 2.4 on ImageNet 64-to-256 upscaling (validation split), outperforming SR3 in FID.</p>
<p><strong>Inpainting</strong> on Places: LDM-4 (big, w/ ft) achieves FID 1.50, setting a new state of the art on image inpainting.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<ul>
<li>LDM-4 and LDM-8 offer the best tradeoff between perceptual compression and generation quality.</li>
<li>The autoencoder only needs to be trained once and can be reused across different diffusion models and tasks.</li>
<li>Cross-attention conditioning generalizes to text, semantic layouts, and bounding boxes without architecture changes.</li>
<li>Convolutional sampling enables generation at resolutions higher than the training resolution (up to 1024x1024).</li>
<li>Sequential sampling remains slower than GANs. The autoencoder reconstruction can become a bottleneck for tasks requiring pixel-level precision.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unconditional</td>
          <td>CelebA-HQ, FFHQ, LSUN</td>
          <td>256x256</td>
          <td>Standard benchmarks</td>
      </tr>
      <tr>
          <td>Class-conditional</td>
          <td>ImageNet</td>
          <td>256x256</td>
          <td>1000 classes</td>
      </tr>
      <tr>
          <td>Text-to-image</td>
          <td>LAION-400M</td>
          <td>256x256</td>
          <td>400M image-text pairs</td>
      </tr>
      <tr>
          <td>Inpainting</td>
          <td>Places</td>
          <td>256x256, 512x512</td>
          <td>Following LaMa protocol</td>
      </tr>
      <tr>
          <td>Super-resolution</td>
          <td>ImageNet</td>
          <td>64 to 256</td>
          <td>Following SR3 pipeline</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Autoencoder regularization</strong>: KL-reg (KL penalty toward standard normal, weighted by ~$10^{-6}$) or VQ-reg (vector quantization layer on the latent space with a learned codebook)</li>
<li><strong>Diffusion</strong>: Standard DDPM denoising with reweighted objective</li>
<li><strong>Sampling</strong>: DDIM sampler with configurable steps (100 to 500 depending on task)</li>
<li><strong>Guidance</strong>: Classifier-free diffusion guidance with scale $s$ (1.5 for class-conditional and text-to-image quantitative evaluation; 10.0 for qualitative text-to-image samples)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Autoencoder</strong>: Based on VQGAN architecture with perceptual + adversarial loss</li>
<li><strong>UNet backbone</strong>: Time-conditional with cross-attention layers at multiple resolutions</li>
<li><strong>Text encoder</strong>: BERT-tokenizer with transformer $\tau_\theta$ for LAION text-to-image model</li>
<li><strong>LDM-4-G</strong>: 400M parameters, $f=4$ downsampling</li>
<li><strong>LDM-KL-8 (text)</strong>: 1.45B parameters, $f=8$ downsampling, KL-regularized</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Best Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CelebA-HQ unconditional</td>
          <td>5.11</td>
          <td>500 DDIM steps</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet class-conditional</td>
          <td>3.60</td>
          <td>LDM-4-G, cfg s=1.5</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>MS-COCO text-to-image</td>
          <td>12.63</td>
          <td>LDM-KL-8-G, 250 steps, cfg s=1.5</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>Places inpainting</td>
          <td>1.50</td>
          <td>LDM-4 big, w/ ft</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet 4x super-resolution</td>
          <td>2.4</td>
          <td>LDM-4 big, 100 steps</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Perceptual compression tradeoff experiments: single NVIDIA A100</li>
<li>Inpainting model trained on eight V100</li>
<li>Training at least 2.7x faster than pixel-based diffusion at equal parameters</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/CompVis/latent-diffusion">CompVis/latent-diffusion</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained models</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rombach, R., Blattmann, A., Lorenz, D., Esser, P., &amp; Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. <em>CVPR 2022</em>. <a href="https://arxiv.org/abs/2112.10752">https://arxiv.org/abs/2112.10752</a></p>
<p><strong>Publication</strong>: CVPR 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{rombach2022highresolution,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{High-Resolution Image Synthesis with Latent Diffusion Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj{\&#34;o}rn}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>     = <span style="color:#e6db74">{10684--10695}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/CompVis/latent-diffusion">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>GraSP: Graph Recognition via Subgraph Prediction (2026)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/grasp-2026/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/grasp-2026/</guid><description>GraSP is a general image-to-graph framework using sequential subgraph prediction, applied to OCSR with 67.5% accuracy on QM9.</description><content:encoded><![CDATA[<h2 id="a-general-framework-for-visual-graph-recognition">A General Framework for Visual Graph Recognition</h2>
<p>GraSP (Graph Recognition via Subgraph Prediction) addresses a fundamental limitation in image-to-graph methods: existing solutions are task-specific and do not transfer between domains. Whether the task is OCSR, scene graph recognition, music notation parsing, or road network extraction, each domain has developed independent solutions despite solving the same conceptual problem of extracting a graph from an image.</p>
<p>The key insight is that graph recognition can be reformulated as sequential subgraph prediction using a binary classifier, sidestepping two core difficulties of using graphs as neural network outputs:</p>
<ol>
<li><strong>Graph isomorphism</strong>: An uncolored graph with $n$ nodes has $n!$ equivalent representations, making direct output comparison intractable</li>
<li><strong>Compositional outputs</strong>: Nodes, edges, and features are interdependent, so standard i.i.d. loss functions are insufficient</li>
</ol>
<h2 id="sequential-subgraph-prediction-as-an-mdp">Sequential Subgraph Prediction as an MDP</h2>
<p>GraSP formulates graph recognition as a Markov Decision Process. Starting from an empty graph, the method iteratively expands the current graph by adding one edge at a time (connecting either a new node or two existing nodes). At each step, a binary classifier predicts whether each candidate successor graph is a subgraph of the target graph shown in the image.</p>
<p>The critical observation is that the optimal value function $V^{\pi^*}$ satisfies:</p>
<p>$$V^{\pi^*}(\mathcal{G}_t | \mathcal{I}) = 1 \iff \mathcal{G}_t \subseteq \mathcal{G}_{\mathcal{I}}$$</p>
<p>This means the value function reduces to a subgraph membership test, which can be learned as a binary classifier rather than requiring reinforcement learning. Greedy decoding then suffices: at each step, select any successor that the classifier predicts is a valid subgraph, and terminate when the classifier indicates the current graph is complete.</p>
<p>This formulation decouples <strong>decision</strong> (what to add) from <strong>generation</strong> (in what order), making the same model applicable across different graph types without modification.</p>
<h2 id="architecture-gnn--film-conditioned-cnn">Architecture: GNN + FiLM-Conditioned CNN</h2>
<p>The architecture has three components:</p>
<ol>
<li>
<p><strong>GNN encoder</strong>: A Message Passing Neural Network processes the candidate subgraph, producing a graph embedding. Messages are constructed as concatenations of source node features, target node features, and connecting edge features.</p>
</li>
<li>
<p><strong>FiLM-conditioned CNN</strong>: A ResNet-v2 processes the image, with FiLM layers placed after every normalization layer within each block. The graph embedding conditions the image processing, producing a joint graph-image representation.</p>
</li>
<li>
<p><strong>MLP classification head</strong>: Takes the conditioned image embedding plus a binary terminal flag (indicating whether this is a termination check) and predicts subgraph membership.</p>
</li>
</ol>
<p>The model uses only 7.25M parameters. Group Normalization is used in the CNN (8 groups per layer), Layer Normalization in the GNN and MLP.</p>
<h2 id="training-via-streaming-data-generation">Training via Streaming Data Generation</h2>
<p>Training uses a streaming architecture rather than a fixed dataset:</p>
<ul>
<li>For each iteration, a target graph $\mathcal{G}_T$ is sampled and rendered as an image</li>
<li><strong>Positive samples</strong> are generated by deleting edges that do not disconnect the graph (yielding valid subgraphs)</li>
<li><strong>Negative samples</strong> are generated by expanding successor states and checking via approximate subgraph matching</li>
<li>Two FIFO buffers (one for positives, one for negatives), each holding up to 25,000 images, maintain diverse and balanced mini-batches of 1024 samples</li>
<li>Training uses the RAdam optimizer with a cosine learning rate schedule (warmup over 50M samples, cycle of 250M samples) on 4 A100 GPUs with a 24h budget</li>
</ul>
<h2 id="synthetic-benchmarks-on-colored-trees">Synthetic Benchmarks on Colored Trees</h2>
<p>GraSP is evaluated on increasingly complex synthetic tasks involving colored tree graphs:</p>
<ul>
<li><strong>Small trees (6-9 nodes)</strong>: Tasks with varying numbers of node colors (1, 3, 5) and edge colors (1, 3, 5). The model works well across all configurations, with simpler tasks (fewer colors) converging faster.</li>
<li><strong>Larger trees (10-15 nodes)</strong>: The same trends hold but convergence is slower due to increased structural complexity.</li>
<li><strong>Out-of-distribution generalization</strong>: Models trained on 6-9 node trees show zero-shot generalization to 10-node trees, indicating learned patterns are size-independent.</li>
</ul>
<h2 id="ocsr-evaluation-on-qm9">OCSR Evaluation on QM9</h2>
<p>For the real-world OCSR evaluation, GraSP is applied to <a href="/notes/chemistry/datasets/qm9/">QM9</a> molecular images (grayscale, no stereo-bonds) with a 10,000-molecule held-out test set:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>OSRA</td>
          <td>45.61%</td>
      </tr>
      <tr>
          <td>GraSP</td>
          <td>67.51%</td>
      </tr>
      <tr>
          <td>MolGrapher</td>
          <td>88.36%</td>
      </tr>
      <tr>
          <td>DECIMER</td>
          <td>92.08%</td>
      </tr>
  </tbody>
</table>
<p>GraSP does not match state-of-the-art OCSR tools, but the authors emphasize that the same model architecture and training procedure transfers directly from synthetic tree tasks to molecular graphs with no task-specific modifications. The only domain knowledge incorporated is a simple chemistry rule: not extending nodes that already have degree four.</p>
<p>The method highlights the practical advantage of decoupling decision from generation. Functional groups can be represented at different granularities (as single nodes to reduce trajectory depth, or expanded to reduce trajectory breadth) without changing the model.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/c72bcbf4/grasp">GraSP Code</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official implementation with pre-trained models</td>
      </tr>
  </tbody>
</table>
<p>The repository includes pre-trained models and example trajectories for interactive exploration. Training requires 4 A100 GPUs with a 24h time budget. The QM9 dataset used for OCSR evaluation is publicly available. No license file is included in the repository.</p>
<h2 id="limitations-and-future-directions">Limitations and Future Directions</h2>
<ul>
<li><strong>Finite type assumption</strong>: The current framework assumes a finite set of node and edge types, limiting applicability to open-vocabulary tasks like scene graph recognition</li>
<li><strong>Scaling to large graphs</strong>: For very large graphs, the branching factor of successor states becomes expensive. Learned filters to prune irrelevant successor states could help</li>
<li><strong>OCSR performance gap</strong>: While GraSP demonstrates transferability, it falls short of specialized OCSR tools that use domain-specific encodings (SMILES) or pixel-level supervision</li>
<li><strong>Modality extension</strong>: The framework could extend beyond images to other input modalities, such as vector embeddings of graphs</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Eberhard, A., Neumann, G., &amp; Friederich, P. (2026). Graph Recognition via Subgraph Prediction. <em>arXiv preprint arXiv:2601.15133</em>. <a href="https://arxiv.org/abs/2601.15133">https://arxiv.org/abs/2601.15133</a></p>
<p><strong>Publication</strong>: arXiv 2026</p>
]]></content:encoded></item><item><title>D3PM: Discrete Denoising Diffusion Probabilistic Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</guid><description>D3PMs extend diffusion models to discrete data with structured transition matrices, connecting diffusion to masked language models.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It extends denoising diffusion probabilistic models (DDPMs) from continuous to discrete state-spaces by introducing structured Markov transition matrices for the corruption process. The paper unifies several corruption strategies, draws a formal connection between absorbing-state diffusion and masked language models, and demonstrates competitive results on both image and text generation.</p>
<h2 id="diffusion-beyond-continuous-spaces">Diffusion Beyond Continuous Spaces</h2>
<p>Standard DDPMs operate in continuous state-spaces (e.g., pixel values treated as real numbers) and use Gaussian noise for corruption. Many important data types are inherently discrete: text (tokens from a vocabulary), quantized images (discrete pixel values), molecular structures, and segmentation maps. Prior work by Hoogeboom et al. extended binary diffusion to multinomial diffusion with uniform transition probabilities, but this limits the structure of the corruption process. D3PMs generalize this by allowing arbitrary transition matrices that encode domain-specific inductive biases.</p>
<h2 id="core-innovation-structured-transition-matrices">Core Innovation: Structured Transition Matrices</h2>
<p>D3PMs define a forward corruption process over discrete variables $\mathbf{x} \in {1, \ldots, K}^D$ using transition matrices $\mathbf{Q}_t \in \mathbb{R}^{K \times K}$:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_{t-1} \mathbf{Q}_t)$$</p>
<p>where $\mathbf{x}_{t-1}$ is a one-hot row vector. The cumulative transition after $t$ steps is $\overline{\mathbf{Q}}_t = \mathbf{Q}_1 \mathbf{Q}_2 \cdots \mathbf{Q}_t$, giving:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_0) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_0 \overline{\mathbf{Q}}_t)$$</p>
<p>The paper explores several transition matrix designs:</p>
<p><strong>Uniform diffusion:</strong> $[\mathbf{Q}_t]_{ij} = (1 - \beta_t) \mathbf{1}_{i=j} + \beta_t / K$. Transitions with equal probability to any state. Stationary distribution is uniform.</p>
<p><strong>Absorbing state:</strong> In absorbing-state diffusion, each non-mask token transitions to the mask state with probability $\beta_t$ per step, while tokens already at the mask state remain there:</p>
<p>$[\mathbf{Q}_t]_{ij} = (1-\beta_t)\mathbf{1}_{i=j\neq m} + \beta_t \mathbf{1}_{j=m} + \mathbf{1}_{i=j=m}$. Each token transitions to a designated absorbing state $m$ (e.g., [MASK] for text, gray pixel for images) with probability $\beta_t$. This establishes a direct connection to masked language models like BERT.</p>
<p><strong>Discretized Gaussian:</strong> Transition probabilities decay as a function of the distance $|i-j|$ between states, mimicking Gaussian diffusion on ordinal data like pixel values.</p>
<p><strong>Embedding-based nearest neighbor:</strong> For text, transitions are weighted by proximity in a pretrained word embedding space, so corruption preferentially swaps words with semantically similar ones.</p>
<p><strong>Training objective.</strong> The reverse process $p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)$ is parameterized by predicting $\tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$ and computing the posterior:</p>
<p>$$p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) \propto \sum_{\tilde{\mathbf{x}}_0} q(\mathbf{x}_{t-1} | \mathbf{x}_t, \tilde{\mathbf{x}}_0) , \tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$$</p>
<p>The loss combines the variational lower bound (VLB) with an auxiliary cross-entropy loss $L_\lambda$:</p>
<p>$$L = L_{\text{VLB}} + \lambda , L_{\text{CE}}$$</p>
<p>where $L_{\text{CE}}$ is a reweighted cross-entropy loss on the $\mathbf{x}_0$ prediction that stabilizes training and improves sample quality. The VLB decomposes into per-timestep KL divergences between the true and predicted reverse transitions.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<p><strong>Image generation (CIFAR-10):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Loss</th>
          <th>IS</th>
          <th>FID</th>
          <th>NLL (bpd)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM uniform</td>
          <td>$L_{\text{VLB}}$</td>
          <td>5.99</td>
          <td>51.27</td>
          <td>5.08</td>
      </tr>
      <tr>
          <td>D3PM absorbing</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>6.78</td>
          <td>30.97</td>
          <td>4.40</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_{\text{VLB}}$</td>
          <td>7.75</td>
          <td>15.30</td>
          <td>3.97</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.54</td>
          <td>8.34</td>
          <td>3.98</td>
      </tr>
      <tr>
          <td>D3PM Gauss + logistic</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.56</td>
          <td>7.34</td>
          <td>3.44</td>
      </tr>
      <tr>
          <td>DDPM $L_{\text{simple}}$ (continuous)</td>
          <td>&ndash;</td>
          <td>9.46</td>
          <td>3.17</td>
          <td>3.75</td>
      </tr>
  </tbody>
</table>
<p>The best discrete D3PM variant is D3PM Gauss + logistic, which achieves FID 7.34 and NLL 3.44 bpd using the combined $L_\lambda$ loss with a truncated logistic parameterization. The truncated logistic parameterization replaces the standard softmax output with a discretized logistic distribution over pixel values, assigning probability mass to each discrete bin based on a continuous logistic CDF. This provides a smoother output distribution that better captures the ordinal structure of pixel intensities. This variant exceeds the continuous DDPM in log-likelihood (3.44 vs. 3.75 bpd) while approaching its sample quality (FID 7.34 vs. 3.17).</p>
<p><strong>Text generation (text8, character-level, 1000 steps):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>bpc</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM absorbing ($L_\lambda$)</td>
          <td>1.45</td>
      </tr>
      <tr>
          <td>D3PM NN ($L_{\text{VLB}}$)</td>
          <td>1.59</td>
      </tr>
      <tr>
          <td>D3PM uniform</td>
          <td>1.61</td>
      </tr>
      <tr>
          <td>Discrete Flow (Tran et al.)</td>
          <td>1.23</td>
      </tr>
  </tbody>
</table>
<p>Among the D3PM variants and baselines evaluated, D3PM absorbing achieves the best bpc on text8 apart from Discrete Flow (Tran et al., 2019). On LM1B (sentencepiece vocabulary of 8192 tokens), D3PM absorbing achieves a perplexity of 76.9 at 1000 steps, compared to 137.9 for D3PM uniform and 43.6 for a comparable autoregressive transformer, demonstrating that discrete diffusion scales to large vocabularies.</p>
<p><strong>Ablation findings:</strong></p>
<ul>
<li>The auxiliary cross-entropy loss $L_\lambda$ is critical: for D3PM Gauss, it improves FID from 15.30 ($L_{\text{VLB}}$) to 8.34 ($L_\lambda$, $\lambda{=}0.001$). Adding the truncated logistic parameterization further improves FID to 7.34.</li>
<li>Discretized Gaussian transitions outperform both uniform and absorbing-state transitions on CIFAR-10 across all metrics.</li>
<li>For text, the absorbing-state (mask) model outperforms uniform and nearest-neighbor models. Nearest-neighbor diffusion provides only marginal improvement over uniform, a surprising negative result.</li>
<li>The $\mathbf{x}_0$-parameterization ensures the learned reverse distribution has the correct sparsity pattern dictated by the transition matrix $\mathbf{Q}_t$.</li>
</ul>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>The choice of transition matrix is an important design decision that encodes domain-specific inductive biases. Discretized Gaussian transitions work best for ordinal image data; absorbing-state transitions work best for text.</li>
<li>D3PMs formally unify diffusion models and masked language models: absorbing-state diffusion with a [MASK] token is equivalent to a reweighted BERT-style training objective.</li>
<li>The combined VLB + auxiliary loss ($L_\lambda$) achieves better density estimation (3.44 bpd) than continuous DDPMs (3.75 bpd) while producing competitive samples.</li>
<li>Sample quality (best FID 7.34 for D3PM Gauss + logistic) still lags behind continuous-space DDPMs (FID 3.17) on CIFAR-10, though the gap narrows with structured transitions and the auxiliary loss.</li>
<li>Scaling to very large numbers of categories $K$ requires special techniques (low-rank corruption or matrix exponentials) to manage the $O(K^2 T)$ memory cost of storing transition matrices.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Image generation</td>
          <td>CIFAR-10</td>
          <td>32x32, 256 categories</td>
          <td>Quantized to 256 ordinal values per channel</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>text8</td>
          <td>Character-level</td>
          <td>27 character vocabulary, sequences of length 256</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>LM1B</td>
          <td>Word-level</td>
          <td>Sentencepiece vocabulary of 8192 tokens, sequence length 128</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise schedules</strong>: Linear schedule for D3PM Gauss, cosine schedule for D3PM uniform, and a novel mutual information schedule for absorbing and nearest-neighbor models</li>
<li><strong>Reverse parameterization</strong>: $\mathbf{x}_0$-parameterization with posterior computation via Bayes&rsquo; rule</li>
<li><strong>Loss</strong>: $L_{\text{VLB}} + \lambda L_{\text{CE}}$ with $\lambda = 0.001$ for images and $\lambda = 0.01$ for text absorbing models</li>
<li><strong>Scaling</strong>: Low-rank corruption (absorbing, uniform) scales as $O(r^2 T)$; matrix exponentials for nearest-neighbor transitions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Image models</strong>: Modified U-Net architecture from Ho et al. (2020) adapted for categorical output via softmax over $K$ classes</li>
<li><strong>Text models</strong>: 12-layer <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style transformer encoder with 70M parameters (12 heads, MLP dim 3072, QKV dim 768)</li>
<li><strong>Timesteps</strong>: $T = 1000$ for both images and text, though text models can be evaluated with fewer steps (e.g., 256 or 20)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>Best D3PM</th>
          <th>Continuous DDPM</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>7.34 (Gauss + logistic)</td>
          <td>3.17</td>
      </tr>
      <tr>
          <td>NLL (bpd)</td>
          <td>CIFAR-10</td>
          <td>3.44 (Gauss + logistic)</td>
          <td>3.75</td>
      </tr>
      <tr>
          <td>BPC</td>
          <td>text8 (char)</td>
          <td>1.45 (absorbing, $L_\lambda$)</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Perplexity</td>
          <td>LM1B</td>
          <td>76.9 (absorbing)</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All models trained for 1M steps with batch size 512 on TPUv2 or TPUv3</li>
<li>Text models: 12-layer transformer encoder (T5 architecture), 70M parameters</li>
<li>Image models: Modified U-Net architecture from Ho et al. (2020)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/google-research/tree/master/d3pm">google-research/d3pm</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX/Flax implementation for image and text experiments</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Austin, J., Johnson, D. D., Ho, J., Tarlow, D., &amp; van den Berg, R. (2021). Structured Denoising Diffusion Models in Discrete State-Spaces. <em>NeurIPS 2021</em>. <a href="https://arxiv.org/abs/2107.03006">https://arxiv.org/abs/2107.03006</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{austin2021structured,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Structured Denoising Diffusion Models in Discrete State-Spaces}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Austin, Jacob and Johnson, Daniel D. and Ho, Jonathan and Tarlow, Daniel and van den Berg, Rianne}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>Consistency Models: Fast One-Step Diffusion Generation</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</guid><description>Consistency models enable one-step generation by learning to map any point on a diffusion ODE trajectory to its origin, achieving FID 3.55 on CIFAR-10.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It proposes consistency models, a new class of generative models designed for fast one-step (or few-step) generation. The models can be trained either by distilling pretrained diffusion models (consistency distillation) or as standalone generative models from scratch (consistency training). The paper provides theoretical analysis of both training modes and achieves FID 3.55 on CIFAR-10 for single-step non-adversarial generation (state of the art at the time of publication).</p>
<h2 id="the-slow-sampling-problem-in-diffusion">The Slow Sampling Problem in Diffusion</h2>
<p>Diffusion models produce high-quality samples but require iterating through many denoising steps (often tens to hundreds), making generation slow compared to GANs or VAEs. Previous approaches to speed up sampling include faster ODE/SDE solvers (DDIM, DPM-Solver) and progressive distillation. These either still require multiple steps or depend on a complex multi-stage distillation pipeline. The goal is a model that can generate high-quality samples in a single forward pass while optionally allowing more steps for better quality.</p>
<h2 id="core-innovation-the-self-consistency-property">Core Innovation: The Self-Consistency Property</h2>
<p>The key idea builds on the Probability Flow (PF) ODE from the score-based SDE framework. The PF ODE describes a deterministic trajectory that converts noise into data, governed by the learned score function. For the VE-SDE parameterization used by EDM (Karras et al., 2022), this takes the form:</p>
<p>$$\frac{d\mathbf{x}_t}{dt} = -t , s_\phi(\mathbf{x}_t, t)$$</p>
<p>where $s_\phi$ is a pretrained score model, a <strong>consistency function</strong> $f(\mathbf{x}_t, t)$ maps any point on an ODE trajectory to the trajectory&rsquo;s origin $\mathbf{x}_\epsilon$. The defining property is self-consistency:</p>
<p>$$f(\mathbf{x}_t, t) = f(\mathbf{x}_{t&rsquo;}, t&rsquo;) \quad \text{for all } t, t&rsquo; \in [\epsilon, T]$$</p>
<p>for any points $\mathbf{x}_t$ and $\mathbf{x}_{t&rsquo;}$ on the same PF ODE trajectory.</p>
<p><strong>Parameterization.</strong> The model enforces the boundary condition $f(\mathbf{x}_\epsilon, \epsilon) = \mathbf{x}_\epsilon$ using skip connections:</p>
<p>$$f_\theta(\mathbf{x}, t) = c_{\text{skip}}(t) , \mathbf{x} + c_{\text{out}}(t) , F_\theta(\mathbf{x}, t)$$</p>
<p>where $c_{\text{skip}}(\epsilon) = 1$ and $c_{\text{out}}(\epsilon) = 0$, ensuring the boundary condition is satisfied by construction.</p>
<p><strong>Consistency Distillation (CD).</strong> Given a pretrained diffusion model, CD trains a consistency model by enforcing self-consistency between adjacent timesteps:</p>
<p>$$\mathcal{L}_{\text{CD}}^N(\theta, \theta^-; \phi) = \mathbb{E}\left[\lambda(t_n) , d!\left(f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1}), , f_{\theta^-}(\hat{\mathbf{x}}_{t_n}^\phi, t_n)\right)\right]$$</p>
<p>where $\hat{\mathbf{x}}_{t_n}^\phi$ is obtained by running one step of the ODE solver using the pretrained score model, $\theta^-$ is an exponential moving average (EMA) of $\theta$, and $d(\cdot, \cdot)$ is a distance metric. The use of a target network $\theta^-$ (updated via EMA) parallels techniques from deep Q-learning and momentum contrastive learning.</p>
<p><strong>Consistency Training (CT).</strong> CT eliminates the need for a pretrained diffusion model. It replaces the ODE solver step with a score estimate derived from the denoising score matching identity:</p>
<p>$$\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) = \mathbb{E}\left[\frac{\mathbf{x} - \mathbf{x}_t}{t^2} ,\middle|, \mathbf{x}_t\right]$$</p>
<p>Because this identity lets us estimate the score from noisy data alone (without a pretrained model), we can compute the ODE update directly from training samples. This allows training directly on data pairs $(\mathbf{x}, \mathbf{x} + t\mathbf{z})$ where $\mathbf{z} \sim \mathcal{N}(0, I)$.</p>
<p><strong>Theoretical guarantee.</strong> If CD achieves zero loss, the consistency model error is bounded by $O((\Delta t)^p)$ where $\Delta t$ is the maximum timestep gap and $p$ is the order of the ODE solver.</p>
<h2 id="experiments-and-benchmarks">Experiments and Benchmarks</h2>
<p><strong>Datasets:</strong> CIFAR-10 (32x32), ImageNet 64x64, LSUN Bedroom 256x256, LSUN Cat 256x256.</p>
<p><strong>Architecture:</strong> All models use the NCSN++/EDM architecture. CD distills from pretrained EDM models.</p>
<p><strong>Key results for consistency distillation (CD):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>3.55</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>2.93</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>6.20</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>4.70</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>7.80</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>2</td>
          <td>5.22</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>11.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>2</td>
          <td>8.84</td>
      </tr>
  </tbody>
</table>
<p>CD outperforms progressive distillation (PD) across all datasets and sampling steps, with the exception of single-step generation on Bedroom 256x256 where CD with $\ell_2$ slightly underperforms PD with $\ell_2$.</p>
<p><strong>Key results for consistency training (CT):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>8.70</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>5.83</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>13.0</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>11.1</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>16.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>20.7</td>
      </tr>
  </tbody>
</table>
<p>CT outperforms existing single-step non-adversarial models (VAEs, normalizing flows), e.g., improving over DC-VAE&rsquo;s FID of 17.90 on CIFAR-10. Samples from CT share structural similarity with EDM samples from the same initial noise, suggesting CT does not suffer from mode collapse.</p>
<p><strong>Zero-shot editing:</strong> Consistency models support colorization, super-resolution, inpainting, stroke-guided generation, interpolation, and denoising at test time without task-specific training, by modifying the multi-step sampling algorithm.</p>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>Consistency distillation achieves state-of-the-art FID for one-step generation (3.55 on CIFAR-10, 6.20 on ImageNet 64x64).</li>
<li>Multi-step sampling provides a smooth quality-compute tradeoff: more steps yield better FID.</li>
<li>CT produces competitive results without any pretrained diffusion model, making consistency models a standalone generative model family.</li>
<li>The LPIPS distance metric $d(\cdot, \cdot)$ generally outperforms $\ell_1$ and $\ell_2$ for training consistency models.</li>
<li>At higher resolutions (LSUN 256x256), the gap between CD/CT and full EDM sampling widens.</li>
<li>CT currently underperforms CD, suggesting room for improvement in the standalone training paradigm.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Primary benchmark</td>
          <td>CIFAR-10</td>
          <td>32x32, 50K train</td>
          <td>FID on 50K samples</td>
      </tr>
      <tr>
          <td>Scaling benchmark</td>
          <td>ImageNet 64x64</td>
          <td>64x64, 1.28M</td>
          <td>Unconditional generation</td>
      </tr>
      <tr>
          <td>High-res benchmark</td>
          <td>LSUN Bedroom, Cat</td>
          <td>256x256</td>
          <td>Unconditional generation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>ODE solver for CD</strong>: Euler and Heun (2nd order) solvers on the empirical PF ODE</li>
<li><strong>EMA for target network</strong>: Decay rate $\mu$ scheduled as a function of training step</li>
<li><strong>Schedule functions</strong>: $N$ (number of discretization steps) and $\mu$ (EMA rate) increase over training following specific schedules (see Appendix C of the paper)</li>
<li><strong>Distance metric</strong>: LPIPS performs best; $\ell_2$ and $\ell_1$ also evaluated</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: NCSN++/EDM architecture from Karras et al. (2022)</li>
<li><strong>CD teacher</strong>: Pretrained EDM models</li>
<li><strong>Parameterization</strong>: Skip-connection formulation with $c_{\text{skip}}(t)$ and $c_{\text{out}}(t)$ from EDM</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>CD 1-step</th>
          <th>CT 1-step</th>
          <th>EDM (full)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>3.55</td>
          <td>8.70</td>
          <td>2.04</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet 64</td>
          <td>6.20</td>
          <td>13.0</td>
          <td>2.44</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Bedroom</td>
          <td>7.80</td>
          <td>16.0</td>
          <td>3.57</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Cat</td>
          <td>11.0</td>
          <td>20.7</td>
          <td>6.69</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training details follow EDM conventions</li>
<li>CD and CT use the same batch sizes and learning rate schedules as EDM training</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/openai/consistency_models">openai/consistency_models</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained checkpoints</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Dhariwal, P., Chen, M., &amp; Sutskever, I. (2023). Consistency Models. <em>ICML 2023</em>. <a href="https://arxiv.org/abs/2303.01469">https://arxiv.org/abs/2303.01469</a></p>
<p><strong>Publication</strong>: ICML 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2023consistency,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Consistency Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://arxiv.org/abs/2303.01469}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/openai/consistency_models">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>AdaptMol: Domain Adaptation for Molecular OCSR (2026)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/adaptmol-2026/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/adaptmol-2026/</guid><description>AdaptMol is an image-to-graph OCSR model using MMD-based domain adaptation and self-training for hand-drawn molecule recognition.</description><content:encoded><![CDATA[<h2 id="bridging-the-synthetic-to-real-gap-in-graph-based-ocsr">Bridging the Synthetic-to-Real Gap in Graph-Based OCSR</h2>
<p>Most OCSR methods are trained on synthetic molecular images and evaluated on high-quality literature figures, both exhibiting relatively uniform styles. Hand-drawn molecules represent a particularly challenging domain with irregular bond lengths, variable stroke widths, and inconsistent atom symbols. Prior graph reconstruction methods like MolScribe and MolGrapher drop below 15% accuracy on hand-drawn images, despite achieving over 65% on literature datasets.</p>
<p>AdaptMol addresses this with a three-stage pipeline that enables effective transfer from synthetic to real-world data without requiring graph annotations in the target domain:</p>
<ol>
<li><strong>Base model training</strong> on synthetic data with comprehensive augmentation and dual position representation</li>
<li><strong>MMD alignment</strong> of bond-level features between source and target domains</li>
<li><strong>Self-training</strong> with SMILES-validated pseudo-labels on unlabeled target images</li>
</ol>
<h2 id="end-to-end-graph-reconstruction-architecture">End-to-End Graph Reconstruction Architecture</h2>
<p>AdaptMol builds on MolScribe&rsquo;s architecture, using a Swin Transformer base encoder ($384 \times 384$ input) with a 6-layer Transformer decoder (8 heads, hidden dim 256). The model jointly predicts atoms and bonds:</p>
<p><strong>Atom prediction</strong> follows the Pix2Seq approach, autoregressively generating a sequence of atom tokens:</p>
<p>$$S_N = [l_1, x_1, y_1, l_2, x_2, y_2, \dots, l_n, x_n, y_n]$$</p>
<p>where $l_i$ is the atom label and $(x_i, y_i)$ are discretized coordinate bin indices.</p>
<p><strong>Dual position representation</strong> adds a 2D spatial heatmap on top of token-based coordinate prediction. The heatmap aggregates joint spatial distributions of all atoms:</p>
<p>$$\mathbf{H} = \text{Upsample}\left(\sum_{i=1}^{n} P_y^{(i)} \otimes P_x^{(i)}\right)$$</p>
<p>where $P_x^{(i)}$ and $P_y^{(i)}$ are coordinate probability distributions from the softmax logits. During training, this heatmap is supervised with Gaussian kernels at ground-truth atom positions. This reduces false positive atom predictions substantially (from 356 to 33 false positives at IoU 0.05).</p>
<p><strong>Bond prediction</strong> extracts atom-level features from decoder hidden states and enriches them with encoder visual features via multi-head attention with a learnable residual weight $\alpha$:</p>
<p>$$\mathbf{F}_{\text{enriched}} = \text{LayerNorm}(\mathbf{F}_{\text{atom}} + \alpha \cdot \text{MHA}(\mathbf{F}_{\text{atom}}, \mathbf{E}_{\text{vis}}))$$</p>
<p>A feed-forward network then predicts bond types between all atom pairs.</p>
<h2 id="bond-level-domain-adaptation-via-mmd">Bond-Level Domain Adaptation via MMD</h2>
<p>The key insight is that bond features are domain-invariant: they encode structural relationships (single, double, triple, aromatic) independent of visual style. Atom-level alignment is problematic due to class imbalance (carbon dominates), multi-token spanning (functional groups), and position-dependent features.</p>
<p>AdaptMol aligns bond-level feature distributions via class-conditional Maximum Mean Discrepancy:</p>
<p>$$L_{\text{MMD}} = \frac{1}{|\mathcal{C}&rsquo;|} \sum_{c \in \mathcal{C}&rsquo;} MMD(F_c^{\text{src}}, F_c^{\text{tgt}})$$</p>
<p>where $\mathcal{C}&rsquo;$ contains classes with sufficient samples in both domains. Confidence-based filtering retains only high-confidence predictions (confidence &gt; 0.95, entropy &lt; 0.1) for alignment, tightening to 0.98 and 0.05 after the first epoch. Progressive loss weighting follows a schedule of 0.1 (epoch 0), 0.075 (epoch 1), and 0.05 thereafter.</p>
<p>An important side effect: MMD alignment improves inter-class bond discrimination, reducing confusion between visually similar bond types (e.g., jagged double bonds vs. aromatic bonds).</p>
<h2 id="self-training-with-smiles-validation">Self-Training with SMILES Validation</h2>
<p>After MMD alignment, the model generates predictions on unlabeled target images. Predicted molecular graphs are converted to SMILES and validated against ground-truth SMILES annotations. Only exact matches are retained as pseudo-labels, providing complete graph supervision (atom coordinates, element types, bond types) that was previously unavailable in the target domain.</p>
<p>This approach is far more data-efficient than alternatives: AdaptMol uses only 4,080 real hand-drawn images vs. DECIMER-Handdraw&rsquo;s 38 million synthetic hand-drawn images.</p>
<h2 id="comprehensive-data-augmentation">Comprehensive Data Augmentation</h2>
<p>Two categories of augmentation are applied during synthetic data generation:</p>
<ul>
<li><strong>Structure-rendering augmentation</strong>: Functional group abbreviation substitution, bond type conversions (single to wavy/aromatic, Kekule to aromatic rings), R-group insertion, and rendering parameter randomization (font family/size, bond width/spacing)</li>
<li><strong>Image-level augmentation</strong>: Geometric operations, quality degradation, layout variations, and chemical document artifacts (caption injection, arrows, marginal annotations)</li>
</ul>
<p>Structure-rendering augmentation provides the larger benefit, contributing ~20% accuracy improvement on JPO and ~30% on ACS benchmarks.</p>
<h2 id="results">Results</h2>
<h3 id="hand-drawn-molecule-recognition">Hand-Drawn Molecule Recognition</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>DECIMER test (Acc)</th>
          <th>ChemPix (Acc)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>AdaptMol</strong></td>
          <td><strong>82.6</strong></td>
          <td><strong>60.5</strong></td>
      </tr>
      <tr>
          <td>DECIMER v2.2</td>
          <td>71.9</td>
          <td>51.4</td>
      </tr>
      <tr>
          <td>AtomLenz</td>
          <td>30.0</td>
          <td>48.4</td>
      </tr>
      <tr>
          <td>MolScribe</td>
          <td>10.1</td>
          <td>26.1</td>
      </tr>
      <tr>
          <td>MolGrapher</td>
          <td>10.7</td>
          <td>14.5</td>
      </tr>
  </tbody>
</table>
<h3 id="literature-and-synthetic-benchmarks">Literature and Synthetic Benchmarks</h3>
<p>AdaptMol achieves state-of-the-art on 4 of 6 literature benchmarks:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>AdaptMol</th>
          <th>MolScribe</th>
          <th>MolGrapher</th>
          <th>DECIMER v2.2</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CLEF</td>
          <td><strong>92.7</strong></td>
          <td>87.5</td>
          <td>57.2</td>
          <td>77.7</td>
      </tr>
      <tr>
          <td>JPO</td>
          <td><strong>88.2</strong></td>
          <td>78.8</td>
          <td>73.0</td>
          <td>75.7</td>
      </tr>
      <tr>
          <td>UOB</td>
          <td><strong>89.3</strong></td>
          <td>88.2</td>
          <td>85.1</td>
          <td>87.2</td>
      </tr>
      <tr>
          <td>ACS</td>
          <td><strong>75.5</strong></td>
          <td>72.8</td>
          <td>41.0</td>
          <td>37.7</td>
      </tr>
      <tr>
          <td>USPTO</td>
          <td>90.9</td>
          <td><strong>92.6</strong></td>
          <td>74.9</td>
          <td>59.6</td>
      </tr>
      <tr>
          <td>Staker</td>
          <td>84.0</td>
          <td><strong>84.4</strong></td>
          <td>0.0</td>
          <td>66.3</td>
      </tr>
  </tbody>
</table>
<p>MolScribe edges out on USPTO and Staker. The authors attribute this to MolScribe directly training on all 680K USPTO samples, which may cause it to specialize to that distribution.</p>
<h3 id="pipeline-ablation">Pipeline Ablation</h3>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Hand-drawn</th>
          <th>ChemDraw</th>
          <th>JPO</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Base model</td>
          <td>10.4</td>
          <td>92.3</td>
          <td>82.7</td>
      </tr>
      <tr>
          <td>+ Font augmentation</td>
          <td>30.2</td>
          <td>92.5</td>
          <td>82.8</td>
      </tr>
      <tr>
          <td>+ Font aug + MMD</td>
          <td>42.1</td>
          <td>94.0</td>
          <td>83.0</td>
      </tr>
      <tr>
          <td>+ Font aug + MMD + Self-training</td>
          <td><strong>82.6</strong></td>
          <td><strong>95.9</strong></td>
          <td><strong>88.2</strong></td>
      </tr>
  </tbody>
</table>
<p>Each component contributes meaningfully: font augmentation (+19.8), MMD alignment (+11.9), and self-training (+40.5) on hand-drawn accuracy.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/fffh1/AdaptMol">AdaptMol Code</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/fffh1/AdaptMol/tree/main">Model + Data</a></td>
          <td>Model/Dataset</td>
          <td>MIT</td>
          <td>Pretrained checkpoint and datasets</td>
      </tr>
  </tbody>
</table>
<p>Training uses 2 NVIDIA A100 GPUs (40GB each). Base model trains for 30 epochs on 1M synthetic samples. Domain adaptation involves 3 steps: USPTO self-training (3 iterations of 3 epochs), MMD alignment on hand-drawn data (5 epochs), and hand-drawn self-training (5 iterations).</p>
<h2 id="limitations">Limitations</h2>
<ul>
<li>Sequence length constraints prevent accurate prediction of very large molecules (&gt;120 atoms), where resizing causes significant information loss</li>
<li>Cannot recognize Markush structures with repeating unit notation (parentheses/brackets), as synthetic training data lacks such cases</li>
<li>Stereochemistry information is lost when stereo bonds connect to abbreviated functional groups due to RDKit post-processing limitations</li>
<li>The retrained baseline (30 epochs from scratch on synthetic + pseudo-labels) achieves higher hand-drawn accuracy (87.2%) but at the cost of cross-domain robustness on literature benchmarks</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hu, F., He, E., &amp; Verspoor, K. (2026). AdaptMol: Domain Adaptation for Molecular Image Recognition with Limited Supervision. <em>Research Square preprint</em>. <a href="https://doi.org/10.21203/rs.3.rs-8365561/v1">https://doi.org/10.21203/rs.3.rs-8365561/v1</a></p>
<p><strong>Publication</strong>: Research Square preprint, February 2026</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/fffh1/AdaptMol">GitHub</a></li>
<li><a href="https://huggingface.co/fffh1/AdaptMol/tree/main">HuggingFace (model + data)</a></li>
</ul>
]]></content:encoded></item><item><title>The Quarks of Attention: Building Blocks of Attention</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/quarks-of-attention/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/quarks-of-attention/</guid><description>Baldi and Vershynin's 2023 theoretical analysis decomposing attention into fundamental building blocks and proving capacity bounds for attentional circuits.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>theory paper</strong> that takes a reductionist approach to attention mechanisms. It classifies all possible fundamental building blocks of attention (&ldquo;quarks&rdquo;) within a formal neural network framework, then proves capacity theorems for circuits built from these primitives using linear and polynomial threshold gates.</p>
<h2 id="why-decompose-attention-into-primitives">Why decompose attention into primitives?</h2>
<p>Descriptions of attention in deep learning often seem complex and obscure the underlying neural architecture. Despite the widespread use of attention in transformers and beyond, there has been little formal theory about the computational nature and capacity of attention mechanisms. Baldi and Vershynin address this by identifying the smallest building blocks and rigorously analyzing what they can compute.</p>
<h2 id="the-standard-model-and-its-extensions">The Standard Model and its extensions</h2>
<p>The paper defines the &ldquo;Standard Model&rdquo; (SM) as the class of all neural networks built from McCulloch-Pitt neurons: directed weighted graphs where neuron $i$ computes $O_i = f_i(S_i)$ with activation $S_i = \sum_j w_{ij} O_j$. The SM already has universal approximation properties, so extensions should be evaluated on efficiency (circuit size, depth, learning), not on what functions can be represented.</p>
<p>Three variable types exist in the SM: activations ($S$), outputs ($O$), and synaptic weights ($w$). Cross these with two mechanisms (addition, multiplication) and the constraint that attending signals originate from neuronal outputs, and you get six possible attention primitives.</p>
<h2 id="the-six-quarks-reduced-to-three">The six quarks, reduced to three</h2>
<table>
  <thead>
      <tr>
          <th></th>
          <th>$S$ (activation)</th>
          <th>$O$ (output)</th>
          <th>$w$ (synapse)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Addition</strong></td>
          <td>Multiplexing (in SM)</td>
          <td>Additive output (in SM)</td>
          <td>Additive synaptic</td>
      </tr>
      <tr>
          <td><strong>Multiplication</strong></td>
          <td>Activation gating</td>
          <td><strong>Output gating</strong></td>
          <td><strong>Synaptic gating</strong></td>
      </tr>
  </tbody>
</table>
<p>The paper shows these reduce to three cases worth studying:</p>
<h3 id="multiplexing-additive-activation-attention">Multiplexing (additive activation attention)</h3>
<p>The attending signal $S_2$ is added to the normal activation $S_1$, producing $O_i = f_i(S_1 + S_2)$. With sigmoid or threshold activations, a large negative $S_2$ forces the output to zero regardless of $S_1$, suppressing unattended stimuli. This mechanism lives entirely within the SM and plays a central role in proving capacity lower bounds.</p>
<h3 id="output-gating">Output gating</h3>
<p>Neuron $j$ multiplies the output of neuron $i$, producing $O_i O_j$. This quadratic term is new to the SM. The gated signal $O_i O_j$ propagates to all downstream neurons of $i$. When $O_j \approx 0$, the attended neuron is silenced; when $O_j$ is large, it is enhanced.</p>
<h3 id="synaptic-gating">Synaptic gating</h3>
<p>Neuron $j$ multiplies a synaptic weight $w_{ki}$, creating a dynamic weight $w_{ki} O_j$. This produces the same local term $w_{ki} O_i O_j$ at neuron $k$ as output gating, but affects only the single downstream connection rather than all of neuron $i$&rsquo;s outputs. Synaptic gating is a fast weight mechanism: the attending network dynamically changes the program executed by the attended network.</p>
<h2 id="transformers-are-built-entirely-from-gating">Transformers are built entirely from gating</h2>
<p>The paper shows that transformer encoder modules decompose into:</p>
<ol>
<li><strong>Output gating</strong> ($mn^2$ operations): computing all $n^2$ pairwise dot products of $Q$ and $K$ vectors, each requiring $m$ element-wise multiplications</li>
<li><strong>Softmax</strong>: a standard SM extension</li>
<li><strong>Synaptic gating</strong> ($n^2$ operations): weighting $V$ vectors by the softmax outputs to form convex combinations</li>
</ol>
<p>The entire attention mechanism uses $O(mn^2)$ gating operations. The permutation invariance of transformers follows directly from the weight sharing across input positions.</p>
<h2 id="relationship-to-polynomial-neural-networks">Relationship to polynomial neural networks</h2>
<p>Gating is a special case of polynomial activation. A neuron with full quadratic activation over $n$ inputs has the form $S_i = \sum_{jk} w_{ijk} O_j O_k$, requiring $O(n^2)$ three-way synaptic weights for all possible pairs. Gating introduces only one new quadratic term per operation. The same gating concepts can also be applied to more complex units with polynomial activations of degree $d$, where one polynomial threshold unit gates the output or synapse of another.</p>
<h2 id="functional-properties-of-gating">Functional properties of gating</h2>
<p>Several examples illustrate what gating enables:</p>
<ul>
<li><strong>Shaping activation functions</strong>: When a unit with activation function $f$ is output-gated by a unit with activation function $g$ (both having the same inputs), the result is $f(S)g(S) = fg(S)$. This changes the effective activation function from $f$ to $fg$. For instance, a linear unit gated by a $(0,1)$ threshold function produces the ReLU activation.</li>
<li><strong>XOR without hidden layers</strong>: The XOR function cannot be computed by a single linear threshold gate. However, gating the OR function by the NAND function (both implementable by single linear threshold gates) produces XOR in a shallow network with no hidden layers.</li>
<li><strong>Universal approximation</strong>: Every continuous function on a compact set can be approximated to arbitrary precision by a shallow attention network of linear units gated by linear threshold gates (Theorem 4.3).</li>
</ul>
<h2 id="attention-as-sparse-quadratic-terms">Attention as sparse quadratic terms</h2>
<p>Both output and synaptic gating introduce quadratic terms of the form $w_{ki} O_i O_j$. A neuron with full quadratic activation over $n$ inputs would require $O(n^2)$ parameters. Gating introduces only one new quadratic term per operation. This is the key insight: attention mechanisms gain some of the expressiveness of quadratic activations while avoiding the combinatorial parameter explosion.</p>
<h2 id="capacity-results">Capacity results</h2>
<p>Using cardinal capacity (the base-2 logarithm of the number of distinct Boolean functions a class of circuits can implement), the paper proves bounds for attentional circuits with linear and polynomial threshold gates:</p>
<ul>
<li><strong>Single unit with output gating</strong>: a gated pair of linear threshold gates on $n$ inputs has capacity $2n^2(1 + o(1))$, compared to $n^2(1 + o(1))$ for a single gate (Theorem 6.1). This represents a doubling of capacity with a doubling of parameters (from $n$ to $2n$), a sign of efficiency.</li>
<li><strong>Multiplexing technique</strong>: additive activation attention enables a &ldquo;multiplexing&rdquo; proof strategy where one unit in a layer is selected as a function of the attending units while driving remaining units to saturation. This is the key tool for proving lower bounds.</li>
<li><strong>Attention layers</strong>: extending to layers of $m$ gated units with $n$ inputs, the capacity is $2mn^2(1 + o(1))$ for output gating (Theorem 7.1), confirming that gating approximately doubles the capacity relative to ungated layers.</li>
<li><strong>Depth reduction</strong>: gating operations available as physical primitives in a neural network can reduce the depth required for certain basic circuits.</li>
</ul>
<h2 id="limitations-and-future-work">Limitations and future work</h2>
<p>The authors note several open directions:</p>
<ul>
<li>The capacity estimates for some configurations (e.g., single-weight synaptic gating in Proposition 6.9) have gaps between the lower and upper bounds that remain to be tightened.</li>
<li>The analysis uses Boolean neurons (linear and polynomial threshold gates) as approximations. Extending results to other activation functions (sigmoid, ReLU) is left for future work.</li>
<li>The paper focuses on single layers and pairs of units. Capacity analysis for deeper attention architectures with multiple stacked layers is not addressed.</li>
<li>The theory treats attention on the time scale of individual inputs. The paper briefly notes that fast synaptic mechanisms operating on different time scales raise interesting architectural questions but does not develop this direction.</li>
</ul>
<h2 id="reproducibility">Reproducibility</h2>
<p>This is a purely theoretical paper with no associated code, datasets, or pretrained models. All results are mathematical theorems and proofs that can be verified from the paper itself. The paper is freely available on arXiv under a CC BY-NC-ND 4.0 license.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Baldi, P. &amp; Vershynin, R. (2023). The quarks of attention: Structure and capacity of neural attention building blocks. <em>Artificial Intelligence</em>, 319, 103901.</p>
<p><strong>Publication</strong>: Artificial Intelligence 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.sciencedirect.com/science/article/pii/S0004370223000474">Journal (ScienceDirect)</a></li>
<li><a href="https://arxiv.org/abs/2202.08371">arXiv</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{baldi2023quarks,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{The quarks of attention: Structure and capacity of neural attention building blocks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Baldi, Pierre and Vershynin, Roman}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{319}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{103901}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1016/j.artint.2023.103901}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Spherical CNNs: Rotation-Equivariant Networks on the Sphere</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/spherical-cnns/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/spherical-cnns/</guid><description>Cohen et al. introduce rotation-equivariant spherical CNNs that define cross-correlation on SO(3), computed via generalized FFT from harmonic analysis.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces the theory and implementation of convolutional neural networks on the sphere. The key contribution is defining spherical cross-correlation that is SO(3)-equivariant and can be computed efficiently using generalized Fast Fourier Transforms from non-commutative harmonic analysis.</p>
<h2 id="why-planar-convolutions-fail-on-spherical-data">Why planar convolutions fail on spherical data</h2>
<p>Many problems require analyzing spherical signals: omnidirectional vision for robots and autonomous vehicles, molecular regression, and global weather modeling. A naive approach of projecting spherical data to a plane introduces space-varying distortions that break translational weight sharing. Rotating a spherical signal cannot be emulated by translating its planar projection.</p>
<p>The fundamental issue is geometric: patterns on a plane move via translations, but patterns on a sphere move via 3D rotations. A spherical CNN should detect patterns regardless of how they are rotated over the sphere. The relevant symmetry group is SO(3) (the group of all 3D rotations).</p>
<h2 id="spherical-cross-correlation-and-the-so3-output-space">Spherical cross-correlation and the SO(3) output space</h2>
<p>The paper defines spherical cross-correlation by replacing filter translations with rotations. For spherical signals $f$ on $S^2$ (the unit sphere) and filter $\psi$, the correlation is:</p>
<p>$$\lbrack\psi \star f\rbrack(R) = \langle L_R \psi, f \rangle = \int_{S^2} \sum_{k=1}^{K} \psi_k(R^{-1}x) f_k(x) , dx$$</p>
<p>where $L_R$ is the rotation operator $\lbrack L_R f\rbrack(x) = f(R^{-1}x)$.</p>
<p>A crucial subtlety: whereas the space of moves for the plane (2D translations) is isomorphic to the plane itself, the space of moves for the sphere (3D rotations) is SO(3), a different three-dimensional manifold. The output of a spherical correlation is therefore a function on SO(3), not on $S^2$. This means subsequent layers must use SO(3) correlation:</p>
<p>$$\lbrack\psi \star f\rbrack(R) = \int_{\text{SO}(3)} \sum_{k=1}^{K} \psi_k(R^{-1}Q) f_k(Q) , dQ$$</p>
<h3 id="equivariance-proof">Equivariance proof</h3>
<p>Equivariance follows from the unitarity of $L_R$ in a single line:</p>
<p>$$\lbrack\psi \star \lbrack L_Q f\rbrack\rbrack(R) = \langle L_R \psi, L_Q f \rangle = \langle L_{Q^{-1}R} \psi, f \rangle = \lbrack\psi \star f\rbrack(Q^{-1}R) = \lbrack L_Q\lbrack\psi \star f\rbrack\rbrack(R)$$</p>
<p>This holds for both $S^2$ and SO(3) correlation.</p>
<h2 id="efficient-computation-via-generalized-fft">Efficient computation via generalized FFT</h2>
<p>A naive SO(3) correlation is $O(n^6)$. The paper addresses this using the generalized Fourier transform (GFT) from non-commutative harmonic analysis.</p>
<p>The GFT projects functions onto orthogonal basis functions: spherical harmonics $Y_m^l(x)$ for $S^2$, and Wigner D-functions $D_{mn}^l(R)$ for SO(3). Both satisfy generalized Fourier theorems:</p>
<ul>
<li><strong>SO(3) convolution theorem</strong>: $\widehat{\psi \star f} = \hat{f} \cdot \hat{\psi}^\dagger$ (matrix multiplication of block Fourier coefficients)</li>
<li><strong>$S^2$ convolution theorem</strong>: $\widehat{\psi \star f}^l = \hat{f}^l \cdot \hat{\psi}^{l\dagger}$ (outer product of $S^2$ Fourier coefficient vectors)</li>
</ul>
<p>The SO(3) FFT works in two steps: (1) standard 2D FFT over the $\alpha$ and $\gamma$ Euler angles, then (2) linear contraction of the $\beta$ axis with precomputed Wigner-d function samples, implemented as a custom GPU kernel.</p>
<h2 id="experiments">Experiments</h2>
<h3 id="equivariance-error">Equivariance error</h3>
<p>Since the theory applies to continuous functions but the implementation is discretized, the authors rigorously measure equivariance error. The approximation error grows with resolution and depth but stays manageable for practical bandwidths. With ReLU activations, the error is higher but stays flat across layers, indicating the error comes from feature map rotation (exact only for bandlimited functions) rather than accumulating through the network.</p>
<h3 id="spherical-mnist">Spherical MNIST</h3>
<p>MNIST digits projected onto the sphere, tested in non-rotated (NR) and rotated (R) settings with ~165K parameters per model:</p>
<table>
  <thead>
      <tr>
          <th>Train / Test</th>
          <th>Planar CNN</th>
          <th>Spherical CNN</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>NR / NR</td>
          <td>99%</td>
          <td>91%</td>
      </tr>
      <tr>
          <td>R / R</td>
          <td>45%</td>
          <td>91%</td>
      </tr>
      <tr>
          <td>NR / R</td>
          <td>9%</td>
          <td>85%</td>
      </tr>
  </tbody>
</table>
<p>The planar CNN collapses to chance when trained on non-rotated data and tested on rotated data. The spherical CNN maintains strong performance across all settings.</p>
<h3 id="3d-shape-recognition-shrec17">3D shape recognition (SHREC17)</h3>
<p>3D meshes projected onto an enclosing sphere via ray casting. For each point on the sphere, a ray is cast toward the origin, collecting three types of information from the intersection: ray length and cos/sin of the surface angle. The same three channels are computed for the convex hull, giving 6 channels total. The network (~1.4M parameters) placed 2nd on recall, mAP, and NDCG, and 3rd on precision and F1 in the SHREC17 competition, competing against methods with highly task-specialized architectures.</p>
<h3 id="molecular-atomization-energy-qm7">Molecular atomization energy (QM7)</h3>
<p>Molecules represented as spherical potential functions around each atom (generalizing the Coulomb matrix). A deep ResNet-style $S^2$CNN with DeepSets-style permutation-invariant aggregation over atoms achieved 8.47 RMSE, outperforming all kernel-based approaches and sorted Coulomb matrix methods.</p>
<h2 id="discussion-and-future-directions">Discussion and future directions</h2>
<p>The authors highlight several avenues for future work. For volumetric tasks like 3D model recognition, extending beyond SO(3) to the roto-translation group SE(3) could improve results. They also note that a Steerable CNN for the sphere would enable analysis of vector fields (e.g., global wind directions). Omnidirectional vision is mentioned as a compelling application as 360-degree sensors become more prevalent.</p>
<h2 id="reproducibility">Reproducibility</h2>
<p>The official PyTorch implementation is publicly available. The code does not support recent PyTorch versions due to changes in the FFT interface.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/jonkhler/s2cnn">s2cnn</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation (deprecated for modern PyTorch)</td>
      </tr>
  </tbody>
</table>
<p>Hardware requirements from the paper: the SHREC17 model uses 8GB GPU memory at batch size 16 and takes 50 hours to train. The QM7 model uses 7GB at batch size 20 and takes 3 hours to train. Datasets used (Spherical MNIST, SHREC17, QM7) are all publicly available.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cohen, T. S., Geiger, M., Köhler, J., &amp; Welling, M. (2018). Spherical CNNs. <em>International Conference on Learning Representations</em>. <a href="https://arxiv.org/abs/1801.10130">https://arxiv.org/abs/1801.10130</a></p>
<p><strong>Publication</strong>: ICLR 2018</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=Hkbd5xZRb">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/1801.10130">arXiv</a></li>
<li><a href="https://github.com/jonkhler/s2cnn">GitHub</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{cohen2018spherical,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Spherical {CNNs}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Cohen, Taco S. and Geiger, Mario and K{\&#34;o}hler, Jonas and Welling, Max}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SE(3)-Transformers: Equivariant Attention for 3D Data</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/se3-transformers/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/se3-transformers/</guid><description>Fuchs et al. combine self-attention with SE(3)-equivariance for 3D point clouds using invariant attention weights and equivariant value messages.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces the SE(3)-Transformer, a self-attention mechanism for 3D point clouds and graphs that is equivariant under continuous 3D rotations and translations. It builds on tensor field networks (TFNs) by adding data-dependent attention weights, resolving a known expressiveness limitation of equivariant convolutions.</p>
<h2 id="why-equivariant-attention-for-point-clouds">Why equivariant attention for point clouds?</h2>
<p>Point cloud data appears in 3D object scans, molecular structures, and particle simulations. Two properties are essential: handling varying numbers of irregularly sampled points, and invariance to global changes in pose (rotations and translations).</p>
<p>Self-attention handles variable-size inputs naturally and has proven effective across many domains. Tensor field networks provide SE(3)-equivariant convolutions but suffer from a key limitation: their filter kernels are decomposed into learnable radial functions and fixed angular components (spherical harmonics). The angular dependence is completely constrained by the equivariance condition, leaving no learnable degrees of freedom in the angular direction. This has been identified in the literature as severely limiting performance.</p>
<p>The SE(3)-Transformer resolves this by introducing data-dependent attention weights that modulate the angular profile of the kernels while maintaining equivariance.</p>
<h2 id="architecture-invariant-attention-meets-equivariant-values">Architecture: invariant attention meets equivariant values</h2>
<p>The core layer combines three components:</p>
<p>$$\mathbf{f}_{\text{out},i}^{\ell} = \underbrace{\mathbf{W}_V^{\ell\ell} \mathbf{f}_{\text{in},i}^{\ell}}_{\text{self-interaction}} + \sum_{k \geq 0} \sum_{j \in \mathcal{N}_i \setminus i} \underbrace{\alpha_{ij}}_{\text{attention}} \underbrace{\mathbf{W}_V^{\ell k}(\mathbf{x}_j - \mathbf{x}_i) \mathbf{f}_{\text{in},j}^k}_{\text{value message}}$$</p>
<h3 id="invariant-attention-weights">Invariant attention weights</h3>
<p>The attention weights use dot-product attention between equivariant queries and keys:</p>
<p>$$\alpha_{ij} = \frac{\exp(\mathbf{q}_i^\top \mathbf{k}_{ij})}{\sum_{j&rsquo; \in \mathcal{N}_i \setminus i} \exp(\mathbf{q}_i^\top \mathbf{k}_{ij&rsquo;})}$$</p>
<p>Both $\mathbf{q}_i$ and $\mathbf{k}_{ij}$ are constructed using TFN-type linear embeddings, making them SE(3)-equivariant. Their inner product is invariant because SO(3) representations are orthogonal: $\mathbf{q}^\top \mathbf{S}_g^\top \mathbf{S}_g \mathbf{k} = \mathbf{q}^\top \mathbf{k}$.</p>
<h3 id="equivariant-value-messages">Equivariant value messages</h3>
<p>The value messages use the same TFN kernel structure as tensor field networks: weight kernels $\mathbf{W}_V^{\ell k}(\mathbf{x})$ decomposed into learnable radial functions and Clebsch-Gordan/spherical harmonic angular components. Features are typed by irreducible representation degree $\ell$ (the independent matrix blocks into which SO(3) group actions decompose): type-0 vectors are rotation-invariant scalars, type-1 vectors transform as 3D vectors, and so on.</p>
<h3 id="angular-modulation">Angular modulation</h3>
<p>The attention weights $\alpha_{ij}$ multiply the value messages, creating data-dependent kernels $\alpha_{ij} \mathbf{W}_V^{\ell k}(\mathbf{x})$. This effectively modulates the angular profile of the fixed spherical harmonic components, adding learnable angular degrees of freedom while preserving equivariance. The authors describe this as one of the first examples of a nonlinear equivariant layer.</p>
<h3 id="attentive-self-interaction">Attentive self-interaction</h3>
<p>The paper also introduces attentive self-interaction as an alternative to the standard linear self-interaction (analogous to 1x1 convolutions). Instead of fixed learned weights across all points, the weights are generated by an MLP operating on invariant inner products of the input features:</p>
<p>$$w_{i,c&rsquo;c}^{\ell\ell} = \text{MLP}\left(\bigoplus_{c,c&rsquo;} \mathbf{f}_{\text{in},i,c&rsquo;}^{\ell\top} \mathbf{f}_{\text{in},i,c}^{\ell}\right)$$</p>
<h2 id="experiments">Experiments</h2>
<h3 id="n-body-particle-simulation">N-body particle simulation</h3>
<p>Five charged particles carrying positive or negative charges, exerting repulsive or attractive forces on each other. The task is predicting positions and velocities 500 timesteps ahead. The SE(3)-Transformer achieves 0.0076 MSE on position (vs. 0.0139 for Set Transformer and 0.0151 for TFN), with equivariance error on the order of $10^{-7}$, confirming exact equivariance up to numerical precision.</p>
<h3 id="scanobjectnn-real-world-3d-object-classification">ScanObjectNN (real-world 3D object classification)</h3>
<p>2902 real-world scanned objects across 15 categories. This task is only SO(2)-invariant (gravity axis matters), so the authors provide the z-component as an additional scalar input. With only 128 input points, the SE(3)-Transformer+z achieves 85.0% accuracy, competitive with methods using 1024 points and task-specific architectures. The model learns to ignore the symmetry-breaking z-input when trained on rotation-augmented data.</p>
<h3 id="qm9-molecular-property-regression"><a href="/notes/chemistry/datasets/qm9/">QM9</a> molecular property regression</h3>
<p>134k molecules with up to 29 atoms, predicting 6 quantum chemical properties. The SE(3)-Transformer achieves competitive results against other equivariant models (TFN, Cormorant), with improvements over TFN on all six targets. Across all three experiments, the SE(3)-Transformer outperforms both a non-equivariant attention baseline (Set Transformer) and equivariant models without attention (TFN).</p>
<h3 id="practical-contributions">Practical contributions</h3>
<p>The paper includes a PyTorch spherical harmonics implementation that is 10x faster than Scipy on CPU and 100-1000x faster on GPU. For a ScanObjectNN model, this yields roughly 22x speedup of the forward pass compared to the lie-learn library, directly addressing a major bottleneck of TFN-based architectures.</p>
<h2 id="conclusions-and-limitations">Conclusions and limitations</h2>
<p>Adding attention to a roto-translation-equivariant model consistently led to higher accuracy and increased training stability across all three experiments. For large neighbourhoods, attention proved essential for model convergence. The equivariance constraints also improved performance compared to conventional (non-equivariant) attention in all experiments.</p>
<p>The authors note that the SE(3)-Transformer is inherently suited for classification and regression on molecular data and discuss applications in drug research, including early-stage suitability classification of molecules for inhibiting viral reproductive cycles.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/FabianFuchsML/se3-transformer-public">se3-transformer-public</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch + DGL implementation</td>
      </tr>
  </tbody>
</table>
<p>The repository includes code for N-body simulations and QM9 experiments. Hyperparameters and architecture details are provided in the paper&rsquo;s appendix (4 equivariant layers, representation degrees, channels per degree, learning rates, batch sizes). Hardware requirements are not explicitly stated in the paper.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fuchs, F. B., Worrall, D. E., Fischer, V., &amp; Welling, M. (2020). SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks. <em>Advances in Neural Information Processing Systems</em>, 33. <a href="https://arxiv.org/abs/2006.10503">https://arxiv.org/abs/2006.10503</a></p>
<p><strong>Publication</strong>: NeurIPS 2020</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2006.10503">arXiv</a></li>
<li><a href="https://github.com/FabianFuchsML/se3-transformer-public">GitHub</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{fuchs2020se3,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{{SE(3)-Transformers}: 3D Roto-Translation Equivariant Attention Networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fuchs, Fabian B. and Worrall, Daniel E. and Fischer, Volker and Welling, Max}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{33}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Scaling Laws vs Model Architectures: Inductive Bias</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/</guid><description>Tay et al.'s 2022 study comparing scaling behavior across ten model architectures, showing that inductive bias affects scaling properties in distinct ways.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>systematization paper</strong> that conducts a large-scale empirical comparison of how ten different model architectures scale. Rather than proposing a new architecture, it characterizes the relationship between inductive bias and scaling behavior across both upstream (pretraining) and downstream (transfer) performance.</p>
<h2 id="why-architecture-aware-scaling-matters">Why architecture-aware scaling matters</h2>
<p>Prior scaling laws work (Kaplan et al., 2020) focused almost exclusively on vanilla Transformers, finding that loss scales as a power law with model size, dataset size, and compute. A common assumption in the field is that improvements observed at one scale transfer to other scales, and new architectures are often evaluated at a single compute point (e.g., base size). This paper challenges that assumption by asking whether different inductive biases scale differently.</p>
<h2 id="ten-architectures-one-controlled-setup">Ten architectures, one controlled setup</h2>
<p>All models are implemented in Mesh TensorFlow under a shared encoder-decoder (<a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style) framework, pretrained on C4 for $2^{19}$ steps with Adafactor optimizer and inverse square root learning rate schedule, and finetuned for 100K steps on GLUE + SuperGLUE + SQuAD. Models range from 15M to 40B parameters, trained on 16 TPU-v3 chips. The ten architectures span four categories:</p>
<p><strong>Transformer variants</strong>: vanilla Transformer, Evolved Transformer (AutoML-derived), Universal Transformer (parameter sharing + recurrence), Switch Transformer (sparse MoE)</p>
<p><strong>Efficient variants</strong>: Performer (linear attention), Funnel Transformer (sequence downsampling), ALBERT (cross-layer parameter sharing + embedding factorization)</p>
<p><strong>General improvements</strong>: Mixture of Softmaxes (MoS), Gated Linear Units (GLU)</p>
<p><strong>Non-Transformers</strong>: Lightweight Convolutions, Dynamic Convolutions, MLP-Mixer</p>
<h2 id="key-findings-on-scaling-behavior">Key findings on scaling behavior</h2>
<h3 id="architecture-changes-the-scaling-slope">Architecture changes the scaling slope</h3>
<p>The paper fits linear scaling laws in log-log space (i.e., power law fits of the form $L \propto C^{-\alpha}$) for each model across multiple axes (FLOPs vs. upstream, FLOPs vs. downstream, etc.). The vanilla Transformer has the highest scaling coefficient on most reported axes ($\alpha_{F,U} = 0.54$, $\alpha_{F,D} = 0.28$). Models that make minimal changes to the Transformer (GLU, MoS) retain similar scaling behavior. Models with more radical inductive biases show worse scaling:</p>
<ul>
<li><strong>Performer</strong> (linear attention): $\alpha_{F,U} = 0.25$, upstream perplexity decreases only 2.7% from base to large vs. 8.4% for vanilla Transformer</li>
<li><strong>ALBERT</strong>: scales negatively on downstream ($\alpha_{F,D} = -0.12$), getting worse as compute increases. ALBERT was designed for parameter efficiency (cross-layer weight sharing, embedding factorization), not compute efficiency, so this result is expected: additional FLOPs reuse the same parameters without adding capacity</li>
<li><strong>MLP-Mixer</strong>: near-zero downstream scaling ($\alpha_{F,D} = -0.03$)</li>
</ul>
<h3 id="the-best-architecture-changes-with-scale">The best architecture changes with scale</h3>
<p>Models that perform well at small compute budgets are not necessarily the best at larger budgets. For example, the Evolved Transformer outperforms vanilla Transformers at tiny-to-small scale on downstream tasks but falls behind when scaled up. MoS-Transformer outperforms vanilla Transformers at some compute regions but not others.</p>
<h3 id="upstream-and-downstream-scaling-diverge">Upstream and downstream scaling diverge</h3>
<p>Good upstream perplexity scaling does not guarantee good downstream transfer scaling. Funnel Transformers and Lightweight Convolutions hold up reasonably well on upstream perplexity but suffer substantially on downstream tasks. Switch Transformers show the best upstream-to-downstream transfer ratio ($\alpha_{U,D} = 0.58$).</p>
<h3 id="depth-and-width-affect-architectures-differently">Depth and width affect architectures differently</h3>
<p>Depth scaling has a more substantial impact on downstream performance than width scaling across most architectures. Evolved Transformers are a partial exception, scaling slightly better under width scaling compared to other architectures on downstream tasks.</p>
<h2 id="practical-implications">Practical implications</h2>
<p>The authors offer concrete guidance: practitioners should be cautious about staking expensive large-scale runs on architectures that drastically modify the attention mechanism. Performers and MLP-Mixers are characterized as &ldquo;high risk&rdquo; options. This helps explain why most large language models at the time (PaLM, Gopher, UL2) use relatively vanilla Transformer architectures.</p>
<p>The paper also notes that not every use case requires billion-parameter models. Inductive biases tailored to small or low-compute regimes remain valuable when scaling is not the priority.</p>
<h2 id="reproducibility">Reproducibility</h2>
<p>No code or trained model weights were publicly released with this paper. The experiments rely on Google&rsquo;s internal Mesh TensorFlow infrastructure with 16 TPU-v3 chips, and pretraining uses the publicly available C4 corpus. Finetuning benchmarks (GLUE, SuperGLUE, SQuAD) are all publicly available. However, reproducing the full study would require substantial compute resources and re-implementation of all ten architectures within a shared framework.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://arxiv.org/abs/2207.10551">arXiv paper</a></td>
          <td>Paper</td>
          <td>Open access</td>
          <td>Full paper with appendices</td>
      </tr>
      <tr>
          <td><a href="https://www.tensorflow.org/datasets/catalog/c4">C4 corpus</a></td>
          <td>Dataset</td>
          <td>ODC-BY</td>
          <td>Pretraining data</td>
      </tr>
  </tbody>
</table>
<p><strong>Missing components</strong>: No released code, model checkpoints, or training scripts. Internal Mesh TensorFlow codebase is not publicly available.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Tay, Y., Dehghani, M., Abnar, S., Chung, H. W., Fedus, W., Rao, J., Narang, S., Tran, V. Q., Yogatama, D., &amp; Metzler, D. (2022). Scaling Laws vs Model Architectures: How does Inductive Bias Influence Scaling? <em>EMNLP 2022</em>.</p>
<p><strong>Publication</strong>: EMNLP 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2207.10551">arXiv</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{tay2022scaling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scaling Laws vs Model Architectures: How does Inductive Bias Influence Scaling?}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Tay, Yi and Dehghani, Mostafa and Abnar, Samira and Chung, Hyung Won and Fedus, William and Rao, Jinfeng and Narang, Sharan and Tran, Vinh Q. and Yogatama, Dani and Metzler, Donald}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Relational Inductive Biases in Deep Learning (2018)</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/relational-inductive-biases-deep-learning-graph-networks/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/relational-inductive-biases-deep-learning-graph-networks/</guid><description>Battaglia et al.'s 2018 paper unifying graph neural network variants under a general graph network framework and analyzing relational inductive biases.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>systematization paper</strong> that is part position paper, part review, and part unification. It argues that combinatorial generalization, the ability to construct new inferences from known building blocks, is a top priority for AI. It frames relational inductive biases as the key design principle connecting standard deep learning architectures, presents the graph network (GN) as a general framework subsuming prior graph neural network variants, and advocates for combining structured approaches with deep learning rather than choosing between them.</p>
<h2 id="the-case-for-relational-inductive-biases">The case for relational inductive biases</h2>
<p>Human intelligence relies on representing the world as compositions of entities, relations, and rules. We understand complex systems by decomposing them into parts and their interactions. Modern deep learning&rsquo;s &ldquo;end-to-end&rdquo; philosophy minimizes structural assumptions, relying on data and compute to learn representations from scratch. The paper argues this approach struggles with combinatorial generalization: generalizing beyond one&rsquo;s experiences by composing known elements in new ways.</p>
<p>The authors reject the false dichotomy between &ldquo;hand-engineering&rdquo; and &ldquo;end-to-end&rdquo; learning. Just as biology uses both nature and nurture, they advocate for architectures that bake in useful structural assumptions (inductive biases) while still learning flexibly from data.</p>
<h2 id="inductive-biases-across-standard-architectures">Inductive biases across standard architectures</h2>
<p>The paper provides a systematic analysis of how existing architectures encode relational structure:</p>
<p><strong>Fully connected networks (MLPs)</strong>: The weakest relational inductive bias. All input units can interact with all others, with no reuse of parameters. No assumptions about the structure of the input.</p>
<p><strong>Convolutional networks (CNNs)</strong>: Encode locality (nearby elements interact) and translation invariance (the same local function is applied everywhere). The entities are individual units or grid elements (e.g., pixels), the relations are defined by the grid neighborhood, and the rule (convolution kernel) is shared across all positions.</p>
<p><strong>Recurrent networks (RNNs)</strong>: Encode sequential structure and temporal invariance. The entities are time steps, each step relates to the previous one through a shared transition function. This imposes a Markovian bias (the future depends on the present state, not the full history directly).</p>
<p><strong>Sets and self-attention</strong>: Permutation invariant architectures impose no ordering on entities. Self-attention (as in Transformers) allows all pairwise interactions but with no structural prior on which interactions matter.</p>
<p>Each architecture can be understood as making specific commitments about what the entities are, what the relations between them are, and what rules govern their interactions.</p>
<h2 id="the-graph-network-framework">The graph network framework</h2>
<p>The paper defines a general &ldquo;graph network&rdquo; (GN) block that operates on graphs with attributes on nodes, edges, and the global graph level. A GN block performs three update steps and three aggregation steps:</p>
<ol>
<li><strong>Edge update</strong>: For each edge, compute updated edge attributes using the current edge attributes, the sender node attributes, the receiver node attributes, and the global attributes</li>
<li><strong>Node update</strong>: For each node, aggregate incoming updated edge attributes, then compute updated node attributes using the aggregated edges, current node attributes, and global attributes</li>
<li><strong>Global update</strong>: Aggregate all updated edge and node attributes, then compute updated global attributes</li>
</ol>
<p>Each update function is learned (typically a small neural network), and each aggregation function must be permutation invariant (typically sum, mean, or max).</p>
<p>This framework generalizes prior work:</p>
<ul>
<li><strong>Message Passing Neural Networks</strong> (Gilmer et al., 2017): edge and node updates with a readout function but no explicit global attribute in message passing</li>
<li><strong>Non-local Neural Networks</strong> (Wang et al., 2018): attention-weighted edge interactions</li>
<li><strong>Interaction Networks</strong> (Battaglia et al., 2016): physics-inspired message passing</li>
<li><strong>Relation Networks</strong> (Santoro et al., 2017): a simple neural network module for relational reasoning</li>
<li><strong>Discovering objects and their relations</strong> (Raposo et al., 2017): discovering objects and their relations from entangled scene representations</li>
<li><strong>Deep Sets</strong> (Zaheer et al., 2017): node-only aggregation without edge structure</li>
<li><strong>CommNet, Structure2Vec, GGNNs</strong>, and others</li>
</ul>
<p>The paper shows how each prior approach corresponds to a specific configuration of which GN components are used and how they are connected.</p>
<h2 id="design-principles-for-graph-networks">Design principles for graph networks</h2>
<p>The paper identifies several key design choices:</p>
<p><strong>Flexible representations</strong>: GN blocks can output graphs with different structure than their input (e.g., predicting edge existence), enabling tasks like link prediction, clustering, or property regression.</p>
<p><strong>Configurable within-block structure</strong>: The internal update and aggregation functions can be swapped freely. The framework separates what is computed (the relational structure) from how it is computed (the function approximators).</p>
<p><strong>Composable multi-block architectures</strong>: GN blocks can be stacked, sharing or not sharing weights across layers. They can be composed with other architectures (e.g., an encoder-GN-decoder pattern) or arranged in recurrent configurations.</p>
<p><strong>Combinatorial generalization</strong>: Because GN blocks share functions across edges and nodes, they can generalize to graphs of different sizes and topologies than those seen during training. A GN trained on small graphs can, in principle, be applied to larger ones.</p>
<h2 id="connections-to-broader-ai-themes">Connections to broader AI themes</h2>
<p>The paper frames graph networks as supporting:</p>
<ul>
<li><strong>Relational reasoning</strong>: Learning about entities and their interactions</li>
<li><strong>Combinatorial generalization</strong>: Applying learned rules to novel combinations</li>
<li><strong>Structured prediction</strong>: Producing complex, structured outputs including graphs and sequences</li>
<li><strong>Interpretable representations</strong>: Graph structure provides a natural vocabulary for understanding what the model has learned</li>
</ul>
<p>The authors also discuss connections to classical AI (logic, planning, causal reasoning) and argue that graph networks provide a bridge between the flexibility of deep learning and the compositionality of symbolic approaches.</p>
<h2 id="limitations-and-open-questions">Limitations and open questions</h2>
<p>The paper identifies several limitations of graph networks:</p>
<ul>
<li><strong>Graph isomorphism</strong>: Learned message-passing cannot be guaranteed to discriminate between certain non-isomorphic graphs. Kondor et al. (2018) suggested that covariance, rather than invariance to permutations, may be preferable.</li>
<li><strong>Expressivity limits of graphs</strong>: Notions like recursion, control flow, and conditional iteration are not straightforward to represent with graphs. Programs and more &ldquo;computer-like&rdquo; processing may offer greater representational and computational expressivity for these concepts.</li>
<li><strong>Where do graphs come from?</strong>: Converting raw sensory data (images, text) into graph-structured representations remains an open problem. Fully connected graphs between spatial or linguistic entities are a common workaround but may not reflect the true underlying structure.</li>
<li><strong>Adaptive graph structure</strong>: How to modify graph topology during computation (e.g., splitting a node when an object fractures, or adding/removing edges based on contact) is an active research direction.</li>
</ul>
<h2 id="reproducibility">Reproducibility</h2>
<p>The authors released an open-source software library for building graph networks in TensorFlow/Sonnet, including demos for shortest-path finding, sorting, and physical prediction tasks.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/deepmind/graph_nets">Graph Nets library</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official TensorFlow/Sonnet implementation with demos</td>
      </tr>
  </tbody>
</table>
<p>This is a position/systematization paper rather than an empirical one, so reproducibility pertains to the accompanying library rather than experimental results. The library and demos are publicly available, making the framework highly accessible.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Battaglia, P. W., Hamrick, J. B., Bapst, V., Sanchez-Gonzalez, A., Zambaldi, V., Malinowski, M., &hellip; &amp; Pascanu, R. (2018). Relational inductive biases, deep learning, and graph networks. <em>arXiv preprint arXiv:1806.01261</em>.</p>
<p><strong>Publication</strong>: arXiv 2018</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1806.01261">arXiv</a></li>
<li><a href="https://github.com/deepmind/graph_nets">Graph Nets library (GitHub)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{battaglia2018relational,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Relational inductive biases, deep learning, and graph networks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Battaglia, Peter W. and Hamrick, Jessica B. and Bapst, Victor and Sanchez-Gonzalez, Alvaro and Zambaldi, Vinicius and Malinowski, Mateusz and Tacchetti, Andrea and Raposo, David and Santoro, Adam and Faulkner, Ryan and others}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:1806.01261}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>OCSU: Optical Chemical Structure Understanding (2025)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/ocsu/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/ocsu/</guid><description>OCSU task for translating molecular images into multi-level descriptions. Introduces Vis-CheBI20 dataset and DoubleCheck/Mol-VL for molecular understanding.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fan, S., Xie, Y., Cai, B., Xie, A., Liu, G., Qiao, M., Xing, J., &amp; Nie, Z. (2025). OCSU: Optical Chemical Structure Understanding for Molecule-centric Scientific Discovery. <em>arXiv preprint arXiv:2501.15415</em>. <a href="https://doi.org/10.48550/arXiv.2501.15415">https://doi.org/10.48550/arXiv.2501.15415</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/PharMolix/OCSU">Code and Dataset (GitHub)</a></li>
</ul>
<h2 id="multi-level-chemical-understanding-method-and-resource">Multi-Level Chemical Understanding (Method and Resource)</h2>
<p>This is primarily a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a significant <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution.</p>
<ul>
<li><strong>Methodological</strong>: It proposes two novel architectures, <strong>DoubleCheck</strong> (an enhanced recognition model) and <strong>Mol-VL</strong> (an end-to-end vision-language model), to solve the newly formulated OCSU task.</li>
<li><strong>Resource</strong>: It constructs and releases <strong>Vis-CheBI20</strong>, the first large-scale dataset specifically designed for optical chemical structure understanding, containing 29.7K images and 117.7K image-text pairs.</li>
</ul>
<h2 id="the-motivation-for-ocsu-beyond-basic-graph-recognition">The Motivation for OCSU Beyond Basic Graph Recognition</h2>
<p>Existing methods for processing molecular images focus narrowly on <strong>Optical Chemical Structure Recognition (OCSR)</strong>, which translates an image solely into a machine-readable graph or SMILES string. However, SMILES strings are not chemist-friendly and lack high-level semantic context.</p>
<ul>
<li><strong>Gap</strong>: There is a lack of systems that can translate chemical diagrams into human-readable descriptions (e.g., functional groups, IUPAC names) alongside the graph structure.</li>
<li><strong>Goal</strong>: To enable <strong>Optical Chemical Structure Understanding (OCSU)</strong>, bridging the gap between visual representations and both machine/chemist-readable descriptions to support drug discovery and property prediction.</li>
</ul>
<h2 id="key-innovations-doublecheck-mol-vl-and-the-vis-chebi20-dataset">Key Innovations: DoubleCheck, Mol-VL, and the Vis-CheBI20 Dataset</h2>
<p>The paper introduces the <strong>OCSU task</strong>, enabling multi-level understanding (motif, molecule, and abstract levels). To solve this, it introduces two distinct paradigms:</p>
<ol>
<li><strong>DoubleCheck (OCSR-based)</strong>: An enhancement to standard OCSR models (like MolScribe) that performs a &ldquo;second look&rdquo; at locally ambiguous atoms. It uses attentive feature enhancement to fuse global molecular features with local features from ambiguous regions.</li>
<li><strong>Mol-VL (OCSR-free)</strong>: An end-to-end Vision-Language Model (VLM) based on Qwen2-VL. It uses multi-task learning to directly generate text descriptions from molecular images without an intermediate SMILES step.</li>
<li><strong>Vis-CheBI20 Dataset</strong>: A new benchmark specifically constructed for OCSU, deriving captions and functional group data from ChEBI-20 and PubChem.</li>
</ol>
<h2 id="methodology-and-experimental-evaluation">Methodology and Experimental Evaluation</h2>
<p>The authors evaluated both paradigms on <strong>Vis-CheBI20</strong> and existing benchmarks (USPTO, ACS) across four subtasks:</p>
<ol>
<li><strong>Functional Group Caption</strong>: Retrieval/F1 score evaluation.</li>
<li><strong>Molecule Description</strong>: Natural language generation metrics (BLEU, ROUGE, METEOR).</li>
<li><strong>IUPAC Naming</strong>: Text generation metrics (BLEU, ROUGE).</li>
<li><strong>SMILES Naming (OCSR)</strong>: Exact matching accuracy ($Acc_s$).</li>
</ol>
<p><strong>Baselines</strong>:</p>
<ul>
<li><strong>Task-Specific</strong>: MolScribe, MolVec, OSRA.</li>
<li><strong>LLM/VLM</strong>: Qwen2-VL, BioT5+, Mol-Instructions.</li>
<li><strong>Ablation</strong>: DoubleCheck vs. MolScribe backbone to test the &ldquo;feature enhancement&rdquo; mechanism.</li>
</ul>
<h2 id="results-and-conclusions-paradigm-trade-offs">Results and Conclusions: Paradigm Trade-Offs</h2>
<ul>
<li><strong>DoubleCheck Superiority</strong>: DoubleCheck outperformed MolScribe on OCSR tasks across all benchmarks. On USPTO, it achieved <strong>92.85%</strong> $Acc_s$ (vs. 92.57%), and on the ACS dataset it showed a <strong>+3.12%</strong> gain on chiral molecules. On Vis-CheBI20, DoubleCheck improved over MolScribe by an average of 2.27% across all metrics.</li>
<li><strong>Paradigm Trade-offs</strong>:
<ul>
<li><strong>Mol-VL (OCSR-free)</strong> excelled at semantic tasks like <strong>Functional Group Captioning</strong>, achieving <strong>97.32%</strong> F1 (vs. 93.63% for DoubleCheck &amp; RDKit and 89.60% for MolScribe &amp; RDKit). It benefits from end-to-end learning of structural context.</li>
<li><strong>DoubleCheck (OCSR-based)</strong> performed better on <strong>IUPAC naming recall</strong> and exact SMILES recovery, as explicit graph reconstruction is more precise for rigid nomenclature than VLM generation.</li>
</ul>
</li>
<li><strong>Conclusion</strong>: Enhancing submodules improves OCSR-based paradigms, while end-to-end VLMs offer stronger semantic understanding but struggle with exact syntax generation (SMILES/IUPAC).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Vis-CheBI20 Dataset</strong></p>
<ul>
<li><strong>Source</strong>: Derived from ChEBI-20 and PubChem.</li>
<li><strong>Size</strong>: 29,700 molecular diagrams, 117,700 image-text pairs.</li>
<li><strong>Generation</strong>: Images generated from SMILES using RDKit to simulate real-world journal/patent styles.</li>
<li><strong>Splits</strong> (vary by task, see table below):</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Train Size</th>
          <th style="text-align: left">Test Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Functional Group</td>
          <td style="text-align: left">26,144</td>
          <td style="text-align: left">3,269</td>
      </tr>
      <tr>
          <td style="text-align: left">Description</td>
          <td style="text-align: left">26,407</td>
          <td style="text-align: left">3,300</td>
      </tr>
      <tr>
          <td style="text-align: left">IUPAC Naming</td>
          <td style="text-align: left">26,200</td>
          <td style="text-align: left">2,680</td>
      </tr>
      <tr>
          <td style="text-align: left">SMILES Naming</td>
          <td style="text-align: left">26,407</td>
          <td style="text-align: left">3,300</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>DoubleCheck (Attentive Feature Enhancement)</strong></p>
<ol>
<li><strong>Ambiguity Detection</strong>: Uses atom prediction confidence to identify &ldquo;ambiguous atoms&rdquo;.</li>
<li><strong>Masking</strong>: Applies a 2D Gaussian mask to the image centered on the ambiguous atom.</li>
<li><strong>Local Encoding</strong>: A Swin-B encoder ($\Phi_l$) encodes the masked image region.</li>
<li><strong>Fusion</strong>: Aligns local features ($\mathcal{F}_l$) with global features ($\mathcal{F}_g$) using a 2-layer MLP and fuses them via weighted summation.</li>
</ol>
<p>$$
\begin{aligned}
\mathcal{F}_e = \mathcal{F}_g + \text{MLP}(\mathcal{F}_g \oplus \hat{\mathcal{F}}_l) \cdot \hat{\mathcal{F}}_l
\end{aligned}
$$</p>
<ol start="5">
<li><strong>Two-Stage Training</strong>:
<ul>
<li>Stage 1: Train atom/bond predictors (30 epochs).</li>
<li>Stage 2: Train alignment/fusion modules with random Gaussian mask noise (10 epochs).</li>
</ul>
</li>
</ol>
<p><strong>Mol-VL (Multi-Task VLM)</strong></p>
<ul>
<li><strong>Prompting</strong>: System prompt: &ldquo;You are working as an excellent assistant in chemistry&hellip;&rdquo;</li>
<li><strong>Tokens</strong>: Uses <code>&lt;image&gt;</code> and <code>&lt;/image&gt;</code> special tokens.</li>
<li><strong>Auxiliary Task</strong>: Functional group recognition (identifying highlighted groups) added to training to improve context learning.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>DoubleCheck</strong>:
<ul>
<li><strong>Backbone</strong>: MolScribe architecture.</li>
<li><strong>Encoders</strong>: Swin-B for both global and local atom encoding.</li>
</ul>
</li>
<li><strong>Mol-VL</strong>:
<ul>
<li><strong>Base Model</strong>: Qwen2-VL (2B and 7B versions).</li>
<li><strong>Vision Encoder</strong>: ViT with naive dynamic resolution and M-RoPE.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Key Metrics</strong>:</p>
<ul>
<li><strong>SMILES</strong>: Exact Match Accuracy ($Acc_s$), Chiral Accuracy ($Acc_c$).</li>
<li><strong>Functional Groups</strong>: F1 Score (Information Retrieval task).</li>
<li><strong>Text Generation</strong>: BLEU-2/4, METEOR, ROUGE-L.</li>
</ul>
<p><strong>Selected Results</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Model</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Score</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>DoubleCheck</strong></td>
          <td style="text-align: left">OCSR (USPTO)</td>
          <td style="text-align: left">$Acc_s$</td>
          <td style="text-align: left"><strong>92.85%</strong></td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>MolScribe</strong></td>
          <td style="text-align: left">OCSR (USPTO)</td>
          <td style="text-align: left">$Acc_s$</td>
          <td style="text-align: left">92.57%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Mol-VL-7B</strong></td>
          <td style="text-align: left">Func. Group Caption</td>
          <td style="text-align: left">F1</td>
          <td style="text-align: left"><strong>97.32%</strong></td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>DoubleCheck &amp; RDKit</strong></td>
          <td style="text-align: left">Func. Group Caption</td>
          <td style="text-align: left">F1</td>
          <td style="text-align: left">93.63%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>DoubleCheck</strong>: Trained on <strong>4 NVIDIA A100 GPUs</strong> for <strong>4 days</strong>.
<ul>
<li>Max LR: 4e-4.</li>
</ul>
</li>
<li><strong>Mol-VL</strong>: Trained on <strong>4 NVIDIA A100 GPUs</strong> for <strong>10 days</strong>.
<ul>
<li>Max LR: 1e-5, 50 epochs.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/PharMolix/OCSU">PharMolix/OCSU (GitHub)</a></td>
          <td style="text-align: left">Code, Model, Dataset</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Official implementation, Mol-VL-7B weights, and Vis-CheBI20 dataset</td>
      </tr>
  </tbody>
</table>
<h3 id="limitations">Limitations</h3>
<p>The authors acknowledge several limitations:</p>
<ul>
<li>The long-tail distribution of functional groups in training data limits performance on uncommon chemical structures.</li>
<li>Mol-VL struggles with exact syntax generation (SMILES and IUPAC) compared to explicit graph-reconstruction approaches.</li>
<li>Vis-CheBI20 images are synthetically generated via RDKit, which may not fully capture the diversity of real-world journal and patent images.</li>
<li>The authors note that OCSU technologies should be restricted to research purposes, as downstream molecule discovery applications could potentially generate harmful molecules.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{fanOCSUOpticalChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{OCSU: Optical Chemical Structure Understanding for Molecule-centric Scientific Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{OCSU}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Fan, Siqi and Xie, Yuguang and Cai, Bowen and Xie, Ailin and Liu, Gaochao and Qiao, Mu and Xing, Jie and Nie, Zaiqing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2501.15415}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>GTR-CoT: Graph Traversal Chain-of-Thought for Molecules</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/gtr-mol-vlm/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/gtr-mol-vlm/</guid><description>GTR-VL uses graph traversal chain-of-thought and two-stage training to improve optical chemical structure recognition on printed and hand-drawn molecules.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, J., He, Y., Yang, H., Wu, J., Ge, L., Wei, X., Wang, Y., Li, L., Ao, H., Liu, C., Wang, B., Wu, L., &amp; He, C. (2025). GTR-CoT: Graph Traversal as Visual Chain of Thought for Molecular Structure Recognition (arXiv:2506.07553). arXiv. <a href="https://doi.org/10.48550/arXiv.2506.07553">https://doi.org/10.48550/arXiv.2506.07553</a></p>
<p><strong>Publication</strong>: arXiv preprint (2025)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2506.07553">Paper on arXiv</a></li>
</ul>
<h2 id="contribution-vision-language-modeling-for-ocsr">Contribution: Vision-Language Modeling for OCSR</h2>
<p>This is a <strong>method paper</strong> that introduces GTR-VL, a Vision-Language Model for Optical Chemical Structure Recognition (OCSR). The work addresses the persistent challenge of converting molecular structure images into machine-readable formats, with a particular focus on handling chemical abbreviations that cause errors in existing systems.</p>
<h2 id="motivation-the-abbreviation-bottleneck">Motivation: The Abbreviation Bottleneck</h2>
<p>The motivation tackles a long-standing bottleneck in chemical informatics: most existing OCSR systems produce incorrect structures when they encounter abbreviated functional groups. When a chemist draws &ldquo;Ph&rdquo; for phenyl or &ldquo;Et&rdquo; for ethyl, current models fail because they have been trained on data where images contain abbreviations but the ground-truth labels contain fully expanded molecular graphs.</p>
<p>This creates a fundamental mismatch. The model sees &ldquo;Ph&rdquo; in the image but is told the &ldquo;correct&rdquo; answer is a full benzene ring. The supervision signal is inconsistent with what is actually visible.</p>
<p>Beyond this data problem, existing graph-parsing methods use a two-stage approach: predict all atoms first, then predict all bonds. This is inefficient and ignores the structural constraints that could help during prediction. The authors argue that mimicking how humans analyze molecular structures - following bonds from atom to atom in a connected traversal - would be more effective.</p>
<h2 id="novelty-graph-traversal-as-visual-chain-of-thought">Novelty: Graph Traversal as Visual Chain-of-Thought</h2>
<p>The novelty lies in combining two key insights about how to properly train and architect OCSR systems. The main contributions are:</p>
<ol>
<li>
<p><strong>Graph Traversal as Visual Chain of Thought</strong>: GTR-VL generates molecular graphs by traversing them sequentially, predicting an atom, then its connected bond, then the next atom, and so on. This mimics how a human chemist would trace through a structure and allows the model to use previously predicted atoms and bonds as context for subsequent predictions.</p>
<p>Formally, the model output sequence for image $I_m$ is generated as:</p>
<p>$$ R_m = \text{concat}(CoT_m, S_m) $$</p>
<p>where $CoT_m$ represents the deterministic graph traversal steps (atoms and bonds) and $S_m$ is the final SMILES representation. This intermediate reasoning step makes the model more interpretable and helps it learn the structural logic of molecules.</p>
</li>
<li>
<p><strong>&ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; Principle</strong>: This addresses the abbreviation problem head-on. The authors correct the ground-truth annotations to match what&rsquo;s actually visible in the image.</p>
<p>They treat abbreviations like &ldquo;Ph&rdquo; as single &ldquo;superatoms&rdquo; and build a pipeline to automatically detect and correct training data. Using OCR to extract visible text from molecular images, they replace the corresponding expanded substructures in the ground-truth with the appropriate abbreviation tokens. This ensures the supervision signal is consistent with the visual input.</p>
</li>
<li>
<p><strong>Large-Scale Dataset (GTR-1.3M)</strong>: To support this approach, the authors created a large-scale dataset combining 1M synthetic molecules from PubChem with 351K corrected real-world patent images from USPTO. The key innovation is the correction pipeline that identifies abbreviations in patent images and fixes the inconsistent ground-truth labels.</p>
</li>
<li>
<p><strong>GRPO for Hand-Drawn OCSR</strong>: Hand-drawn molecular data lacks fine-grained atom/bond coordinate annotations, making SFT-based graph parsing inapplicable. The authors use Group Relative Policy Optimization (GRPO) with a composite reward function that combines format, SMILES, and graph-level rewards. The graph reward computes the maximum common subgraph (MCS) between predicted and ground-truth molecular graphs:</p>
<p>$$ R_{\text{graph}} = \frac{|N_m^a|}{|N_g^a| + |N_p^a|} + \frac{|N_m^b|}{|N_g^b| + |N_p^b|} $$</p>
<p>where $N_m^a$, $N_g^a$, $N_p^a$ are atom counts in the MCS, ground truth, and prediction, and $N_m^b$, $N_g^b$, $N_p^b$ are the corresponding bond counts.</p>
</li>
<li>
<p><strong>Two-Stage Training</strong>: Stage 1 performs SFT on GTR-1.3M for printed molecule recognition. Stage 2 applies GRPO on a mixture of printed data (GTR-USPTO-4K) and hand-drawn data (DECIMER-HD-Train, 4,070 samples) to extend capabilities to hand-drawn structures.</p>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: Traditional SMILES-based evaluation fails for molecules with abbreviations because canonicalization breaks down. The authors created a new benchmark that evaluates graph structure directly, providing three metrics: direct SMILES generation, graph-derived SMILES, and exact graph matching.</p>
</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The evaluation focused on demonstrating that GTR-VL&rsquo;s design principles solve real problems that plague existing OCSR systems:</p>
<ol>
<li>
<p><strong>Comprehensive Baseline Comparison</strong>: GTR-VL was tested against three categories of models:</p>
<ul>
<li><strong>Specialist OCSR systems</strong>: MolScribe and MolNexTR</li>
<li><strong>Chemistry-focused VLMs</strong>: ChemVLM, ChemDFM-X, OCSU</li>
<li><strong>General-purpose VLMs</strong>: GPT-4o, GPT-4o-mini, Qwen-VL-Max</li>
</ul>
</li>
<li>
<p><strong>MolRec-Bench Evaluation</strong>: The new benchmark includes two subsets of patent images:</p>
<ul>
<li><strong>MolRec-USPTO</strong>: 5,423 standard patent images similar to existing benchmarks</li>
<li><strong>MolRec-Abb</strong>: 9,311 molecular images with abbreviated superatoms, derived from MolGrapher&rsquo;s USPTO 10K abb subset</li>
</ul>
<p>This design directly tests whether models can handle the abbreviation problem that breaks existing systems.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Systematic experiments isolated the contribution of key design choices:</p>
<ul>
<li><strong>Chain-of-Thought vs. Direct</strong>: Comparing graph traversal CoT against direct SMILES prediction</li>
<li><strong>Traversal Strategy</strong>: Graph traversal vs. the traditional &ldquo;atoms-then-bonds&rdquo; approach</li>
<li><strong>Dataset Quality</strong>: Training on corrected vs. uncorrected data</li>
</ul>
</li>
<li>
<p><strong>Retraining Experiments</strong>: Existing specialist models (MolScribe, MolNexTR) were retrained from scratch on the corrected GTR-1.3M dataset to isolate the effect of data quality from architectural improvements.</p>
</li>
<li>
<p><strong>Hand-Drawn OCSR Evaluation</strong>: GTR-VL was also evaluated on the DECIMER Hand-drawn test set and ChemPix dataset, comparing against DECIMER and AtomLenz+EditKT baselines.</p>
</li>
<li>
<p><strong>Qualitative Analysis</strong>: Visual inspection of predictions on challenging cases with heavy abbreviation usage, complex structures, and edge cases to understand failure modes.</p>
</li>
</ol>
<h2 id="results--conclusions-resolving-the-abbreviation-bottleneck">Results &amp; Conclusions: Resolving the Abbreviation Bottleneck</h2>
<ul>
<li>
<p><strong>Performance Gains on Abbreviations</strong>: On MolRec-Abb, GTR-VL-Stage1 achieves 85.49% Graph accuracy compared to around 20% for MolScribe and MolNexTR with their original checkpoints. On MolRec-USPTO, GTR-VL-Stage1 reaches 93.45% Graph accuracy. Existing specialist models see their accuracy drop below 20% on MolRec-Abb when abbreviations are present.</p>
</li>
<li>
<p><strong>Data Correction is Critical</strong>: When MolScribe and MolNexTR were retrained on GTR-1.3M, their MolRec-Abb Graph accuracy jumped from around 20% to 70.60% and 71.85% respectively. GTR-VL-Stage1 still outperformed these retrained baselines at 85.49%, confirming that both data correction and the graph traversal approach contribute.</p>
</li>
<li>
<p><strong>Chain-of-Thought Helps</strong>: Ablation on GTR-USPTO-351K shows that CoT yields 68.85% Gen-SMILES vs. 66.54% without CoT, a 2.31 percentage point improvement.</p>
</li>
<li>
<p><strong>Graph Traversal Beats Traditional Parsing</strong>: Graph traversal achieves 83.26% Graph accuracy vs. 80.15% for the atoms-then-bonds approach, and 81.88% vs. 79.02% on Gra-SMILES.</p>
</li>
<li>
<p><strong>General VLMs Still Struggle</strong>: General-purpose VLMs like GPT-4o scored near 0% on MolRec-Bench across all metrics, highlighting the importance of domain-specific training for OCSR.</p>
</li>
<li>
<p><strong>Hand-Drawn Recognition via GRPO</strong>: GTR-VL-Stage1 (SFT only) achieves only 9.53% Graph accuracy on DECIMER-HD-Test, but after GRPO training in Stage 2, performance jumps to 75.44%. On ChemPix, Graph accuracy rises from 22.02% to 86.13%. The graph reward is essential: GRPO without graph supervision achieves only 11.00% SMILES on DECIMER-HD-Test, while adding graph reward reaches 75.64%.</p>
</li>
<li>
<p><strong>Evaluation Methodology Matters</strong>: The new graph-based evaluation metrics revealed problems with traditional SMILES-based evaluation that previous work had missed. Many &ldquo;failures&rdquo; in existing benchmarks were actually correct graph predictions that got marked wrong due to canonicalization issues with abbreviations.</p>
</li>
</ul>
<p>The work establishes that addressing the abbreviation problem requires both correcting the training data and rethinking the model architecture. The combination of faithful data annotation and sequential graph generation improves OCSR performance on molecules with abbreviations by a large margin over previous methods.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Base Model</strong>: GTR-VL fine-tunes <strong>Qwen2.5-VL</strong>.</p>
<p><strong>Input/Output Mechanism</strong>:</p>
<ul>
<li><strong>Input</strong>: The model takes an image $I_m$ and a text prompt</li>
<li><strong>Output</strong>: The model generates $R_m = \text{concat}(CoT_m, S_m)$, where it first produces the Chain-of-Thought (the graph traversal steps) followed immediately by the final SMILES string</li>
<li><strong>Traversal Strategy</strong>: Uses <strong>depth-first traversal</strong> to alternately predict atoms and bonds</li>
</ul>
<p><strong>Prompt Structure</strong>: The model is prompted to &ldquo;list the types of atomic elements&hellip; the coordinates&hellip; and the chemical bonds&hellip; then&hellip; output a canonical SMILES&rdquo;. The CoT output is formatted as a JSON list of atoms (with coordinates) and bonds (with indices referring to previous atoms), interleaved.</p>
<h3 id="data">Data</h3>
<p><strong>Training Dataset (GTR-1.3M)</strong>:</p>
<ul>
<li><strong>Synthetic Component</strong>: 1 million molecular SMILES from PubChem, converted to images using Indigo</li>
<li><strong>Real Component</strong>: 351,000 samples from USPTO patents (filtered from an original 680,000)
<ul>
<li>Processed using an OCR pipeline to detect abbreviations (e.g., &ldquo;Ph&rdquo;, &ldquo;Et&rdquo;)</li>
<li>Ground truth expanded structures replaced with superatoms to match visible abbreviations in images</li>
<li>This &ldquo;Faithfully Recognize What You&rsquo;ve Seen&rdquo; correction ensures training supervision matches visual input</li>
</ul>
</li>
</ul>
<p><strong>Evaluation Dataset (MolRec-Bench)</strong>:</p>
<ul>
<li><strong>MolRec-USPTO</strong>: 5,423 molecular images from USPTO patents</li>
<li><strong>MolRec-Abb</strong>: 9,311 molecular images with abbreviated superatoms, derived from MolGrapher&rsquo;s USPTO 10K abb subset</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Graph Traversal Algorithm</strong>:</p>
<ul>
<li>Depth-first traversal strategy</li>
<li>Alternating atom-bond prediction sequence</li>
<li>Each step uses previously predicted atoms and bonds as context</li>
</ul>
<p><strong>Two-Stage Training</strong>:</p>
<ul>
<li><strong>Stage 1 (SFT)</strong>: Train on GTR-1.3M to learn visual CoT mechanism for printed molecules (produces GTR-VL-Stage1)</li>
<li><strong>Stage 2 (GRPO)</strong>: Apply GRPO on GTR-USPTO-4K + DECIMER-HD-Train (4,070 samples) for hand-drawn recognition (produces GTR-VL-Stage2, i.e., GTR-VL)</li>
</ul>
<p><strong>Training Procedure</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate (SFT)</strong>: Peak learning rate of $1.6 \times 10^{-4}$ with cosine decay</li>
<li><strong>Learning Rate (GRPO)</strong>: Peak learning rate of $1 \times 10^{-5}$ with cosine decay</li>
<li><strong>Warm-up</strong>: Linear warm-up for the first 10% of iterations</li>
<li><strong>Batch Size (SFT)</strong>: 2 per GPU with gradient accumulation over 16 steps, yielding <strong>effective batch size of 1024</strong></li>
<li><strong>Batch Size (GRPO)</strong>: 4 per GPU with gradient accumulation of 1, yielding <strong>effective batch size of 128</strong></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong> (three complementary measures to handle abbreviation issues):</p>
<ul>
<li><strong>Gen-SMILES</strong>: Exact match ratio of SMILES strings directly generated by the VLM (image-captioning style)</li>
<li><strong>Gra-SMILES</strong>: Exact match ratio of SMILES strings derived from the predicted graph structure (graph-parsing style)</li>
<li><strong>Graph</strong>: Exact match ratio between ground truth and predicted graphs (node/edge comparison, bypassing SMILES canonicalization issues)</li>
</ul>
<p><strong>Baselines Compared</strong>:</p>
<ul>
<li>Specialist OCSR systems: MolScribe, MolNexTR</li>
<li>Chemistry-focused VLMs: ChemVLM, ChemDFM-X, OCSU</li>
<li>General-purpose VLMs: GPT-4o, GPT-4o-mini, Qwen-VL-Max</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute</strong>: Training performed on <strong>32 NVIDIA A100 GPUs</strong></p>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<p><strong>Status</strong>: Closed. As of the paper&rsquo;s publication, no source code, pre-trained model weights, or dataset downloads (GTR-1.3M, MolRec-Bench) have been publicly released. The paper does not mention plans for open-source release. The training data pipeline relies on PubChem SMILES (public), USPTO patent images (publicly available through prior work), the Indigo rendering tool (open-source), and an unspecified OCR system for abbreviation detection. Without the released code and data corrections, reproducing the full pipeline would require substantial re-implementation effort.</p>
]]></content:encoded></item><item><title>Can Recurrent Neural Networks Warp Time? (ICLR 2018)</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/can-recurrent-neural-networks-warp-time/</link><pubDate>Sat, 14 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/can-recurrent-neural-networks-warp-time/</guid><description>Tallec and Ollivier's ICLR 2018 paper deriving gating mechanisms in RNNs from time warping invariance and proposing chrono initialization for LSTMs.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>theory paper</strong> that provides a principled derivation of gating mechanisms in recurrent neural networks from an axiom of invariance to time transformations. The theoretical insights also yield a practical contribution: the <strong>chrono initialization</strong> for LSTM gate biases.</p>
<h2 id="why-time-warping-invariance-matters-for-recurrent-models">Why time warping invariance matters for recurrent models</h2>
<p>Standard recurrent neural networks are highly sensitive to changes in the time scale of their input data. Inserting a fixed number of blank steps between elements of an input sequence can make an otherwise easy task impossible for a vanilla RNN to learn. This fragility arises because the class of functions representable by an ordinary RNN is not closed under time rescaling.</p>
<p>The vanishing gradient problem compounds this issue: learning long-term dependencies requires gradient signals to persist across many time steps, but stability of the dynamical system causes these signals to decay exponentially. Prior solutions include gating mechanisms (LSTMs, GRUs) introduced on engineering grounds, and orthogonal weight constraints that limit representational power and make forgetting difficult.</p>
<p>Tallec and Ollivier ask a clean theoretical question: what structural properties must a recurrent model have to be invariant to arbitrary time transformations in its input?</p>
<h2 id="deriving-gates-from-time-warping-invariance">Deriving gates from time warping invariance</h2>
<p>The core insight starts from the continuous-time formulation of a basic RNN:</p>
<p>$$\frac{\mathrm{d}h(t)}{\mathrm{d}t} = \tanh(W_x x(t) + W_h h(t) + b) - h(t)$$</p>
<p>Applying a time warping $t \gets c(t)$ (any increasing differentiable function) to the input data $x(c(t))$ transforms this equation into:</p>
<p>$$\frac{\mathrm{d}h(t)}{\mathrm{d}t} = \frac{\mathrm{d}c(t)}{\mathrm{d}t} \tanh(W_x x(t) + W_h h(t) + b) - \frac{\mathrm{d}c(t)}{\mathrm{d}t} h(t)$$</p>
<p>The derivative $\frac{\mathrm{d}c(t)}{\mathrm{d}t}$ of the time warping appears as a multiplicative factor. For the model class to represent this equation for any time warping, a learnable function $g(t)$ must replace the unknown derivative:</p>
<p>$$\frac{\mathrm{d}h(t)}{\mathrm{d}t} = g(t) \tanh(W_x x(t) + W_h h(t) + b) - g(t) h(t)$$</p>
<p>Discretizing with a Taylor expansion ($\delta t = 1$) yields:</p>
<p>$$h_{t+1} = g_t \tanh(W_x x_t + W_h h_t + b) + (1 - g_t) h_t$$</p>
<p>This is a gated recurrent network with input gate $g_t$ and forget gate $(1 - g_t)$, where $g_t$ is computed by a sigmoid function of the inputs. The value $1/g(t_0)$ represents the local forgetting time of the network at time $t_0$.</p>
<h3 id="the-special-case-of-linear-time-rescaling">The special case of linear time rescaling</h3>
<p>For the simpler case of a constant time rescaling $c(t) = \alpha t$, the same derivation produces a leaky RNN:</p>
<p>$$h_{t+1} = \alpha \tanh(W_x x_t + W_h h_t + b) + (1 - \alpha) h_t$$</p>
<p>Leaky RNNs are invariant to global time rescalings but fail with variable warpings. Full gating (where $g_t$ depends on the input) is required for invariance to general time warpings.</p>
<h3 id="per-unit-gates-and-the-connection-to-lstms">Per-unit gates and the connection to LSTMs</h3>
<p>Extending to per-unit gates $g_t^i$ allows different units to operate at different characteristic timescales:</p>
<p>$$h_{t+1}^i = g_t^i \tanh(W_x^i x_t + W_h^i h_t + b^i) + (1 - g_t^i) h_t^i$$</p>
<p>This closely resembles the LSTM cell update equation, where $(1 - g_t^i)$ corresponds to the forget gate $f_t$ and $g_t^i$ corresponds to the input gate $i_t$. The derivation naturally ties these two gates (they sum to 1), a constraint that has been used successfully in practice.</p>
<h2 id="chrono-initialization-for-gate-biases">Chrono initialization for gate biases</h2>
<p>The theoretical framework provides a principled initialization strategy. If the sequential data has temporal dependencies in a range $[T_{\text{min}}, T_{\text{max}}]$, then gate values $g$ should lie in $[1/T_{\text{max}}, 1/T_{\text{min}}]$. Since gate values center around $\sigma(b_g)$ when inputs are centered, the biases should be initialized as:</p>
<p>$$b_g \sim -\log(\mathcal{U}([T_{\text{min}}, T_{\text{max}}]) - 1)$$</p>
<p>For LSTMs specifically, the <strong>chrono initialization</strong> sets:</p>
<p>$$b_f \sim \log(\mathcal{U}([1, T_{\text{max}} - 1]))$$
$$b_i = -b_f$$</p>
<p>where $T_{\text{max}}$ is the expected range of long-term dependencies. This contrasts with the standard practice of setting forget gate biases to 1 or 2.</p>
<h2 id="experimental-validation">Experimental validation</h2>
<h3 id="time-warping-robustness">Time warping robustness</h3>
<p>On a character recall task with artificially warped sequences, three architectures are compared (64 units each):</p>
<ul>
<li><strong>Vanilla RNNs</strong> fail with even moderate warping coefficients</li>
<li><strong>Leaky RNNs</strong> perfectly solve uniform warpings but fail with variable warpings</li>
<li><strong>Gated RNNs</strong> achieve perfect performance under both uniform and variable warpings for all tested warping factors</li>
</ul>
<p>This directly validates the theory: leaky RNNs handle constant time rescalings, but only gated models handle general time warpings.</p>
<h3 id="synthetic-tasks-copy-and-adding">Synthetic tasks (copy and adding)</h3>
<p>Using 128-unit LSTMs:</p>
<ul>
<li><strong>Copy task</strong> ($T = 500, 2000$): Chrono initialization converges to the solution while standard initialization plateaus at the memoryless baseline</li>
<li><strong>Variable copy</strong> ($T = 500, 1000$): Chrono matches standard for smaller $T$ but outperforms for $T = 1000$</li>
<li><strong>Adding task</strong> ($T = 200, 750$): Chrono converges significantly faster, approximately 7x faster for $T = 750$</li>
</ul>
<h3 id="real-world-tasks">Real-world tasks</h3>
<ul>
<li><strong>Permuted MNIST</strong> (512-unit LSTM): Chrono achieves 96.3% vs. 95.4% for standard initialization</li>
<li><strong>Character-level text8</strong> (2000-unit LSTM): Slight improvement (1.37 vs. 1.38 bits-per-character)</li>
<li><strong>Word-level Penn Treebank</strong> (10-layer RHN): Comparable results to the baseline (65.4 test perplexity)</li>
</ul>
<p>Short-term dependency tasks show minimal differences, consistent with the theory that chrono initialization primarily helps when long-term dependencies dominate.</p>
<h2 id="limitations">Limitations</h2>
<p>The continuous-to-discrete time correspondence relies on a Taylor expansion with step size $\delta t = 1$. This approximation holds when the derivative of the time warping is not too large ($g_t \lesssim 1$). Discrete-time gated models are therefore invariant to time warpings that stretch time (such as interspersing data with blanks or introducing long-term dependencies), but they cannot handle warpings that compress events faster than the model&rsquo;s time step. Additionally, the chrono initialization requires specifying $T_{\text{max}}$, the expected range of long-term dependencies, which may not be known in advance.</p>
<h2 id="reproducibility">Reproducibility</h2>
<p><strong>Status: Partially Reproducible.</strong></p>
<p>The paper describes all hyperparameters, architectures, and training procedures in sufficient detail to reproduce the experiments. The synthetic tasks (copy, adding, time warping) follow standard setups from prior work with clearly specified parameters. The real-world experiments (permuted MNIST, text8, Penn Treebank) use established benchmarks with referenced codebases (the text8 setup reuses code from Cooijmans et al. 2016).</p>
<p>The chrono initialization itself requires minimal implementation effort: it only changes the bias initialization of gate units, with no modifications to the model architecture or training procedure.</p>
<p>No official code repository is provided by the authors. No pre-trained models or datasets beyond standard benchmarks are released.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Tallec, C. &amp; Ollivier, Y. (2018). Can recurrent neural networks warp time? <em>International Conference on Learning Representations (ICLR 2018)</em>.</p>
<p><strong>Publication</strong>: ICLR 2018</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=SJcKhk-Ab">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/1804.11188">arXiv</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{tallec2018can,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Can recurrent neural networks warp time?}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Tallec, Corentin and Ollivier, Yann}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemDFM-R: Chemical Reasoning LLM with Atomized Knowledge</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemdfm-r/</link><pubDate>Fri, 26 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemdfm-r/</guid><description>A 14B-parameter chemical reasoning LLM enhanced with atomized functional group knowledge and mix-sourced distillation strategy.</description><content:encoded><![CDATA[<h2 id="method-and-resource-contributions">Method and Resource Contributions</h2>
<p>This is primarily a <strong>Method</strong> paper with significant <strong>Resource</strong> contributions.</p>
<ul>
<li><strong>Methodological Basis</strong>: The paper introduces a training pipeline (&ldquo;mix-sourced distillation&rdquo;) and domain-specific reinforcement learning to improve reasoning capabilities in chemical LLMs. It validates the approach through ablation studies across training stages.</li>
<li><strong>Resource Contribution</strong>: The authors constructed <strong>ChemFG</strong>, a 101 billion-token corpus annotated with &ldquo;atomized&rdquo; knowledge regarding functional groups and reaction centers.</li>
</ul>
<h2 id="bridging-the-chemical-reasoning-gap">Bridging the Chemical Reasoning Gap</h2>
<p>Current chemical LLMs struggle to reason logically for two main reasons:</p>
<ol>
<li><strong>Shallow Domain Understanding</strong>: Models generally learn molecule-level properties directly, bypassing the intermediate &ldquo;atomized&rdquo; characteristics (e.g., <a href="https://en.wikipedia.org/wiki/Functional_group">functional groups</a>) that ultimately dictate chemical behavior.</li>
<li><strong>Specialized Reasoning Logic</strong>: Chemical logic differs fundamentally from math or code. Distilling reasoning from general teacher models like DeepSeek-R1 frequently fails because the teachers lack the domain intuition required to generate valid chemical rationales.</li>
</ol>
<h2 id="atomized-knowledge-and-mixed-source-distillation">Atomized Knowledge and Mixed-Source Distillation</h2>
<p>The authors introduce three structural innovations to solve the reasoning gap:</p>
<ol>
<li><strong>Atomized Knowledge Enhancement (ChemFG)</strong>: A toolkit was built leveraging SMARTS notations to identify functional group changes during reactions. A critique of this approach is that it relies heavily on 2D cheminformatics abstractions, potentially missing deeper 3D stereochemical interactions.</li>
<li><strong>Mix-Sourced Distillation</strong>: General models (DeepSeek-R1/o3-mini) are fed &ldquo;pseudo-reasoning&rdquo; prompts that include ground truth answers and functional group data. While this forces the teacher to generate high-quality rationales for the student to learn, it introduces a layer of hindsight bias into the generated reasoning chains. During inference, the student model lacks both the pre-calculated functional group metadata and the ground truth, forcing it to bridge an artificially steep generalization gap.</li>
<li><strong>Chemical Reinforcement Learning</strong>: The intermediate model undergoes domain-specific reinforcement learning. The RL details are described in the paper&rsquo;s Appendix D, with the authors citing the open-source DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) framework. The optimization relies on rule-based rewards (format adherence and canonicalized <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> accuracy) across a variety of chemical tasks.</li>
</ol>
<h2 id="benchmark-evaluation-and-ablation-studies">Benchmark Evaluation and Ablation Studies</h2>
<p>The model was evaluated on comprehensive chemical benchmarks: <strong>SciKnowEval</strong> (19 tasks) and <strong><a href="/notes/chemistry/llm-applications/chemeval-multilevel-chemical-evaluation/">ChemEval</a></strong> (36 tasks).</p>
<ul>
<li><strong>Baselines</strong>: Compared against similarly sized open models (Qwen2.5-14B-Instruct, Qwen3-14B), domain models (<a href="/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/">ChemLLM</a>, MolInst), and frontier models (GPT-4o, DeepSeek-R1).</li>
<li><strong>Ablation</strong>: Evaluated across training stages (Base → ChemDFM-I → ChemDFM-R) to measure the specific impact of the instruction tuning versus the reasoning stages.</li>
<li><strong>Qualitative Analysis</strong>: The paper includes case studies demonstrating the model&rsquo;s step-by-step chemical reasoning and its potential for human-AI collaboration (Sections 4.2 and 4.3).</li>
</ul>
<h2 id="performance-outcomes-and-numerical-limitations">Performance Outcomes and Numerical Limitations</h2>
<ul>
<li><strong>Performance vs. Baselines</strong>: ChemDFM-R outperforms similarly sized open models and domain models on molecule-centric and reaction-centric tasks, and surpasses the much larger DeepSeek-R1 on ChemEval (0.78 vs. 0.58 overall). It shows competitive results relative to o4-mini, though o4-mini leads on SciKnowEval (0.74 vs. 0.70).</li>
<li><strong>Reasoning Interactivity</strong>: The model generates readable rationales that allow users to catch structural errors or identify reaction mechanisms accurately. Section 4.3 of the paper demonstrates human-AI collaboration scenarios.</li>
<li><strong>Quantitative Limitations</strong>: The model struggles with tasks involving numerical prediction and calculation (e.g., yield extraction, molecular property calculation). The paper notes that all molecule-centric and reaction-centric tasks where ChemDFM-R falls short of Qwen2.5-14B-Instruct involve numerical reasoning.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is constructed in three phases:</p>
<p><strong>1. Domain Pre-training (ChemFG)</strong>:</p>
<ul>
<li><strong>Size</strong>: 101 billion tokens</li>
<li><strong>Composition</strong>:
<ul>
<li>12M literature documents (79B tokens)</li>
<li>30M molecules from PubChem/PubChemQC</li>
<li>7M reactions from USPTO-FULL</li>
</ul>
</li>
<li><strong>Augmentation</strong>: SMILES augmentation (10x) using R-SMILES</li>
<li><strong>Atomized Features</strong>: Annotated with a custom &ldquo;Functional Group Identification Toolkit&rdquo; that identifies 241 functional group types and tracks changes in reaction centers. <em>Note: Data and toolkit are partially reproduced; while the toolkit (<a href="https://github.com/OpenDFM/ChemFG-Tool">ChemFG-Tool</a>) was open-sourced on GitHub, the 101 billion-token ChemFG dataset itself has not been publicly released.</em></li>
</ul>
<p><strong>2. Instruction Tuning</strong>:</p>
<ul>
<li><strong>Sources</strong>: Molecule-centric (<a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a>, <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>), Reaction-centric (USPTO), and Knowledge-centric (Exams, Literature QA) tasks</li>
<li><strong>Mixing</strong>: Mixed with general instruction data in a 1:2 ratio</li>
</ul>
<p><strong>3. Distillation Dataset</strong>:</p>
<ul>
<li><strong>Sources</strong>:
<ul>
<li>~70% ChemDFM-R instruction data</li>
<li>~22% constructed pseudo-reasoning (functional group descriptions)</li>
<li>~8% teacher rationales (from DeepSeek-R1/o3-mini)</li>
</ul>
</li>
<li><strong>Mixing</strong>: Mixed with general data (including AM-Deepseek-R1-Distill-1.4M) in a 1:2 ratio</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Functional Group Identification</strong>:</p>
<ul>
<li>Extends the <code>thermo</code> library&rsquo;s SMARTS list</li>
<li>For reactions, identifies &ldquo;reacting functional groups&rdquo; by finding reactants containing atoms involved in bond changes (reaction centers) that do not appear in the product</li>
</ul>
<p><strong>Mix-Sourced Distillation</strong>:</p>
<ul>
<li>Teacher models (DeepSeek-R1, o3-mini) are prompted with Question + Ground Truth + Functional Group Info to generate high-quality &ldquo;Thoughts&rdquo;</li>
<li>These rationales are distilled into the student model using a supervised fine-tuning loss across target tokens $y_t$:
$$ \mathcal{L}_{\text{SFT}} = - \sum_{t=1}^T \log P_\theta(y_t \mid x, y_{&lt;t}) $$</li>
</ul>
<p><strong>Reinforcement Learning</strong>:</p>
<ul>
<li><strong>Algorithm</strong>: The paper cites DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) as the RL framework; full details are in Appendix D of the paper. <em>Note: While the underlying DAPO framework is open-source, the specific chemistry-oriented RL pipeline and environment used for ChemDFM-R has not been publicly released.</em></li>
<li><strong>Hyperparameters</strong> (from paper appendix): Learning rate <code>5e-7</code>, rollout batch size <code>512</code>, training batch size <code>128</code></li>
<li><strong>Rewards</strong>: The reward system applies rule-based constraints focusing on physical form and chemical validity. The total reward $R(y, y^*)$ for a generated response $y$ given target $y^*$ combines a format adherence reward ($R_{\text{format}}$) and an accuracy reward ($R_{\text{acc}}$) evaluated on canonicalized SMILES:
$$ R(y, y^*) = R_{\text{format}}(y) + R_{\text{acc}}(\text{canonicalize}(y), \text{canonicalize}(y^*)) $$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Base Model</strong>: Qwen2.5-14B</li>
<li><strong>ChemDFM-I</strong>: Result of instruction tuning the domain-pretrained model for 2 epochs</li>
<li><strong>ChemDFM-R</strong>: Result of applying mix-sourced distillation (1 epoch) followed by RL on ChemDFM-I. <em>Note: Model weights are publicly available on <a href="https://huggingface.co/OpenDFM/ChemDFM-R-14B">Hugging Face</a>.</em></li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Hardware and training time details are described in the paper&rsquo;s appendices, which are not available in the extracted text. The details below are reported from the paper but could not be independently cross-verified against the main text:</p>
<ul>
<li><strong>Compute</strong>: NVIDIA A800 Tensor Core GPUs</li>
<li><strong>Training Time</strong>: 30,840 GPU hours total (Domain Pretraining: 24,728 hours; Instruction Tuning: 3,785 hours; Distillation: 2,059 hours; Reinforcement Learning: 268 hours)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>SciKnowEval</strong>: 19 tasks (text-centric, molecule-centric, reaction-centric)</li>
<li><strong>ChemEval</strong>: 36 tasks, categorized similarly</li>
</ul>
<p><strong>Key Metrics</strong>: Accuracy, F1 Score, BLEU score (with PRS normalization for ChemEval)</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>SciKnowEval (all)</th>
          <th>ChemEval* (all)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Qwen2.5-14B-Instruct</td>
          <td>0.61</td>
          <td>0.57</td>
          <td>General-domain baseline</td>
      </tr>
      <tr>
          <td>ChemDFM-I</td>
          <td>0.69</td>
          <td>0.72</td>
          <td>After domain pretraining + instruction tuning</td>
      </tr>
      <tr>
          <td>ChemDFM-R</td>
          <td><strong>0.70</strong></td>
          <td><strong>0.78</strong></td>
          <td>After distillation + RL</td>
      </tr>
      <tr>
          <td>DeepSeek-R1</td>
          <td>0.62</td>
          <td>0.58</td>
          <td>General-domain reasoning model</td>
      </tr>
      <tr>
          <td>o4-mini</td>
          <td><strong>0.74</strong></td>
          <td>0.69</td>
          <td>Frontier reasoning model</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/OpenDFM/ChemDFM-R-14B">ChemDFM-R-14B</a></td>
          <td>Model</td>
          <td>AGPL-3.0</td>
          <td>Final reasoning model weights on Hugging Face</td>
      </tr>
      <tr>
          <td><a href="https://github.com/OpenDFM/ChemFG-Tool">ChemFG-Tool</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Functional group identification toolkit (241 groups)</td>
      </tr>
  </tbody>
</table>
<p><strong>Missing components</strong>: The 101B-token ChemFG pretraining dataset is not publicly released. The chemistry-oriented RL pipeline and training code are not open-sourced. The instruction tuning and distillation datasets are not available.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhao, Z., Chen, B., Wan, Z., Chen, L., Lin, X., Yu, S., Zhang, S., Ma, D., Zhu, Z., Zhang, D., Wang, H., Dai, Z., Wen, L., Chen, X., &amp; Yu, K. (2025). ChemDFM-R: A Chemical Reasoning LLM Enhanced with Atomized Chemical Knowledge. <em>arXiv preprint arXiv:2507.21990</em>. <a href="https://doi.org/10.48550/arXiv.2507.21990">https://doi.org/10.48550/arXiv.2507.21990</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{zhao2025chemdfmr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemDFM-R: A Chemical Reasoning LLM Enhanced with Atomized Chemical Knowledge}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Zihan Zhao and Bo Chen and Ziping Wan and Lu Chen and Xuanze Lin and Shiyang Yu and Situo Zhang and Da Ma and Zichen Zhu and Danyang Zhang and Huayang Wang and Zhongyang Dai and Liyang Wen and Xin Chen and Kai Yu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2507.21990}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.CE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2507.21990}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa-3: Open Source Chemical Foundation Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-3/</link><pubDate>Fri, 26 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-3/</guid><description>An open-source framework integrating DeepChem and Ray for training and benchmarking chemical foundation models like MoLFormer and GROVER at scale.</description><content:encoded><![CDATA[<h2 id="core-contribution-an-open-source-framework">Core Contribution: An Open-Source Framework</h2>
<p>This is primarily a <strong>Resource ($\Psi_{\text{Resource}}$)</strong> paper, with secondary <strong>Method ($\Psi_{\text{Method}}$)</strong> contributions.</p>
<ul>
<li><strong>Resource Basis</strong>: The core contribution is &ldquo;ChemBERTa-3,&rdquo; an open-source framework integrated into DeepChem that standardizes the pretraining and benchmarking of chemical foundation models. The authors focus heavily on infrastructure (AWS/Ray integration) and correcting benchmarking inconsistencies in the field.</li>
<li><strong>Method Basis</strong>: It trains models like &ldquo;c3-MoLFormer&rdquo; to reproduce and validate the infrastructure.</li>
</ul>
<h2 id="the-pretraining-scalability-challenge">The Pretraining Scalability Challenge</h2>
<ul>
<li><strong>Scalability Challenges</strong>: Building robust molecular models is difficult due to the vast size of chemical space and the computational intensity of pretraining on large datasets.</li>
<li><strong>Proprietary Barriers</strong>: Many high-performing chemical foundation models (e.g., the full <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer-XL</a>) are partially closed-source or difficult to reproduce.</li>
<li><strong>Benchmarking Inconsistencies</strong>: There is a lack of systematic comparison between architectures (e.g., Graph vs. Transformer) using unified protocols. Specifically, previous comparisons relied on reported results that used differing scaffold splitting algorithms, making them inaccurate.</li>
</ul>
<h2 id="unified-infrastructure--standardized-benchmarking">Unified Infrastructure &amp; Standardized Benchmarking</h2>
<ul>
<li><strong>Unified Infrastructure</strong>: Integration of DeepChem with Ray for distributed, scalable pretraining and fine-tuning of both graph and transformer models.</li>
<li><strong>Standardized Benchmarking</strong>: Identification that MoLFormer&rsquo;s scaffold splitting algorithm differs from the standard DeepChem/<a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> splitter, and the subsequent standardization of these benchmarks for fair comparison.</li>
<li><strong>New DeepChem Tools</strong>: Introduction of the <code>ModularTorchModel</code> class for flexible loss computation and <code>HuggingFaceModel</code> wrappers to bridge ecosystems.</li>
</ul>
<h2 id="benchmarking-transformers-vs-graph-models">Benchmarking Transformers vs. Graph Models</h2>
<ul>
<li><strong>Architecture Comparison</strong>: Benchmarked Transformers (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a>) against Graph models (GROVER, InfoGraph, InfoMax3D, DMPNN, GCN) and baselines (Random Forest).</li>
<li><strong>Pretraining Scale Disparity</strong>:
<ul>
<li>Transformers were pretrained on ZINC20 subsets ranging from 10M to 1.1B molecules (combining ZINC and PubChem).</li>
<li>Graph models were limited to 250K molecule subsets due to memory and computational overhead of message passing on large graphs. While this highlights the superior scalability of Transformer architectures, comparing a 1.1B-trained Transformer to a 250K-trained Graph model provides an unbalanced evaluation of architectural capacity.</li>
</ul>
</li>
<li><strong>Reproducibility Validation</strong>: Trained &ldquo;c3-MoLFormer&rdquo; (a reproduction of MoLFormer) on 1.1B molecules using two distinct hardware setups: AWS spot instances (Ray) and a local HPC cluster.</li>
<li><strong>Scaffold Split Analysis</strong>: Compared performance metrics using &ldquo;DeepChem scaffold splits&rdquo; vs. &ldquo;MoLFormer scaffold splits&rdquo; to quantify the impact of data leakage/overlap.</li>
</ul>
<h2 id="overcoming-scaffold-splitting-inconsistencies">Overcoming Scaffold Splitting Inconsistencies</h2>
<ul>
<li><strong>Scaling Transformers vs. Graphs</strong>: Transformer-based models are significantly easier to scale to large datasets than current graph-based approaches, though performance is comparable at small scales.</li>
<li><strong>Benchmarking sensitivity</strong>: MoLFormer&rsquo;s reported superiority over baselines was partly inflated by its specific scaffold splitting method, which had higher structural overlap between train and test sets (yielding a lower <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto distance</a>, generally quantified via $1 - \frac{|A \cap B|}{|A \cup B|}$) than DeepChem splits. When standardized, baselines like DMPNN perform more competitively.</li>
<li><strong>Infrastructure Viability</strong>: The framework successfully replicated large-scale training (MoLFormer-1.1B) on both cloud and on-premise HPC, confirming reproducibility.</li>
<li><strong>Open Source Release</strong>: All code, configurations, and the c3-MoLFormer-1.1B model weights are released to facilitate future research.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Pretraining</strong>:
<ul>
<li><strong>Source</strong>: <a href="/notes/chemistry/datasets/zinc-22/">ZINC20</a> (1.4B compounds) and PubChem.</li>
<li><strong>Scale</strong>: Subsets of 10M, 100M, and 1.1B (100% ZINC20 + 100% PubChem) were used for Transformers. Graph models used a 250K subset.</li>
</ul>
</li>
<li><strong>Fine-tuning</strong>:
<ul>
<li><strong>Suite</strong>: MoleculeNet.</li>
<li><strong>Tasks</strong>: Classification (BACE, BBBP, Tox21, HIV, SIDER, ClinTox) and Regression (ESOL, FreeSolv, Lipophilicity, QM9).</li>
<li><strong>Splits</strong>: Critical distinction made between &ldquo;DeepChem scaffold splits&rdquo; (80/10/10) and &ldquo;MoLFormer scaffold splits&rdquo; (which can be downloaded from <a href="https://ibm.ent.box.com/v/MoLFormer-data"><code>https://ibm.ent.box.com/v/MoLFormer-data</code></a>). The paper notes these algorithms differ.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework</strong>: DeepChem integrated with Ray for distributed training. To recreate the environment, the repository relies on a nightly version of DeepChem (<code>pip install --pre deepchem</code>) and specific dependencies found within the <code>requirements.txt</code>. Pretraining scripts are available in the <code>chemberta3_benchmarking/pretraining</code> directory of the repository.</li>
<li><strong>Data Preparation</strong>: Featurization workflows (e.g., <code>CircularFingerprint</code>, <code>RDKitConformer</code>) are documented under <code>chemberta3_benchmarking/data/data_preprocessing/</code> in the codebase.</li>
<li><strong>Modular Training</strong>: Uses <code>ModularTorchModel</code> to allow loss computation from intermediate values and flexible component connection.</li>
<li><strong>Training Brittleness</strong>:
<ul>
<li><strong>Optimizer</strong>: Linear learning rate scheduler with warmup.</li>
<li><strong>Instability Handling</strong>: The authors observed significant loss spikes during warmup. Their primary mitigation strategy involved checkpointing frequently and restarting from the last stable state upon a spike, highlighting a persistent brittleness in optimizing these large chemical foundation models.</li>
<li><strong>Numerical Issues</strong>: Addressed NaN values by pretraining on a small dataset with low LR before scaling up.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a></strong>: RoBERTa-based architecture trained with Masked Language Modeling (MLM) and Multitask Regression (MTR). Specific model identifiers (e.g., <a href="https://huggingface.co/DeepChem/ChemBERTa-100M-MLM"><code>DeepChem/ChemBERTa-100M-MLM</code></a>) are hosted on Hugging Face so researchers can pull them directly via the <code>transformers</code> library. The core pretraining objective minimized the standard MLM loss:
$$ \mathcal{L}_{\text{MLM}} = - \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log \hat{y}_{i} $$
where $\mathcal{M}$ represents the set of masked SMILES token indices, and $\hat{y}_{i}$ is the model&rsquo;s predicted probability for the correct token given the corrupted sequence context.</li>
<li><strong>MoLFormer (c3-MoLFormer)</strong>: Re-implementation of the MoLFormer architecture (Rotary embeddings, linear attention). Specific model identifiers (e.g., <a href="https://huggingface.co/DeepChem/MoLFormer-c3-1.1B"><code>DeepChem/MoLFormer-c3-1.1B</code></a>) are similarly available on Hugging Face.
<ul>
<li>Tokenizer: <code>ibm/MoLFormer-XL-both-10pct</code> tokenizer.</li>
</ul>
</li>
<li><strong>Graph Models</strong>:
<ul>
<li><strong>GROVER</strong>: Graph Transformer with node/edge/graph level self-supervision.</li>
<li><strong>InfoGraph</strong>: Maximizes mutual information between graph-level and substructure representations.</li>
<li><strong>InfoMax3D</strong>: Incorporates 3D conformer data (via <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> ETKDGv2) into contrastive pretraining.</li>
<li><strong>DMPNN</strong>: Directed Message Passing Neural Network (Chemprop variant).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: <a href="https://en.wikipedia.org/wiki/Receiver_operating_characteristic">ROC-AUC</a> for classification; RMSE for regression (MAE for QM9).</li>
<li><strong>Baselines</strong>: Random Forest, GCN, DMPNN trained on fine-tuning splits only.</li>
<li><strong>Protocol</strong>: Three independent runs per configuration to report mean and range (not a confidence interval), with the exception of the compute-heavy QM9 dataset, which only received a single run. Benchmarking execution scripts (e.g., GCN, RF, DMPNN, ChemBERTa) are stored in the repo under <code>chemberta3_benchmarking/models_benchmarking/</code> and contain the specific fine-tuning hyperparameters and optimizer configurations used for each downstream task.</li>
<li><strong>Key Results</strong>:
<ul>
<li><em>c3-MoLFormer-1.1B</em> achieved ~0.848 ROC-AUC on BACE and ~0.900 on BBBP (using MoLFormer splits). This closely matches the original IBM MoLFormer metrics, validating the reproducibility of the open-source framework.</li>
<li>When constrained to the equivalent 250K subset, Graph models (InfoGraph, GROVER) performed comparably to Transformers, indicating that Transformer superiority in chemistry is largely driven by data scalability rather than an inherent architectural advantage at small scales.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Cloud (AWS)</strong>:
<ul>
<li><strong>Compute</strong>: 40 NVIDIA T4 GPUs (<code>g4dn.12xlarge</code> spot instances for pretraining, <code>g4dn.2xlarge</code> for benchmarking).</li>
<li><strong>Cost</strong>: ~$4000 for MoLFormer 1.1B pretraining.</li>
<li><strong>Time</strong>: ~10 days (260 hours) for 1.1B model pretraining.</li>
<li><strong>Setup</strong>: Setup scripts for single-node and multi-node spot EC2 clusters are provided in the GitHub repository&rsquo;s <code>infra/</code> and <code>spot/</code> folders.</li>
</ul>
</li>
<li><strong>On-Premise HPC</strong>:
<ul>
<li><strong>Compute</strong>: 16 nodes (AMD EPYC), each with 4 AMD MI300A APUs.</li>
<li><strong>Environment</strong>: Ray multi-node multi-GPU framework.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/deepforestsci/chemberta3">ChemBERTa-3 GitHub Repository</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training, fine-tuning, and benchmarking framework</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-1.1B">DeepChem/MoLFormer-c3-1.1B</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer re-implementation pretrained on 1.1B molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/ChemBERTa-100M-MLM">DeepChem/ChemBERTa-100M-MLM</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>ChemBERTa pretrained on 100M ZINC molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-100M">DeepChem/MoLFormer-c3-100M</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer pretrained on 100M molecules</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/DeepChem/MoLFormer-c3-550M">DeepChem/MoLFormer-c3-550M</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>MoLFormer pretrained on 550M molecules</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Singh, R. et al. (2026). ChemBERTa-3: an open source training framework for chemical foundation models. <em>Digital Discovery</em>, 5, 662-685. <a href="https://doi.org/10.1039/D5DD00348B">https://doi.org/10.1039/D5DD00348B</a></p>
<p><strong>Publication</strong>: Digital Discovery 2026</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/deepforestsci/chemberta3">ChemBERTa-3 GitHub Repository</a></li>
<li><a href="https://deepchem.io/">DeepChem Project</a></li>
<li><a href="https://huggingface.co/DeepChem">DeepChem Hugging Face Models</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{singhChemBERTa3OpenSource2026,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Singh, Riya and Barsainyan, Aryan Amit and Irfan, Rida and Amorin, Connor Joseph and He, Stewart and Davis, Tony and Thiagarajan, Arun and Sankaran, Shiva and Chithrananda, Seyone and Ahmad, Walid and Jones, Derek and McLoughlin, Kevin and Kim, Hyojin and Bhutani, Anoushka and Sathyanarayana, Shreyas Vinaya and Viswanathan, Venkat and Allen, Jonathan E. and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa-3}}: an open source training framework for chemical foundation models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2026}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{662-685}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{The Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D5DD00348B}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1039/D5DD00348B}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>GP-MoLFormer: Molecular Generation via Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/gp-molformer/</guid><description>A 46.8M parameter transformer for molecular generation trained on 1.1B SMILES, introducing pair-tuning for efficient property optimization.</description><content:encoded><![CDATA[<h2 id="contribution-and-taxonomic-focus">Contribution and Taxonomic Focus</h2>
<p>This is primarily a <strong>Methodological</strong> paper, as it proposes a specific neural architecture (GP-MoLFormer) and a novel fine-tuning algorithm (Pair-tuning) for molecular generation. It validates these contributions against standard baselines (e.g., JT-VAE, <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen</a>-7b).</p>
<p>It also contains a secondary <strong>Theoretical</strong> contribution by establishing an empirical <a href="/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/">scaling law</a> that relates inference compute (generation size) to the novelty of the generated molecules.</p>
<h2 id="motivation-data-scale-and-prompt-based-optimization">Motivation: Data Scale and Prompt-Based Optimization</h2>
<p>While large language models (LLMs) have transformed text generation, the impact of training data scale and memorization on <em>molecular</em> generative models remains under-explored. Specifically, there is a need to understand how training on billion-scale datasets affects the novelty of generated molecules and whether biases in public databases (like ZINC and PubChem) perpetuate memorization. Furthermore, existing optimization methods often require computationally expensive property predictors or reinforcement learning loops; there is a practical need for more efficient &ldquo;prompt-based&rdquo; optimization techniques.</p>
<h2 id="core-innovations-architecture-and-pair-tuning">Core Innovations: Architecture and Pair-Tuning</h2>
<ol>
<li><strong>Architecture</strong>: The application of a linear-attention transformer decoder with Rotary Positional Embeddings (RoPE) to generative chemistry, allowing for efficient training on 1.1 billion SMILES.</li>
<li><strong>Pair-Tuning</strong>: A novel, parameter-efficient fine-tuning method that uses property-ordered molecular pairs to learn &ldquo;soft prompts&rdquo; for optimization without updating the base model weights.</li>
<li><strong>Scaling Analysis</strong>: An extensive empirical investigation mapping the trade-off between inference compute (up to 10B generations) and chemical novelty, fitting an exponential decay curve that demonstrates how novelty saturates as generation volume grows.</li>
</ol>
<h2 id="experimental-methodology-and-downstream-tasks">Experimental Methodology and Downstream Tasks</h2>
<p>The authors evaluated GP-MoLFormer on three distinct tasks, though the comparisons highlight the difficulty of evaluating foundation models against classical baselines:</p>
<ol>
<li><strong>De Novo Generation</strong>: Comparing validity, uniqueness, and novelty against baselines (CharRNN, VAE, <a href="/notes/chemistry/molecular-design/generation/latent-space/limo-latent-inceptionism/">LIMO</a>, MolGen-7b) on a held-out test set. Notably, this is an unequal comparison; most baselines were trained on the 1.6M molecule <a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> dataset, whereas GP-MoLFormer uses up to 1.1B molecules, meaning performance gains are heavily driven by data scale.</li>
<li><strong>Scaffold-Constrained Decoration</strong>: Generating molecules from DRD2 active binder scaffolds and measuring the hit rate of active compounds against specialized scaffold decorators.</li>
<li><strong>Property-Guided Optimization</strong>: Using Pair-tuning to optimize for Drug-likeness (QED), Penalized <a href="https://en.wikipedia.org/wiki/Octanol-water_partition_coefficient">logP</a>, and <a href="https://en.wikipedia.org/wiki/Dopamine_receptor_D2">DRD2</a> binding activity, comparing the results to graph-based and reinforcement learning benchmarks.</li>
</ol>
<p>Additionally, they performed a <strong>Scaling Study</strong>:</p>
<ul>
<li>Comparing models trained on raw (1.1B) vs. de-duplicated (650M) data.</li>
<li>Generating up to 10 billion molecules to fit empirical scaling laws for novelty.</li>
</ul>
<h2 id="key-findings-and-scaling-laws">Key Findings and Scaling Laws</h2>
<ul>
<li><strong>Scale Driven Performance</strong>: GP-MoLFormer achieves high internal diversity and validity on generation metrics. However, its baseline novelty percentage (~32%) is considerably lower than classical models. The authors attribute this to the massive training scale forcing the model to heavily prioritize matching real-world molecule frequencies over pure exploration. GP-MoLFormer&rsquo;s advantage in generation metrics over LLM-baselines like <a href="/notes/chemistry/molecular-design/generation/autoregressive/molgen-molecular-generation-chemical-feedback/">MolGen</a>-7b likely stems heavily from its 10x larger training dataset rather than fundamental architectural superiority.</li>
<li><strong>Pair-Tuning Efficacy</strong>: The proposed pair-tuning method effectively optimizes properties (e.g., improving DRD2 activity scores) without requiring full model fine-tuning or external reward loops. While successful, the text-based generation yields ~94.5% validity during optimization, which lags behind graph and SELFIES-based baselines that guarantee 100% structural validity.</li>
<li><strong>Memorization vs. Novelty</strong>: Training on de-duplicated data (GP-MoLFormer-UNIQ) yields higher novelty (approx. 5-8% higher) than training on raw data, confirming that duplication bias in public databases leads directly to memorization.</li>
<li><strong>Inference Scaling Law</strong>: Novelty decays exponentially with generation size ($y = ae^{-bx}$), yet the model maintains generative capability (~16.7% novelty) even after generating an unprecedented 10 billion molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Sources</strong>: A combination of <strong>PubChem</strong> (111M SMILES) and <strong>ZINC</strong> (1B SMILES) databases. Downloading and pre-training instructions are located in the repository&rsquo;s <code>data/README.md</code>.</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>All SMILES were canonicalized using RDKit (no isomeric information).</li>
<li><strong>GP-MoLFormer (Base)</strong>: Trained on the full 1.1B dataset (includes duplicates).</li>
<li><strong>GP-MoLFormer-UNIQ</strong>: Trained on a de-duplicated subset of 650M SMILES.</li>
</ul>
</li>
<li><strong>Tokenization</strong>: Uses the tokenizer from Schwaller et al. (2019) with a vocabulary size of <strong>2,362 tokens</strong>.</li>
<li><strong>Filtering</strong>: Sequences restricted to a maximum length of <strong>202 tokens</strong>.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pair-Tuning (Algorithm 1)</strong>:</p>
<ul>
<li><strong>Objective</strong>: Learn task-specific soft prompts $\phi_T$ to maximize the conditional probability of target molecule $b$ given a seed molecule $a$, where pair $(a, b)$ satisfies the property condition $b &gt; a$. The base model parameters $\theta$ remain frozen.</li>
<li><strong>Prompt Structure</strong>: Autoregressive training optimizes the continuous embeddings of $n$ enhancement tokens against the cross-entropy loss of the target sequence:
$$ \mathcal{L}(\phi_T) = - \sum_{i=1}^{|b|} \log P_{\theta}(b_i | \phi_T, a, b_{&lt;i}) $$</li>
<li><strong>Hyperparameters</strong>: Trained for 1,000 epochs with a batch size of 35 and a fixed learning rate of $3 \times 10^{-2}$.</li>
<li><strong>Inference</strong>: The learned prompt $\phi_T$ and seed molecule $a$ are prepended as context, and candidates are sampled autoregressively until a termination token is produced.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Availability</strong>: The model trained on deduplicated data (GP-MoLFormer-UNIQ) is publicly available on <a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">Hugging Face</a>. The full 1.1B base model is not explicitly hosted. The source code repository includes a disclosure that IBM will not maintain the code going forward.</li>
<li><strong>Architecture</strong>: Transformer decoder (~47M parameters: 12 layers, 12 heads, hidden size 768).</li>
<li><strong>Attention Mechanism</strong>: Combines Linear Attention (Generalized Random Feature map, $\phi$) with Rotary Positional Embeddings (RoPE). To avoid the quadratic complexity of standard attention while maintaining relative positional awareness, RoPE is applied to queries ($Q$) and keys ($K$) prior to the random feature mapping:
$$ \text{Attention}(Q, K, V) = \frac{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle v_n}{\sum_{n=1}^N \langle \phi(R_m q_m), \phi(R_n k_n) \rangle} $$</li>
<li><strong>Inference Speed</strong>: ~3ms per forward pass on a single A100 GPU.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Generation Quality Metrics</strong>: Validity, Uniqueness, Novelty (<a href="/notes/chemistry/molecular-design/generation/evaluation/molecular-sets-moses/">MOSES</a> suite), <a href="/notes/chemistry/molecular-design/generation/evaluation/frechet-chemnet-distance/">Fréchet ChemNet Distance (FCD)</a>, Scaffold similarity (Scaf), and Similarity to Nearest Neighbor (SNN).</li>
<li><strong>MoLFormer-Based Metrics</strong>: The authors introduce Fréchet <a href="/notes/chemistry/molecular-representations/encoders/molformer/">MoLFormer</a> Distance (FMD) and MoLFormer-space IntDiv2 to measure distributional similarity using their own pre-trained continuous embeddings instead of standard fingerprints.</li>
<li><strong>Optimization Metrics</strong>: Penalized logP (calculated as $\text{logP} - \text{SA} - \text{max}(\text{maxrings}(size) - 6, 0)$), Drug-likeness (QED), and DRD2 activity scores.</li>
<li><strong>Scaling Metrics</strong>: Empirical fit for novelty decay: $y = ae^{-bx}$.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 16 x NVIDIA A100 (80 GB) GPUs across 2 nodes connected via EDR Infiniband.</li>
<li><strong>Training Time</strong>:
<ul>
<li>GP-MoLFormer (1.1B data): ~115 hours total (28.75 hours/epoch for 4 epochs).</li>
<li>GP-MoLFormer-UNIQ (650M data): ~80 hours total.</li>
</ul>
</li>
<li><strong>Hyperparameters</strong>: Used a batch size of 1,600 molecules per GPU with a fixed learning rate of $1.6 \times 10^{-4}$ (scaled up to $8\times$ factor as GPUs increased).</li>
<li><strong>Optimization</strong>: Used distributed data-parallel training and adaptive bucketing by sequence length to handle scale.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/IBM/gp-molformer/">GP-MoLFormer (GitHub)</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official implementation; IBM will not maintain going forward</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/ibm-research/GP-MoLFormer-Uniq">GP-MoLFormer-Uniq (Hugging Face)</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained on 650M de-duplicated SMILES</td>
      </tr>
  </tbody>
</table>
<p>The full 1.1B base model weights are not publicly hosted. The training data (PubChem and ZINC) is publicly available, and instructions for downloading and pre-processing are in the repository&rsquo;s <code>data/README.md</code>.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ross, J., Belgodere, B., Hoffman, S. C., Chenthamarakshan, V., Navratil, J., Mroueh, Y., &amp; Das, P. (2025). GP-MoLFormer: A Foundation Model For Molecular Generation. <em>Digital Discovery</em>, 4(10), 2684&ndash;2696. <a href="https://doi.org/10.1039/D5DD00122F">https://doi.org/10.1039/D5DD00122F</a></p>
<p><strong>Publication</strong>: Digital Discovery, vol. 4, no. 10, pp. 2684&ndash;2696 (2025)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ross2025gpmolformer,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{GP-MoLFormer: a foundation model for molecular generation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ross, Jerret and Belgodere, Brian and Hoffman, Samuel C and Chenthamarakshan, Vijil and Navratil, Jiri and Mroueh, Youssef and Das, Payel}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{2684--2696}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D5DD00122F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa-2: Scaling Molecular Transformers to 77M</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-2/</link><pubDate>Thu, 25 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta-2/</guid><description>Optimizing transformer pretraining for molecules using MLM vs MTR objectives, scaling to 77M compounds from PubChem for improved property prediction.</description><content:encoded><![CDATA[<h2 id="classifying-chemberta-2s-methodological-contributions">Classifying ChemBERTa-2&rsquo;s Methodological Contributions</h2>
<p>This is primarily a <strong>Methodological</strong> paper with a secondary <strong>Resource</strong> contribution.</p>
<p>It fits the Method classification because it focuses on optimizing the architecture and pretraining pipeline for molecular transformers. The authors perform extensive ablation studies (varying dataset size from 5M to 77M, comparing MLM vs. MTR objectives) to determine &ldquo;how well&rdquo; these strategies work compared to baselines. The secondary Resource classification applies because they open-source the trained models and establish a benchmark on a massive 77M compound dataset.</p>
<p><strong>Key methodological indicators</strong>:</p>
<ul>
<li><strong>Baseline comparison</strong>: The paper explicitly compares ChemBERTa-2 against standard baselines (D-MPNN, Random Forest, GCN) and its predecessor (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa-1</a>) with prominent benchmark tables</li>
<li><strong>Ablation studies</strong>: Extensive experiments comparing multi-task and self-supervised pretraining by varying hyperparameters and pretraining dataset size</li>
<li><strong>Scaling analysis</strong>: Systematic investigation of whether larger datasets (up to 77M compounds) yield better performance</li>
</ul>
<h2 id="motivations-for-scaling-molecular-transformers">Motivations for Scaling Molecular Transformers</h2>
<p>The authors aim to bridge the gap between NLP success stories (like GPT-3) and molecular machine learning by developing a &ldquo;chemical foundation model&rdquo;.</p>
<p><strong>Key motivations</strong>:</p>
<ul>
<li><strong>Label scarcity</strong>: Experimental labels for molecular properties are rare and expensive, but unlabeled SMILES strings are abundant</li>
<li><strong>Scaling hypothesis</strong>: Testing if scaling pretraining data (up to 77M compounds) yields consistent downstream improvements, similar to scaling laws in NLP</li>
<li><strong>Efficiency</strong>: Optimizing the pretraining process introduced in the original ChemBERTa by comparing self-supervised (MLM) and weakly supervised (MTR, using <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> computed properties as labels) approaches</li>
</ul>
<h2 id="novelty-in-multi-task-regression-objectives">Novelty in Multi-Task Regression Objectives</h2>
<p><strong>Scale</strong>: Training on 77M unique SMILES from PubChem, which is one of the largest molecular pretraining datasets used to date (compared to 10M for ChemBERTa-1 or 18.7M for <a href="/notes/chemistry/molecular-representations/encoders/smiles-bert/">SMILES-BERT</a>).</p>
<p><strong>Pipeline optimization</strong>: A direct, controlled comparison of <strong>Masked Language Modeling (MLM)</strong> vs. <strong>Multi-Task Regression (MTR)</strong> pretraining objectives on identical datasets.</p>
<p><strong>Proxy selection</strong>: The finding that MLM loss correlates well with MTR loss, allowing the cheaper MLM task to be used for hyperparameter tuning before running the expensive MTR pretraining.</p>
<h2 id="experimental-pretraining-setup-on-77m-compounds">Experimental Pretraining Setup on 77M Compounds</h2>
<h3 id="pretraining-setup">Pretraining Setup</h3>
<p><strong>Datasets</strong>: Subsets of <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> containing 5M, 10M, and 77M unique SMILES.</p>
<p><strong>Tasks</strong>:</p>
<ul>
<li><strong>MLM</strong>: Masking 15% of tokens (following RoBERTa procedure). The model is optimized by minimizing the cross-entropy loss over the predicted masked tokens:
$$ \mathcal{L}_{MLM} = -\sum_{i \in \mathcal{M}} \log P(x_i \mid \mathbf{x}_{\setminus \mathcal{M}}) $$
where $\mathcal{M}$ represents the set of masked token indices.</li>
<li><strong>MTR</strong>: Predicting 200 calculated molecular properties (via RDKit) simultaneously using a mean squared error objective:
$$ \mathcal{L}_{MTR} = \frac{1}{200} \sum_{j=1}^{200} \frac{1}{N} \sum_{i=1}^{N} \left( \hat{y}_{ij} - y_{ij} \right)^2 $$
Continuous target labels $y_{ij}$ are mean-normalized prior to training to equilibrate the disparate scales of different chemical properties.</li>
</ul>
<p><strong>Hyperparameter search</strong>: Ran 50 random configurations on the 5M dataset; selected the top 5 to scale up to 10M and 77M.</p>
<h3 id="downstream-validation">Downstream Validation</h3>
<p><strong>Finetuning</strong>: Evaluated on 8 tasks from <strong><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></strong> (BACE, BBBP, ClinTox, Delaney, etc.) using scaffold splits (80/10/10).</p>
<p><strong>Analysis</strong>: Used UMAP to visualize embeddings from MLM, MTR, and ECFP to check for clustering by label without finetuning.</p>
<h2 id="key-performance-outcomes-and-scaling-realities">Key Performance Outcomes and Scaling Realities</h2>
<p><strong>Highly competitive performance</strong>: ChemBERTa-2 outperforms the D-MPNN baseline (chemprop) on 6 out of 8 MoleculeNet tasks, though the margins demonstrate that task-specific baselines remain notably robust.</p>
<p><strong>MTR superiority</strong>: Models pretrained on Multi-Task Regression (MTR) consistently perform better on downstream tasks than those pretrained on MLM on every finetuning task evaluated. MTR is substantially slower than MLM due to the larger input size from the 200-element label vector, but MLM loss serves as a reliable proxy for MTR loss, enabling cheaper architecture search before committing to full MTR pretraining.</p>
<p><strong>Scaling laws versus downstream utility</strong>: Pretraining loss improved by 25-35% when increasing the dataset from 5M to 77M compounds. However, this improvement in pretraining loss does not uniformly transfer to downstream tasks. For MTR models, SR-p53 ROC-AUC decreases monotonically from 0.834 (5M) to 0.827 (10M) to 0.817 (77M), and Lipophilicity RMSE is worse at 77M (0.798) than at 5M (0.758), despite a dip at 10M (0.744). This variability in transfer challenges the assumption that pretraining improvements always yield downstream gains.</p>
<p><strong>Transfer learning</strong>: The correlation between pretraining loss and downstream performance is task-dependent; it is strong for Lipophilicity but weaker for BACE classification.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The pretraining corpus is derived from <strong>PubChem</strong>.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Pretraining</strong></td>
          <td>PubChem</td>
          <td>77M SMILES</td>
          <td>Canonicalized and globally shuffled. Subsets of 5M and 10M used. <strong>Note: Exact splits and datasets are not published.</strong></td>
      </tr>
      <tr>
          <td><strong>Validation</strong></td>
          <td>PubChem</td>
          <td>100k SMILES</td>
          <td>A fixed set held out from the 77M corpus. <strong>Note: Exact 100k subset is not published.</strong></td>
      </tr>
      <tr>
          <td><strong>MTR Labels</strong></td>
          <td>RDKit</td>
          <td>200 props</td>
          <td>200 molecular properties calculated from SMILES using RDKit. Labels are mean-normalized. <strong>Note: Calculated labels are not published and must be re-computed.</strong></td>
      </tr>
      <tr>
          <td><strong>Finetuning</strong></td>
          <td>MoleculeNet</td>
          <td>1.5k - 8k</td>
          <td>Tasks: BACE, Clearance, Delaney, Lipophilicity, BBBP, ClinTox, HIV, Tox21. Split 80/10/10 via scaffold splitter.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Pretraining Objectives:</strong></p>
<ol>
<li><strong>Masked Language Modeling (MLM)</strong>: Follows RoBERTa procedure. Masks 15% of tokens. Max sequence length 512.</li>
<li><strong>Multi-Task Regression (MTR)</strong>: Predicting 200 RDKit properties. Labels are mean-normalized.</li>
</ol>
<p><strong>Tokenizer:</strong></p>
<ul>
<li>Dictionary of common SMILES characters</li>
<li>Maximum vocabulary size: <strong>591 tokens</strong></li>
</ul>
<p><strong>Optimization:</strong></p>
<ul>
<li><strong>Patience</strong>: Early stopping set to one pass through the dataset to ensure full coverage</li>
<li><strong>Hyperparameter search</strong>: Random search (50 configs) varying hidden size, attention heads, dropout, intermediate size, hidden layers, and learning rate. <strong>Note: The precise configuration of the winning models that were scaled to 77M is absent from the paper.</strong></li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Based on <strong>RoBERTa</strong> (HuggingFace implementation)</li>
<li><strong>Parameter scale</strong>: Models ranged between <strong>5M and 46M parameters</strong></li>
<li><strong>Selection</strong>: Top 5 configurations from the 5M-dataset random search were trained on the full 77M dataset</li>
<li><strong>Checkpoints</strong>: Pre-trained weights are hosted by DeepChem on <a href="https://huggingface.co/DeepChem">Hugging Face</a>. Direct links include <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MTR">DeepChem/ChemBERTa-77M-MTR</a> and <a href="https://huggingface.co/DeepChem/ChemBERTa-77M-MLM">DeepChem/ChemBERTa-77M-MLM</a> (Note: Model cards are currently empty).</li>
<li><strong>Code Reference</strong>: While the <a href="https://github.com/deepchem/deepchem">DeepChem</a> repository is referenced for code, isolated training scripts tailored to recreate ChemBERTa-2&rsquo;s exact pipeline are not separated from the generalized deepchem library tooling.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Benchmarks were performed on <strong>MoleculeNet</strong> using DeepChem.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Tasks</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>RMSE</strong> ($\downarrow$)</td>
          <td>Delaney, Lipo, BACE (Reg), Clearance</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 outperformed D-MPNN on Delaney (0.889 vs 1.105) and Clearance (48.5 vs 49.8).</td>
      </tr>
      <tr>
          <td><strong>ROC-AUC</strong> ($\uparrow$)</td>
          <td>BBBP, ClinTox, HIV, Tox21, BACE (Cls)</td>
          <td>D-MPNN</td>
          <td>ChemBERTa-2 generally competitive; MTR-77M achieved 0.728 on BBBP vs D-MPNN 0.697.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: AWS EC2 instances with <strong>Nvidia T4 GPUs</strong></li>
<li><strong>Strategy</strong>: AWS Spot instances were used to reduce cost; implemented frequent checkpointing to handle interruptions.</li>
<li><strong>Note</strong>: For MTR, they wrote a custom data loader wrapper around HuggingFace&rsquo;s text loader to handle CSV parsing efficiency, as the default CSV loader was a major bottleneck for the 200-element target vectors.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ahmad, W., Simon, E., Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2022). ChemBERTa-2: Towards Chemical Foundation Models. <em>arXiv preprint arXiv:2209.01712</em>. <a href="https://doi.org/10.48550/arXiv.2209.01712">https://doi.org/10.48550/arXiv.2209.01712</a></p>
<p><strong>Publication</strong>: arXiv 2022 (Presented at 2021 ELLIS ML for Molecule Discovery Workshop)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa-1 Paper</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{ahmadChemBERTa2ChemicalFoundation2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}: {{Towards Chemical Foundation Models}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa-2}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2209.01712}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-25}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemformer: A Pre-trained Transformer for Comp Chem</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/chemformer/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-design/generation/autoregressive/chemformer/</guid><description>BART-based Transformer pre-trained on 100M molecules using self-supervision to accelerate convergence on chemical sequence tasks.</description><content:encoded><![CDATA[<h2 id="paper-contribution-and-methodological-classification">Paper Contribution and Methodological Classification</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It proposes an architecture adaptation (Chemformer based on BART) and a specific pre-training strategy (&ldquo;Combined&rdquo; masking and augmentation). The paper validates this method by benchmarking against established models on multiple tasks, including direct synthesis, retrosynthesis, and molecular optimization. It also includes a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution by making the pre-trained models and code available.</p>
<h2 id="motivation-computational-bottlenecks-in-cheminformatics">Motivation: Computational Bottlenecks in Cheminformatics</h2>
<p>Existing Transformer models for cheminformatics are often developed for single applications and are computationally expensive to train from scratch. For example, training a Molecular Transformer for reaction prediction can take days, limiting hyperparameter exploration. Self-supervised pre-training (like BERT or T5) has significantly advanced NLP by reducing fine-tuning time and improving performance. In chemistry, applications have traditionally focused on task-specific datasets or encoder-only architectures, which perform poorly on sequence generation tasks. The authors aim to use transfer learning on a large unlabelled dataset to create a model that converges quickly and performs well across diverse sequence-to-sequence and discriminative tasks.</p>
<h2 id="core-innovation-bart-architecture-and-combined-pre-training">Core Innovation: BART Architecture and Combined Pre-training</h2>
<p>The primary insight lies in the adaptation of the <strong>BART architecture</strong> for chemistry and the introduction of a <strong>&ldquo;Combined&rdquo; self-supervised pre-training task</strong>.</p>
<ul>
<li><strong>Architecture</strong>: Chemformer uses the BART encoder-decoder structure, allowing it to handle both discriminative (property prediction) and generative (reaction prediction) tasks efficiently. This provides an alternative to encoder-only (BERT) or decoder-only (GPT) models.</li>
<li><strong>Combined Pre-training</strong>: The authors introduce a task that applies both <strong>Span Masking</strong> (randomly replacing tokens with <code>&lt;mask&gt;</code>) and <strong><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> Augmentation</strong> (permuting atom order, see <a href="/notes/chemistry/molecular-representations/notations/randomized-smiles-generative-models/">Randomized SMILES</a>) simultaneously. Formally, given a canonical SMILES sequence $x$, a corrupted sequence $\tilde{x} = \text{Mask}(\text{Augment}(x))$ is generated. The model is trained using an autoregressive cross-entropy loss to reconstruct the canonical sequence from the corrupted input:
$$ \mathcal{L}_{\text{pre-train}} = -\sum_{t=1}^{|x|} \log P(x_t \mid x_{&lt;t}, \tilde{x}) $$</li>
<li><strong>Tunable Augmentation</strong>: A downstream augmentation strategy is proposed where the probability of augmenting the input/output SMILES ($p_{aug}$) is a tunable hyperparameter, performed on-the-fly.</li>
</ul>
<h2 id="experimental-setup-and-pre-training-tasks">Experimental Setup and Pre-training Tasks</h2>
<p>The authors pre-trained Chemformer on <strong>100 million molecules</strong> from ZINC-15 and fine-tuned it on three distinct task types:</p>
<ol>
<li><strong>Seq2Seq Reaction Prediction</strong>:
<ul>
<li><em>Direct Synthesis</em>: USPTO-MIT dataset (Mixed and Separated).</li>
<li><em><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></em>: USPTO-50K dataset (see also <a href="/notes/chemistry/molecular-design/reaction-prediction/molecular-transformer/">Molecular Transformer</a>, <a href="/notes/chemistry/molecular-design/reaction-prediction/tied-two-way-transformers-retrosynthesis/">Tied Two-Way Transformers</a>).</li>
</ul>
</li>
<li><strong>Molecular Optimization</strong>: Generating molecules with improved properties (<a href="https://en.wikipedia.org/wiki/Distribution_coefficient">LogD</a>, solubility, clearance) starting from ChEMBL matched molecular pairs.</li>
<li><strong>Discriminative Tasks</strong>:
<ul>
<li><em><a href="https://en.wikipedia.org/wiki/Quantitative_structure%E2%80%93activity_relationship">QSAR</a></em>: Predicting properties (ESOL, FreeSolv, Lipophilicity) from <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>.</li>
<li><em>Bioactivity</em>: Predicting pXC50 values for 133 genes using ExCAPE data.</li>
</ul>
</li>
</ol>
<p>Ablation studies compared three pre-training strategies (Masking, Augmentation, Combined) against a randomly initialized baseline.</p>
<h2 id="results-trade-offs-and-conclusions">Results, Trade-offs, and Conclusions</h2>
<ul>
<li><strong>Performance</strong>: Chemformer achieved <strong>competitive top-1 accuracy</strong> on USPTO-MIT (91.3% Mixed) and USPTO-50K (53.6-54.3%), outperforming the Augmented Transformer and graph-based models (GLN, GraphRetro).</li>
<li><strong>Convergence Speed</strong>: Pre-training significantly accelerated training; fine-tuning for just 20 epochs (30 mins) outperformed the previous baselines trained for significantly longer.</li>
<li><strong>Pre-training Tasks</strong>: The &ldquo;Combined&rdquo; task generally performed best for reaction prediction and bioactivity, while &ldquo;Masking&rdquo; was superior for molecular optimization.</li>
<li><strong>Augmentation Trade-off</strong>: The augmentation strategy improved top-1 accuracy but significantly degraded top-5/10 accuracy because beam search outputs became populated with augmented versions of the same molecule. This presents a considerable limitation for practical applications like retrosynthesis mapping, where retrieving a diverse set of candidate reactions is often critical.</li>
<li><strong>Discriminative Evaluation Caveats</strong>: Chemformer underperformed specialized baselines (like D-MPNN or <a href="/notes/chemistry/molecular-representations/encoders/molbert-molecular-representations/">MolBERT</a>) on small discriminative datasets. The authors note that direct comparison is difficult: Chemformer was trained simultaneously on multiple subtasks (multi-task learning), while the literature baselines were trained and tuned on each subtask separately. Additionally, the Chemformer encoder uses fewer than 20M parameters compared to MolBERT&rsquo;s approximately 85M, and Chemformer&rsquo;s pre-training does not include molecular property objectives. For other transfer learning approaches to QSAR, see <a href="/notes/chemistry/molecular-design/property-prediction/molpmofit-transfer-learning-qsar/">MolPMoFiT</a>.</li>
<li><strong>Pre-training Data Scope</strong>: The 100M pre-training dataset from ZINC-15 was selected with constraints on molecular weight ($\le 500$ Da) and LogP ($\le 5$), focusing the learned representations on small, drug-like molecules.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><em>Note: The primary GitHub repository for Chemformer was officially archived on February 11, 2026. Pre-trained weights and datasets used in the paper are still hosted externally on <a href="https://az.app.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Box</a>. Active development of Chemformer models has moved to the <a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels</a> repository.</em></p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/Chemformer">Chemformer (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Archived; original PyTorch implementation</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/MolecularAI/aizynthmodels">AiZynthModels (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Active successor repository</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://az.app.box.com/s/7eci3nd9vy0xplqniitpk02rbg9q2zcq">Pre-trained weights (Box)</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Base and Large model checkpoints</td>
      </tr>
  </tbody>
</table>
<p>The following datasets were used for pre-training and benchmarking.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Pre-training</strong></td>
          <td style="text-align: left">ZINC-15</td>
          <td style="text-align: left">100M</td>
          <td style="text-align: left">Selected subset (reactive, annotated purchasability, MW $\le 500$, LogP $\le 5$). Split: 99% Train / 0.5% Val / 0.5% Test.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Direct Synthesis</strong></td>
          <td style="text-align: left">USPTO-MIT</td>
          <td style="text-align: left">~470k</td>
          <td style="text-align: left">Evaluated on &ldquo;Mixed&rdquo; and &ldquo;Separated&rdquo; variants.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Retrosynthesis</strong></td>
          <td style="text-align: left">USPTO-50K</td>
          <td style="text-align: left">~50k</td>
          <td style="text-align: left">Standard benchmark for retrosynthesis.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Optimization</strong></td>
          <td style="text-align: left">ChEMBL MMPs</td>
          <td style="text-align: left">~160k Train</td>
          <td style="text-align: left">Matched Molecular Pairs for LogD, solubility, and clearance optimization.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Properties</strong></td>
          <td style="text-align: left">MoleculeNet</td>
          <td style="text-align: left">Small</td>
          <td style="text-align: left">ESOL (1128), FreeSolv (642), Lipophilicity (4200).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Bioactivity</strong></td>
          <td style="text-align: left">ExCAPE</td>
          <td style="text-align: left">~312k</td>
          <td style="text-align: left">133 gene targets; &gt;1200 compounds per gene.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Tokenization</strong>: Regex-based tokenization (523 tokens total) derived from ChEMBL 27 canonical SMILES.</li>
<li><strong>Augmentation</strong>: SMILES enumeration (permuting atom order) used for pre-training and on-the-fly during fine-tuning ($p_{aug}=0.5$ for Seq2Seq, $p_{aug}=1.0$ for discriminative).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pre-training Tasks</strong>:
<ol>
<li><em>Masking</em>: Span masking (BART style).</li>
<li><em>Augmentation</em>: Input is a randomized SMILES; target is canonical SMILES.</li>
<li><em>Combined</em>: Input is augmented <em>then</em> masked; target is canonical SMILES.</li>
</ol>
</li>
<li><strong>Optimization</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$).</li>
<li>Schedule: Linear warm-up (8000 steps) for pre-training; One-cycle schedule for fine-tuning.</li>
</ul>
</li>
<li><strong>Inference</strong>: <a href="https://en.wikipedia.org/wiki/Beam_search">Beam search</a> with width 10 for Seq2Seq tasks. Used <code>molbart/inference_score.py</code> and <code>molbart/retrosynthesis/round_trip_inference.py</code> for standard and round-trip validation.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two model sizes were trained. Both use the Pre-Norm Transformer layout with GELU activation.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hyperparameter</th>
          <th style="text-align: left">Chemformer (Base)</th>
          <th style="text-align: left">Chemformer-Large</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Layers</strong></td>
          <td style="text-align: left">6</td>
          <td style="text-align: left">8</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Model Dimension</strong></td>
          <td style="text-align: left">512</td>
          <td style="text-align: left">1024</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Feed-forward Dim</strong></td>
          <td style="text-align: left">2048</td>
          <td style="text-align: left">4096</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Attention Heads</strong></td>
          <td style="text-align: left">8</td>
          <td style="text-align: left">16</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Parameters</strong></td>
          <td style="text-align: left">~45M</td>
          <td style="text-align: left">~230M</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Pre-training Task</strong></td>
          <td style="text-align: left">All 3 variants</td>
          <td style="text-align: left">Combined only</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<p>Comparisons relied on Top-N accuracy for reaction tasks and validity metrics for optimization.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Task</th>
          <th style="text-align: left">Key Result</th>
          <th style="text-align: left">Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Direct Synthesis (Sep)</td>
          <td style="text-align: left"><strong>92.8%</strong> (Large)</td>
          <td style="text-align: left">91.1% (Aug Transformer)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Top-1 Acc</strong></td>
          <td style="text-align: left">Retrosynthesis</td>
          <td style="text-align: left"><strong>54.3%</strong> (Large)</td>
          <td style="text-align: left">53.7% (GraphRetro) / 52.5% (GLN)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Desirable %</strong></td>
          <td style="text-align: left">Mol Optimization</td>
          <td style="text-align: left"><strong>75.0%</strong> (Base-Mask)</td>
          <td style="text-align: left">70.2% (Transformer-R)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>RMSE</strong></td>
          <td style="text-align: left">Lipophilicity</td>
          <td style="text-align: left">0.598 (Combined)</td>
          <td style="text-align: left">0.555 (D-MPNN)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 NVIDIA V100 GPUs (batch size 128 per GPU).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Pre-training: 2.5 days (Base) / 6 days (Large) for 1M steps.</li>
<li>Fine-tuning: ~20-40 epochs for reaction prediction (&lt;12 hours).</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Irwin, R., Dimitriadis, S., He, J., &amp; Bjerrum, E. J. (2022). Chemformer: a pre-trained transformer for computational chemistry. <em>Machine Learning: Science and Technology</em>, 3(1), 015022. <a href="https://doi.org/10.1088/2632-2153/ac3ffb">https://doi.org/10.1088/2632-2153/ac3ffb</a></p>
<p><strong>Publication</strong>: Machine Learning: Science and Technology 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{irwinChemformerPretrainedTransformer2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Chemformer: A Pre-Trained Transformer for Computational Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Chemformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Irwin, Ross and Dimitriadis, Spyridon and He, Jiazhen and Bjerrum, Esben Jannik}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Machine Learning: Science and Technology}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{015022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{IOP Publishing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2632-2153}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/2632-2153/ac3ffb}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemBERTa: Molecular Property Prediction via Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta/</link><pubDate>Tue, 23 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/encoders/chemberta/</guid><description>A systematic evaluation of RoBERTa transformers pretrained on 77M PubChem SMILES for molecular property prediction tasks.</description><content:encoded><![CDATA[<h2 id="taxonomy-and-paper-contributions">Taxonomy and Paper Contributions</h2>
<p>This is primarily a <strong>Method</strong> paper ($\Psi_{\text{Method}}$), with a significant <strong>Resource</strong> component ($\Psi_{\text{Resource}}$).</p>
<p>It is a methodological investigation because it systematically evaluates a specific architecture (Transformers/RoBERTa) against established State-of-the-Art (SOTA) baselines like directed Message Passing Neural Networks (D-MPNNs) to determine &ldquo;how well does this work?&rdquo; in the chemical domain. It ablates dataset size, tokenization, and input representation.</p>
<p>It is also a resource paper as it introduces &ldquo;PubChem-77M,&rdquo; a curated dataset of 77 million SMILES strings designed to facilitate large-scale self-supervised pretraining for the community.</p>
<h2 id="overcoming-data-scarcity-in-property-prediction">Overcoming Data Scarcity in Property Prediction</h2>
<p>The primary motivation is <strong>data scarcity</strong> in molecular property prediction. Graph Neural Networks (GNNs) achieve strong performance on property prediction tasks when provided with sufficient labeled data. Generating these labels requires costly and time-consuming laboratory testing, leading to severe data scarcity in specialized chemical domains.</p>
<p>Massive quantities of <strong>unlabeled chemical structure data</strong> exist in the form of <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings. Inspired by the success of Transformers in NLP, where self-supervised pretraining on large corpora yields strong transfer learning, the authors aim to use these unlabeled datasets to learn effective molecular representations. Additionally, Transformers benefit from a mature software ecosystem (HuggingFace) that offers efficiency advantages over GNNs.</p>
<h2 id="pretraining-scaling-laws-and-novelty">Pretraining Scaling Laws and Novelty</h2>
<p>Previous works applied Transformers to SMILES strings. This paper advances the field by systematically evaluating scaling laws and architectural components for this domain. Specifically:</p>
<ul>
<li><strong>Scaling Analysis</strong>: It explicitly tests how pretraining dataset size (100K to 10M) impacts downstream performance.</li>
<li><strong>Tokenizer Comparison</strong>: It compares standard NLP <a href="https://en.wikipedia.org/wiki/Byte-pair_encoding">Byte-Pair Encoding (BPE)</a> against a chemically-aware &ldquo;SmilesTokenizer&rdquo;.</li>
<li><strong>Representation Comparison</strong>: It evaluates if the robust <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a> string representation offers advantages over standard SMILES in a Transformer context.</li>
</ul>
<h2 id="experimental-setup-pretraining-and-finetuning">Experimental Setup: Pretraining and Finetuning</h2>
<p>The authors trained <strong>ChemBERTa</strong> (based on RoBERTa) using Masked Language Modeling (MLM) on subsets of the <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> dataset. The core training objective minimizes the cross-entropy loss over a corrupted input where a subset of basic tokens, denoted by $\mathcal{M}$, are masked:</p>
<p>$$
\mathcal{L}_{\text{MLM}} = - \frac{1}{|\mathcal{M}|} \sum_{i \in \mathcal{M}} \log P(x_i \mid x_{\setminus \mathcal{M}}; \theta)
$$</p>
<p>where $x_i$ is the exact masked token, $x_{\setminus \mathcal{M}}$ is the corrupted SMILES context string, and $\theta$ represents the network parameters.</p>
<ul>
<li><strong>Pretraining</strong>: Models were pretrained on dataset sizes of 100K, 250K, 1M, and 10M compounds.</li>
<li><strong>Baselines</strong>: Performance was compared against D-MPNN (Graph Neural Network), Random Forest (RF), and SVM using 2048-bit Morgan Fingerprints.</li>
<li><strong>Downstream Tasks</strong>: Finetuning was performed individually on small <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a> classification tasks: BBBP (<a href="https://en.wikipedia.org/wiki/Blood%E2%80%93brain_barrier">blood-brain barrier</a>), ClinTox (clinical toxicity), HIV, and Tox21 (p53 stress-response). This poses a transfer learning challenge, as the model must adapt from pretraining on 10 million molecules to classifying datasets ranging from ~1.5K to ~41K examples.</li>
<li><strong>Ablations</strong>:
<ul>
<li><strong>Tokenization</strong>: BPE vs. SmilesTokenizer on the 1M dataset, evaluated on Tox21.</li>
<li><strong>Input</strong>: SMILES vs. SELFIES strings on the Tox21 task.</li>
</ul>
</li>
</ul>
<h2 id="results-vs-graph-neural-network-baselines">Results vs. Graph Neural Network Baselines</h2>
<p>The main comparison between ChemBERTa (pretrained on 10M compounds) and Chemprop baselines on MoleculeNet tasks is summarized below (Table 1 from the paper):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>BBBP ROC</th>
          <th>BBBP PRC</th>
          <th>ClinTox ROC</th>
          <th>ClinTox PRC</th>
          <th>HIV ROC</th>
          <th>HIV PRC</th>
          <th>Tox21 ROC</th>
          <th>Tox21 PRC</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ChemBERTa 10M</td>
          <td>0.643</td>
          <td>0.620</td>
          <td>0.733</td>
          <td>0.975</td>
          <td>0.622</td>
          <td>0.119</td>
          <td>0.728</td>
          <td>0.207</td>
      </tr>
      <tr>
          <td>D-MPNN</td>
          <td>0.708</td>
          <td>0.697</td>
          <td>0.906</td>
          <td>0.993</td>
          <td>0.752</td>
          <td>0.152</td>
          <td>0.688</td>
          <td>0.429</td>
      </tr>
      <tr>
          <td>RF</td>
          <td>0.681</td>
          <td>0.692</td>
          <td>0.693</td>
          <td>0.968</td>
          <td>0.780</td>
          <td>0.383</td>
          <td>0.724</td>
          <td>0.335</td>
      </tr>
      <tr>
          <td>SVM</td>
          <td>0.702</td>
          <td>0.724</td>
          <td>0.833</td>
          <td>0.986</td>
          <td>0.763</td>
          <td>0.364</td>
          <td>0.708</td>
          <td>0.345</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Scaling Improvements &amp; Training Dynamics</strong>: Performance scales predictably with pretraining data size. Increasing data from 100K to 10M improved ROC-AUC by +0.110 and PRC-AUC by +0.059 on average across BBBP, ClinTox, and Tox21 (HIV was omitted due to resource constraints). Notably, researchers had to halt pretraining on the 10M subset after just 3 epochs due to overfitting, suggesting that simple 15% token masking might not provide a sufficiently difficult learning curvature for large-scale chemical representation.</li>
<li><strong>Performance Limits vs. GNNs</strong>: ChemBERTa generally performs below the D-MPNN baseline. On the Tox21 dataset, ChemBERTa-10M achieved a higher ROC-AUC (0.728) than D-MPNN (0.688); nonetheless, it recorded a substantially lower PRC-AUC (0.207 vs 0.429). This gap indicates that current Transformer iterations lack the explicit inductive biases of graph algorithms and struggle with the severe class imbalances typical of chemical datasets.</li>
<li><strong>Ablation Limitations (Tokenization &amp; SELFIES)</strong>: The authors&rsquo; ablation studies for tokenization (SmilesTokenizer narrowly beating BPE) and input representation (SELFIES performing comparably to SMILES) were evaluated exclusively on the single Tox21 task. Deriving broad architectural conclusions regarding &ldquo;semantically-aware tokenization&rdquo; or string robustness from an $N=1$ empirical evaluation is a significant limitation of the study. Broader benchmarking is required to validate these findings.</li>
<li><strong>Interpretability</strong>: Attention heads organically learn to track chemically relevant substructures (like specific functional groups and aromatic rings), mimicking the inductive biases of graph convolutions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors curated a massive dataset for pretraining and utilized standard benchmarks for evaluation.</p>
<ul>
<li><strong>Pretraining Data</strong>: <strong>PubChem-77M</strong>.
<ul>
<li>Source: 77 million unique SMILES from PubChem.</li>
<li>Preprocessing: Canonicalized and globally shuffled.</li>
<li>Subsets used: 100K, 250K, 1M, and 10M subsets.</li>
<li><em>Availability Note</em>: The authors provided a direct link to the <a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">canonicalized 10M compound subset</a> used for their largest experiments. Full reproducibility of the smaller (100K, 250K, 1M) or full 77M sets may require re-extracting from PubChem.</li>
</ul>
</li>
<li><strong>Evaluation Data</strong>: <strong>MoleculeNet</strong>.
<ul>
<li>Tasks: BBBP (2,039), ClinTox (1,478), HIV (41,127), Tox21 (7,831).</li>
<li>Splitting: 80/10/10 train/valid/test split using a <strong>scaffold splitter</strong> to ensure chemical diversity between splits.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The core training methodology mirrors standard BERT/RoBERTa procedures adapted for chemical strings.</p>
<ul>
<li><strong>Objective</strong>: Masked Language Modeling (MLM) with <strong>15% token masking</strong>.</li>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>BPE</strong>: Byte-Pair Encoder (vocab size 52K).</li>
<li><strong>SmilesTokenizer</strong>: Regex-based custom tokenizer available in DeepChem (documented <a href="https://deepchem.readthedocs.io/en/latest/tokenizers.html#smilestokenizer">here</a>).</li>
</ul>
</li>
<li><strong>Sequence Length</strong>: Maximum sequence length of <strong>512 tokens</strong>.</li>
<li><strong>Finetuning</strong>: Appended a linear classification layer; backpropagated through the base model for up to 25 epochs with early stopping on ROC-AUC.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: <strong>RoBERTa</strong> (via HuggingFace).
<ul>
<li>Layers: 6</li>
<li>Attention Heads: 12 (72 distinct mechanisms total).</li>
<li><em>Implementation Note</em>: The original training notebooks and scripts are maintained in the authors&rsquo; <a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry repository</a>, alongside the primary downstream tasks integrated into DeepChem. A <a href="https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Transfer_Learning_With_ChemBERTa_Transformers.ipynb">full Tox21 transfer learning tutorial</a> has been incorporated into the DeepChem repository.</li>
</ul>
</li>
<li><strong>Baselines</strong> (via Chemprop library):
<ul>
<li><strong>D-MPNN</strong>: Directed Message Passing Neural Network with default hyperparameters.</li>
<li><strong>RF/SVM</strong>: Scikit-learn Random Forest and SVM using 2048-bit Morgan fingerprints (<a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a>).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is measured using dual metrics to account for class imbalance common in toxicity datasets.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ROC-AUC</strong></td>
          <td>Area Under Receiver Operating Characteristic Curve</td>
      </tr>
      <tr>
          <td><strong>PRC-AUC</strong></td>
          <td>Area Under Precision-Recall Curve (vital for imbalanced data)</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Single <strong>NVIDIA V100 GPU</strong>.</li>
<li><strong>Training Time</strong>: Approximately <strong>48 hours</strong> for the 10M compound subset.</li>
<li><strong>Carbon Footprint</strong>: Estimated 17.1 kg $\text{CO}_2\text{eq}$ (offset by Google Cloud).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training notebooks and finetuning scripts</td>
      </tr>
      <tr>
          <td><a href="https://github.com/deepchem/deepchem">DeepChem</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Integration of ChemBERTa and SmilesTokenizer</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">ChemBERTa-zinc-base-v1</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Pre-trained RoBERTa on 100K ZINC SMILES</td>
      </tr>
      <tr>
          <td><a href="https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/pubchem_10m.txt.zip">PubChem-10M subset</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Canonicalized 10M compound subset used for largest experiments</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status</strong>: Partially Reproducible. Code and pre-trained models are available, and the 10M pretraining subset is downloadable. However, smaller subsets (100K, 250K, 1M) may need re-extraction from PubChem, and exact hyperparameter details for finetuning (learning rate, batch size) are not fully specified in the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chithrananda, S., Grand, G., &amp; Ramsundar, B. (2020). ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction. <em>arXiv preprint arXiv:2010.09885</em>. <a href="https://doi.org/10.48550/arXiv.2010.09885">https://doi.org/10.48550/arXiv.2010.09885</a></p>
<p><strong>Publication</strong>: arXiv 2020 (Preprint)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://huggingface.co/seyonec/ChemBERTa-zinc-base-v1">HuggingFace Model Hub (ChemBERTa-zinc-base-v1)</a> - <em>Additional pre-trained variations on PubChem &amp; ZINC datasets are available on the author&rsquo;s <a href="https://huggingface.co/seyonec">seyonec</a> HF profile.</em></li>
<li><a href="https://github.com/seyonechithrananda/bert-loves-chemistry">bert-loves-chemistry GitHub Repository</a> - <em>Notebooks and scripts used for MLM pretraining and finetuning evaluations.</em></li>
</ul>
<h3 id="bibtex">BibTeX</h3>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{chithranandaChemBERTaLargeScaleSelfSupervised2020,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemBERTa}}: {{Large-Scale Self-Supervised Pretraining}} for {{Molecular Property Prediction}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{ChemBERTa}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2010.09885}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-24}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Score-Based Generative Modeling with SDEs (Song 2021)</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</guid><description>Unified SDE framework for score-based generative models, introducing Predictor-Corrector samplers and setting CIFAR-10 records with FID 2.20 and 2.99 bits/dim.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper. It proposes a unified framework that generalizes previous discrete score-based models (SMLD and DDPM) into continuous-time Stochastic Differential Equations (SDEs). The paper introduces algorithms for sampling (Predictor-Corrector) and likelihood computation (Probability Flow ODE), validated by setting new records on CIFAR-10 (FID 2.20, IS 9.89 at the time of publication). It also contains elements of <strong>Systematization</strong> by showing how existing methods are special cases of this broader framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Prior successful generative models, specifically Score Matching with Langevin Dynamics (SMLD) and Denoising Diffusion Probabilistic Models (DDPM), operate by sequentially corrupting data with slowly increasing noise and learning to reverse the process. Both methods treat the noise scales as a finite set of discrete steps. The authors aim to generalize this to a continuum of noise scales by modeling the diffusion process as a Stochastic Differential Equation (SDE). This continuous formulation enables:</p>
<ul>
<li><strong>Flexible sampling:</strong> Use of general-purpose SDE solvers.</li>
<li><strong>Exact likelihood computation:</strong> Via connection to Neural ODEs.</li>
<li><strong>Controllable generation:</strong> Solving inverse problems (inpainting, colorization) without retraining.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>SDE framework</strong> for score-based generative modeling:</p>
<ul>
<li><strong>Continuous Generalization:</strong> Proving that SMLD and DDPM noise perturbations correspond to discretizations of Variance Exploding (VE) SDEs and Variance Preserving (VP) SDEs, respectively.</li>
<li><strong>Reverse-Time SDE:</strong> Leveraging Anderson&rsquo;s result (Anderson, 1982: a result on time-reversal of diffusion processes showing that the reverse is also a diffusion, with the forward drift reversed and a correction term involving the score of the marginal density) that the reverse of a diffusion process is also a diffusion process, governed by the score (gradient of log density).</li>
<li><strong>Predictor-Corrector (PC) Samplers:</strong> A hybrid sampling strategy where a numerical SDE solver (Predictor) estimates the next step, and a score-based MCMC approach (Corrector) corrects the marginal distribution.</li>
<li><strong>Probability Flow ODE:</strong> Deriving a deterministic ODE that shares the same marginal densities as the SDE, enabling near-exact likelihood computation (accuracy is limited by both numerical ODE solver discretization and variance of the unbiased Hutchinson trace estimator) and latent space manipulation.</li>
<li><strong>Sub-VP SDE:</strong> A new SDE class proposed to improve likelihoods by bounding variance tighter than the VP SDE.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the framework on standard image benchmarks:</p>
<ul>
<li><strong>Datasets:</strong> CIFAR-10 (32x32), CelebA (64x64), LSUN (Bedroom, Church), and CelebA-HQ (256x256 and 1024x1024).</li>
<li><strong>Ablation Studies:</strong> Comparing samplers (Ancestral vs. Reverse Diffusion vs. Probability Flow vs. PC) and SDE types (VE, VP, sub-VP).</li>
<li><strong>Architecture Search:</strong> Exploring improvements like FIR up/downsampling, rescaling skip connections, and increasing depth (leading to NCSN++ and DDPM++ architectures).</li>
<li><strong>Likelihood Evaluation:</strong> Computing Negative Log-Likelihood (NLL) in bits/dim using the Probability Flow ODE.</li>
<li><strong>Inverse Problems:</strong> Testing class-conditional generation, inpainting, and colorization using the conditional reverse-time SDE.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Record Performance:</strong> The <strong>NCSN++ cont. (deep, VE)</strong> model achieved an Inception Score of 9.89 and FID of 2.20 on CIFAR-10 (as of ICLR 2021).</li>
<li><strong>High-Fidelity Generation:</strong> First score-based model to generate 1024x1024 images (CelebA-HQ).</li>
<li><strong>Competitive Likelihoods:</strong> The <strong>DDPM++ cont. (deep, sub-VP)</strong> model achieved 2.99 bits/dim on uniformly dequantized CIFAR-10, a record at the time.</li>
<li><strong>Sampling Efficiency:</strong> PC samplers consistently outperformed predictor-only methods (like standard ancestral sampling) for the same computational cost.</li>
<li><strong>Controllable Generation:</strong> Successful application to inpainting and colorization using a single unconditional model.</li>
<li><strong>Limitations:</strong> Sampling remains slower than GANs on the same datasets. The breadth of available samplers introduces many hyperparameters (SDE type, predictor, corrector, signal-to-noise ratio, number of steps) that require tuning.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>CIFAR-10</strong>: Used for main benchmarking (FID, Inception Score, NLL).</li>
<li><strong>CelebA-HQ</strong>: Used for high-resolution experiments at 256x256 and 1024x1024.</li>
<li><strong>LSUN</strong>: Bedroom and Church Outdoor categories (256x256) used for sampler comparison and controllable generation (inpainting, colorization).</li>
<li><strong>Preprocessing</strong>: CIFAR-10 images are 32x32; CelebA pre-processed to 64x64 following Song &amp; Ermon (2020). Data is typically scaled to $[0, 1]$ or standardized depending on the specific SDE config.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Forward SDEs</strong>:</p>
<p>Here $dw$ denotes a Wiener process increment (a small, independent Gaussian noise burst at each timestep).</p>
<ul>
<li><strong>VE SDE (Variance Exploding)</strong>: $dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw$. Corresponds to SMLD. Used with $\sigma_{\min}=0.01$ and $\sigma_{\max}$ chosen via heuristics.</li>
<li><strong>VP SDE (Variance Preserving)</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)} dw$. Corresponds to DDPM.</li>
<li><strong>Sub-VP SDE</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)ds})} dw$. Bounded variance, good for likelihoods.</li>
</ul>
<p><strong>Reverse-Time SDE Solver (Predictor)</strong>:</p>
<ul>
<li>Discretized via <strong>Reverse Diffusion Sampling</strong>, which matches the forward discretization.</li>
<li><strong>Euler-Maruyama</strong> solver used for continuously-trained models.</li>
</ul>
<p><strong>Corrector Algorithm</strong>:</p>
<ul>
<li><strong>Langevin MCMC</strong>: Applies annealed Langevin dynamics: adds noise and takes a score-guided gradient step to correct the marginal distribution at each timestep.</li>
<li><strong>PC Sampling</strong>: Alternates between one step of the Predictor and one step of the Corrector.</li>
<li><strong>Signal-to-Noise Ratio ($r$)</strong>: A hyperparameter for the corrector step size. Tuned values: $r \approx 0.16$ for VE SDEs on CIFAR-10.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>NCSN++</strong>: Optimized architecture for VE SDEs. Key features:
<ul>
<li>4 residual blocks per resolution.</li>
<li>BigGAN-type residual blocks.</li>
<li>Rescaling skip connections by $1/\sqrt{2}$.</li>
<li>FIR (Finite Impulse Response) up/downsampling.</li>
<li>&ldquo;Residual&rdquo; progressive architecture for input, no progressive growing for output.</li>
</ul>
</li>
<li><strong>DDPM++</strong>: Optimized architecture for VP/sub-VP SDEs. Similar to NCSN++ but without FIR upsampling and no progressive growing.</li>
<li><strong>Deep Variants</strong>: &ldquo;cont. (deep)&rdquo; models double the depth (from 4 to 8 blocks per resolution) for the best reported results.</li>
<li><strong>Conditioning</strong>: Time $t$ is conditioned via random Fourier feature embeddings (scale 16) for continuous models.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>FID (Fréchet Inception Distance)</strong>: Computed on 50k samples.</li>
<li><strong>Inception Score</strong>: Reported for CIFAR-10.</li>
<li><strong>NLL (Negative Log-Likelihood)</strong>: Reported in bits/dim on uniformly dequantized data using the Probability Flow ODE.</li>
</ul>
<p><strong>Denoising</strong>: A single denoising step using Tweedie&rsquo;s formula is applied at the end of sampling to remove residual noise, which significantly improves FID.</p>
<h3 id="hardware">Hardware</h3>
<p><strong>Training</strong>:</p>
<ul>
<li>Batch size: 128 for CIFAR-10, 64 for LSUN, 8 for high-res CelebA-HQ.</li>
<li>Iterations: Discrete-objective models trained for 1.3M iterations during architecture exploration. Continuous-objective models (cont.) trained for 0.95M iterations. High-res CelebA-HQ (1024x1024) trained for approximately 2.4M iterations.</li>
<li><strong>EMA</strong>: Exponential Moving Average rate of 0.999 used for VE models, 0.9999 for VP models.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yang-song/score_sde">yang-song/score_sde</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX and PyTorch implementation with pretrained checkpoints</td>
      </tr>
  </tbody>
</table>
<p>All datasets used (CIFAR-10, CelebA-HQ, LSUN) are publicly available. Pretrained model checkpoints for CIFAR-10, CelebA-HQ, and FFHQ are provided in the repository. Specific hardware requirements (GPU type, training time) are not detailed in the paper.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., &amp; Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. <em>ICLR 2021</em>. <a href="https://arxiv.org/abs/2011.13456">https://arxiv.org/abs/2011.13456</a></p>
<p><strong>Publication</strong>: ICLR 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2021scorebased,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Score-Based Generative Modeling through Stochastic Differential Equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Sohl-Dickstein, Jascha and Kingma, Diederik P and Kumar, Abhishek and Ermon, Stefano and Poole, Ben}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://openreview.net/forum?id=PxTIG12RRHS}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/yang-song/score_sde">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/">Score Matching and Denoising Autoencoders</a></li>
</ul>
]]></content:encoded></item><item><title>Score Matching and Denoising Autoencoders: A Connection</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</guid><description>Theoretical paper proving the equivalence between training Denoising Autoencoders and performing Score Matching on a Parzen density estimator.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Theory Paper</strong>.</p>
<p>Its primary contribution is a formal mathematical derivation connecting two previously distinct techniques: Score Matching (SM) and Denoising Autoencoders (DAE). It provides the &ldquo;why&rdquo; behind the empirical success of DAEs by grounding them in the probabilistic framework of energy-based models. It relies on proofs and equivalence relations (e.g., $J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$).</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper bridges a gap between two successful but disconnected approaches in unsupervised learning:</p>
<ol>
<li><strong>Denoising Autoencoders (DAE):</strong> Empirically successful for pre-training deep networks. They previously lacked a clear probabilistic interpretation.</li>
<li><strong>Score Matching (SM):</strong> A theoretically sound method for estimating unnormalized density models that avoids the partition function problem but requires computing expensive second derivatives.</li>
</ol>
<p>By connecting them, the authors aim to define a proper probabilistic model for DAEs (allowing sampling/ranking) and find a simpler way to apply score matching that avoids second derivatives.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Denoising Score Matching (DSM)</strong> framework and the proof of its equivalence to DAEs. Key contributions include:</p>
<ul>
<li><strong>Equivalence Proof:</strong> Showing that training a DAE with Gaussian noise is equivalent to matching the score of a model against a non-parametric Parzen density estimator of the data.</li>
<li><strong>Denoising Score Matching ($J_{DSM}$):</strong> A new objective that learns a score function by trying to denoise corrupted samples. This avoids the explicit second derivatives required by standard Implicit Score Matching ($J_{ISM}$).</li>
<li><strong>Explicit Energy Function:</strong> Deriving the specific energy function $E(x;\theta)$ that corresponds to the standard sigmoid DAE architecture.</li>
<li><strong>Justification for Tied Weights:</strong> Providing a theoretical justification for tying encoder and decoder weights, which arises naturally from differentiating the energy function.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation in this theoretical paper is purely mathematical and focuses on formal proofs:</p>
<ul>
<li><strong>Derivation of Equivalence:</strong> The paper formally proves the chain of equivalences:
$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$
where $q_{\sigma}$ is the Parzen density estimate.</li>
<li><strong>Appendix Proof:</strong> A detailed proof is provided to show that Explicit Score Matching ($J_{ESM}$) on the Parzen density is equivalent to the proposed Denoising Score Matching ($J_{DSM}$) objective.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Theoretical Unification:</strong> DAE training is formally equivalent to Score Matching on a smoothed data distribution ($q_{\sigma}$).</li>
<li><strong>New Training Objective:</strong> The $J_{DSM}$ objective offers a computationally efficient way to perform score matching (no Hessian required) by using a denoising objective.</li>
<li><strong>Probabilistic Interpretation:</strong> DAEs can now be understood as Energy-Based Models (EBMs), allowing for operations like sampling (via Hybrid Monte Carlo) and likelihood ranking, which were previously ill-defined for standard autoencoders.</li>
<li><strong>Regularization Insight:</strong> The smoothing kernel width $\sigma$ in the Parzen estimator corresponds to the noise level in the DAE. This suggests that DAEs are learning a regularized version of the score, which may explain their robustness.</li>
<li><strong>Connection to Regularized Score Matching:</strong> The paper notes that Kingma and LeCun (2010) independently proposed a regularized score matching criterion $J_{ISMreg}$ derived by approximating $J_{ISMq_{\sigma}}$. The four $q_{\sigma}$-based objectives in this work (including the DAE objective) can be seen as approximation-free forms of regularized score matching, with the additional advantage that $J_{DSMq_{\sigma}}$ does not require second derivatives.</li>
</ul>
<hr>
<h2 id="key-concepts-explained">Key Concepts Explained</h2>
<h3 id="1-score-and-score-matching">1. &ldquo;Score&rdquo; and &ldquo;Score Matching&rdquo;</h3>
<p><strong>What does &ldquo;score&rdquo; actually mean?</strong></p>
<p>In this paper (and probabilistic modeling generally), the <strong>score</strong> is the gradient of the log-density with respect to the <em>data vector</em> $x$.</p>
<ul>
<li><strong>Definition:</strong> $\psi(x) = \nabla_x \log p(x)$.</li>
<li><strong>Intuition:</strong> It is a vector field pointing in the direction of highest probability increase. Crucially, calculating the score avoids the intractable partition function $Z$, because $\nabla_x \log p(x) = \nabla_x \log \tilde{p}(x) - \nabla_x \log Z = \nabla_x \log \tilde{p}(x)$. The constant $Z$ vanishes upon differentiation.</li>
</ul>
<p><strong>What is Score Matching?</strong></p>
<p>Score Matching is a training objective for unnormalized models. It minimizes the squared Euclidean distance between the model&rsquo;s score $\psi(x;\theta)$ and the data&rsquo;s true score $\nabla_x \log q(x)$.</p>
<h3 id="2-the-parzen-density-estimator">2. The Parzen Density Estimator</h3>
<p><strong>What is it?</strong></p>
<p>It is a non-parametric method for estimating a probability density function from finite data. It places a smooth kernel (here, a Gaussian) centered at every data point in the training set $D_n$.</p>
<ul>
<li><strong>Formula:</strong> $q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n \mathcal{N}(\tilde{x}; x^{(t)}, \sigma^2 I)$.</li>
</ul>
<p><strong>Why smooth the data?</strong></p>
<ol>
<li>
<p><strong>To define the score:</strong> The empirical data distribution is a set of Dirac deltas (spikes). The gradient (score) of a Dirac delta is undefined. Smoothing creates a differentiable surface, allowing a valid target score $\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x})$ to be computed.</p>
</li>
<li>
<p><strong>To model corruption:</strong> The Parzen estimator with Gaussian kernels mathematically models the process of taking a clean data point $x$ and adding Gaussian noise - the exact procedure used in Denoising Autoencoders.</p>
</li>
</ol>
<h3 id="3-why-avoiding-second-derivatives-matters">3. Why avoiding second derivatives matters</h3>
<p>Standard <strong>Implicit Score Matching (ISM)</strong> eliminates the need for the unknown data score, but introduces a new cost: it requires computing the trace of the Hessian (the sum of second partial derivatives) of the log-density.</p>
<ul>
<li><strong>The Cost:</strong> For high-dimensional data (like images) and deep networks, computing second derivatives of the log-density is computationally expensive.</li>
<li>This paper shows that <strong>Denoising Score Matching (DSM)</strong> allows you to bypass Hessian computation entirely. By using the Parzen target, the objective simplifies to matching a first-order vector, making it scalable to deep neural networks.</li>
</ul>
<h3 id="4-the-equivalence-chain---why-each-step">4. The equivalence chain - why each step?</h3>
<p>The chain $J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ connects the concepts.</p>
<ul>
<li>
<p><strong>$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}}$ (Implicit $\to$ Explicit):</strong>
<strong>Why:</strong> Integration by parts. This is Hyvärinen&rsquo;s original proof (2005): integration by parts moves the derivative from $\psi$ onto the data density $q$, producing a term involving $q$&rsquo;s gradient (the score). The boundary term vanishes because $q_{\sigma}$ decays to zero at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching). The result allows replacing the unknown data score with a computable term involving only the model&rsquo;s score and its Jacobian.</p>
</li>
<li>
<p><strong>$J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$ (Explicit $\to$ Denoising):</strong>
<strong>Why:</strong> The explicit score of the Parzen density is known. When $x$ is perturbed to $\tilde{x}$ by Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$, the gradient of the log-density pointing back to the mean is exactly $\frac{1}{\sigma^2}(x - \tilde{x})$. Minimizing the error against the true score becomes minimizing the error against this restoration vector.</p>
</li>
<li>
<p><strong>$J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ (Denoising $\to$ Autoencoder):</strong>
<strong>Why:</strong> Algebraic substitution. If you define the model&rsquo;s score $\psi(\tilde{x};\theta)$ to be proportional to the reconstruction error ($\propto x^r - \tilde{x}$), the score matching loss $J_{DSM}$ becomes proportional to the standard autoencoder squared loss $|x^r - x|^2$.</p>
</li>
</ul>
<h3 id="5-energy-based-models-ebms-connection">5. Energy-Based Models (EBMs) connection</h3>
<p><strong>What is an EBM?</strong></p>
<p>An EBM defines a probability distribution via an energy function $E(x;\theta)$, where $p(x;\theta) \propto e^{-E(x;\theta)}$.</p>
<p><strong>Why standard autoencoders lack probabilistic interpretation:</strong></p>
<p>A standard autoencoder acts as a deterministic map $x \to x^r$, providing a reconstruction error. It lacks a normalization constant or a defined density function to support sampling or probability queries.</p>
<p><strong>What does this enable?</strong></p>
<p>By proving the equivalence, the DAE is formally defined as an EBM. This enables:</p>
<ol>
<li><strong>Sampling:</strong> Using MCMC methods (like Hybrid Monte Carlo) to generate new data from the DAE.</li>
<li><strong>Ranking:</strong> Calculating the energy of inputs to determine which are more &ldquo;likely&rdquo; or &ldquo;normal&rdquo; (useful for anomaly detection).</li>
</ol>
<h3 id="6-the-specific-energy-function-form">6. The specific energy function form</h3>
<p>The function is:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p><strong>Why does it have that specific form?</strong></p>
<p>It was derived via integration to ensure its derivative matches the DAE architecture. The authors worked backward from the DAE&rsquo;s reconstruction function (sigmoid + linear) to find the scalar field that generates it.</p>
<p><strong>Where does the quadratic term come from?</strong></p>
<p>The score (negative energy gradient) needs to look like $\psi(x) \propto c - x + W^T\text{sigmoid}(Wx + b)$.</p>
<ul>
<li>The term $-x$ in the score arises because $\nabla_x(-\frac{1}{2}|x|^2) = -x$. Including $-\frac{1}{2}|x|^2$ inside the energy&rsquo;s numerator produces this linear term after differentiation.</li>
</ul>
<p><strong>How does differentiating it recover the DAE reconstruction?</strong></p>
<ul>
<li>$\nabla_x \sum_j \text{softplus}(\langle W_j, x \rangle + b_j) = W^T \sigma(Wx + b)$ (The encoder part).</li>
<li>$\nabla_x \langle c, x \rangle = c$ (The bias).</li>
<li>$\nabla_x (-\frac{1}{2}|x|^2) = -x$ (The input subtraction).</li>
<li>Result: $-\nabla_x E \propto c + W^T h - x = x^r - x$.</li>
</ul>
<h3 id="7-tied-weights-justification">7. &ldquo;Tied weights&rdquo; justification</h3>
<p><strong>What does it mean for weights to be &ldquo;tied&rdquo;?</strong></p>
<p>The decoder matrix is the transpose of the encoder matrix ($W^T$).</p>
<p><strong>Why is this theoretically justified?</strong></p>
<p>Because the reconstruction function is interpreted as the <strong>gradient</strong> of an energy function. A vector field can only be the gradient of a scalar field if its Jacobian is symmetric.</p>
<ul>
<li>In the DAE energy derivative, the encoder contributes $W^T \sigma(Wx + b)$. If the decoder used a separate matrix $U$, the resulting vector field would not be a valid gradient of any scalar energy function (unless $U = W^T$).</li>
<li>Therefore, for a DAE to correspond to a valid probabilistic Energy-Based Model, the weights <em>must</em> be tied.</li>
</ul>
<p><strong>The necessity of tied weights:</strong></p>
<p>Within this parametrization, tied weights are a mathematical necessity: a separate decoder matrix $U \neq W^T$ would make the reconstruction function an invalid gradient of any scalar energy, breaking the EBM correspondence.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a theoretical paper, the &ldquo;reproducibility&rdquo; lies in the mathematical formulations derived.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Data ($D_n$):</strong> The theory assumes a set of training examples $D_n = {x^{(1)}, &hellip;, x^{(n)}}$ drawn from an unknown true pdf $q(x)$.</li>
<li><strong>Parzen Density Estimate ($q_{\sigma}$):</strong> The theoretical targets are derived from a kernel-smoothed empirical distribution:
$$q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n q_{\sigma}(\tilde{x}|x^{(t)})$$
where the kernel is an isotropic Gaussian of variance $\sigma^2$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Denoising Score Matching (DSM) Objective</strong></p>
<p>The paper proposes this objective as a tractable alternative to standard score matching. It minimizes the distance between the model score and the gradient of the log-noise density:</p>
<p>$$J_{DSMq_{\sigma}}(\theta) = \mathbb{E}_{q_{\sigma}(x,\tilde{x})} \left[ \frac{1}{2} \left| \psi(\tilde{x};\theta) - \frac{\partial \log q_{\sigma}(\tilde{x}|x)}{\partial \tilde{x}} \right|^2 \right]$$</p>
<p>For Gaussian noise, the target score is simply $\frac{1}{\sigma^2}(x - \tilde{x})$.</p>
<p><strong>2. Equivalence Chain</strong></p>
<p>The central result connects four objectives:</p>
<p>$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$</p>
<p>This implies optimizing the DAE reconstruction error is minimizing a score matching objective.</p>
<h3 id="models">Models</h3>
<p><strong>1. The Denoising Autoencoder (DAE)</strong></p>
<ul>
<li><strong>Corruption:</strong> Additive isotropic Gaussian noise $\tilde{x} = x + \epsilon, \epsilon \sim \mathcal{N}(0, \sigma^2 I)$.</li>
<li><strong>Encoder:</strong> $h = \text{sigmoid}(W\tilde{x} + b)$.</li>
<li><strong>Decoder:</strong> $x^r = W^T h + c$ (Tied weights $W$).</li>
<li><strong>Loss:</strong> Squared reconstruction error $|x^r - x|^2$. (The equivalence with DSM introduces a $\frac{1}{2\sigma^4}$ scaling factor.)</li>
</ul>
<p><strong>2. The Corresponding Energy Function</strong></p>
<p>To make the DAE equivalent to Score Matching, the underlying Energy-Based Model $p(x;\theta) \propto e^{-E(x;\theta)}$ must have the following energy function:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p>Note the scaling by $1/\sigma^2$ and the quadratic term $|x|^2$.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric:</strong> Theoretical Equivalence ($\sim$).</li>
<li><strong>Condition:</strong> The equivalence holds provided $\sigma &gt; 0$ and the density $q_{\sigma}$ is differentiable and vanishes at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Vincent, P. (2011). A Connection Between Score Matching and Denoising Autoencoders. <em>Neural Computation</em>, 23(7), 1661-1674. <a href="https://doi.org/10.1162/NECO_a_00142">https://doi.org/10.1162/NECO_a_00142</a></p>
<p><strong>Publication</strong>: Neural Computation 2011</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vincentConnectionScoreMatching2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Connection Between Score Matching}} and {{Denoising Autoencoders}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Vincent, Pascal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2011</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Neural Computation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1661--1674}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1162/NECO_a_00142}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Official PDF</a></li>
</ul>
]]></content:encoded></item><item><title>Neural ODEs: Continuous-Depth Deep Learning Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/neural-odes/</guid><description>Introduces ODE-Nets, a continuous-depth neural network model parameterized by ODEs, enabling constant memory backpropagation and adaptive computation.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Key Prerequisites</strong>: Before diving in, note that for the ODE solver to guarantee a unique solution, the neural network $f(h(t), t, \theta)$ parameterizing the dynamics must be <a href="https://en.wikipedia.org/wiki/Lipschitz_continuity">Lipschitz continuous</a>. This ensures the <a href="https://en.wikipedia.org/wiki/Picard%E2%80%93Lindel%C3%B6f_theorem">Picard-Lindelöf theorem</a> holds, preventing trajectories from crossing and guaranteeing a well-defined backward pass.</p></blockquote>
<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel family of deep neural network models where the derivative of the hidden state is parameterized by a neural network. It provides specific algorithms (Algorithm 1) for training these models scalably.</li>
<li><strong>Theory</strong>: It derives the adjoint sensitivity method for backpropagating through black-box ODE solvers and proves the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1) for continuous normalizing flows.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors aim to address limitations in discrete deep learning architectures:</p>
<ul>
<li><strong>Discrete vs. Continuous</strong>: Existing models like Residual Networks build transformations by composing discrete steps, which can be seen as an Euler discretization of a continuous transformation. The authors investigate the limit as step sizes go to zero.</li>
<li><strong>Memory Efficiency</strong>: Backpropagating through deep discrete networks requires storing intermediate activations, leading to linear memory cost in terms of depth, which is a major bottleneck.</li>
<li><strong>Irregular Data</strong>: Recurrent Neural Networks (RNNs) struggle with data arriving at arbitrary times, typically requiring discretization into fixed bins.</li>
<li><strong>Normalizing Flow Costs</strong>: Standard normalizing flows have a bottleneck in computing the determinant of the Jacobian, which is computationally expensive ($O(D^3)$).</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core contribution is the <strong>Neural ODE</strong> formulation:
$$\frac{dh(t)}{dt} = f(h(t), t, \theta)$$
where the output is computed using a black-box differential equation solver.</p>
<p>Key technical innovations include:</p>
<ol>
<li><strong>Adjoint Sensitivity Method for Backprop</strong>: The authors treat the solver as a black box and compute gradients by solving a second, augmented ODE backwards in time. This allows for <strong>constant memory cost</strong> regardless of depth.</li>
<li><strong>Adaptive Computation</strong>: The model uses modern ODE solvers that adapt evaluation steps based on error tolerance, allowing the model to trade precision for speed explicitly.</li>
<li><strong>Continuous Normalizing Flows (CNF)</strong>: By moving to continuous time, the change of variables formula simplifies from a log-determinant (cubic cost) to a trace operation (linear cost), enabling scalable generative modeling.</li>
<li><strong>Latent ODEs</strong>: A generative time-series model that represents time-series as latent trajectories determined by a local initial state and global shared dynamics, handling irregular sampling naturally.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across three distinct domains:</p>
<ol>
<li><strong>Supervised Learning (MNIST)</strong>:
<ul>
<li>Compared <strong>ODE-Net</strong> against a standard <strong>ResNet</strong> and a Runge-Kutta network (<strong>RK-Net</strong>).</li>
<li>Measured test error, parameter count, and memory usage.</li>
<li>Analyzed the trade-off between numerical precision (tolerance) and speed (NFE).</li>
</ul>
</li>
<li><strong>Continuous Normalizing Flows (Generative)</strong>:
<ul>
<li>Compared CNF against standard Normalizing Flows (NF) on density matching and maximum likelihood estimation tasks using toy 2D datasets (Two Circles, Two Moons, and other target distributions).</li>
<li>Evaluated training loss (KL divergence) and maximum likelihood estimation.</li>
</ul>
</li>
<li><strong>Time-Series Modeling (Latent ODE)</strong>:
<ul>
<li>Tested on a dataset of bi-directional spirals with irregular timestamps and Gaussian noise.</li>
<li>Compared Latent ODEs against an RNN baseline on predictive RMSE. A second RNN variant with time-difference concatenation was also trained.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Efficiency</strong>: ODE-Nets achieved roughly equivalent accuracy to ResNets on MNIST (0.42% vs 0.41% error) but with <strong>constant memory cost</strong> ($O(1)$) compared to ResNet&rsquo;s linear cost ($O(L)$).</li>
<li><strong>Adaptive Depth</strong>: The number of function evaluations (NFE) in ODE-Nets increases with training epoch, suggesting the model adapts its complexity as it learns. The backward pass NFE is roughly half the forward pass NFE, indicating that the adjoint method is also more computationally efficient than direct backpropagation through the integrator.</li>
<li><strong>Generative Performance</strong>: Continuous Normalizing Flows (CNF) achieved lower KL divergence loss than standard Normalizing Flows (NF), trained with only 10,000 iterations (Adam) compared to 500,000 iterations (RMSprop) for NF. Note that the two models used different optimizers, so the comparison is not fully controlled. CNF can also expand capacity by increasing width ($M$) without architectural constraints.</li>
<li><strong>Irregular Time-Series</strong>: Latent ODEs significantly outperformed RNNs across all observation counts on irregular spiral data. The advantage is most pronounced with sparse observations (0.1642 vs 0.3937 RMSE at 30 obs), and the model learns interpretable latent trajectories that switch direction smoothly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: Standard handwritten digit dataset used for supervised learning benchmarks.</li>
<li><strong>Toy 2D Densities</strong>: &ldquo;Two Circles&rdquo; and &ldquo;Two Moons&rdquo; distributions used for visualizing normalizing flows.</li>
<li><strong>Bi-directional Spirals</strong>: A generated dataset of 1,000 2D spirals (half clockwise, half counter-clockwise). Each spiral is sampled at 100 equally-spaced timesteps with added Gaussian noise. For training, each spiral is then subsampled without replacement to $n \in {30, 50, 100}$ irregularly-spaced observations, simulating realistic missing data.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Adjoint Sensitivity Method (Backpropagation)</strong></p>
<p>To optimize the parameters of the ODE-Net, the authors use the adjoint sensitivity method to compute gradients. Standard backpropagation would require storing the activations at every step of the ODE solver, incurring a high memory cost that scales linearly with the number of steps.</p>
<p>Instead, this method treats the ODE solver as a &ldquo;black box&rdquo; and computes gradients by solving a second, <strong>augmented ODE</strong> backwards in time from the final state $t_1$ to the initial state $t_0$.</p>
<p>The augmented state contains three components that are solved simultaneously:</p>
<ol>
<li><strong>The State</strong>: The original hidden state $z(t)$, which is reconstructed backwards.</li>
<li><strong>The Adjoint</strong>: The sensitivity of the loss with respect to the state, $a(t) = \partial L / \partial z(t)$.</li>
<li><strong>The Gradient</strong>: The accumulating gradients with respect to parameters, $\partial L / \partial \theta$.</li>
</ol>
<p>The dynamics of this augmented system are defined as:
$$\frac{d}{dt}\begin{bmatrix} z(t) \ a(t) \ \partial L/\partial \theta \end{bmatrix} = \begin{bmatrix} f(z(t), t, \theta) \ -a(t)^T \frac{\partial f}{\partial z} \ -a(t)^T \frac{\partial f}{\partial \theta} \end{bmatrix}$$</p>
<p>Using this approach, the vector-Jacobian products (e.g., $a(t)^T \frac{\partial f}{\partial z}$) are evaluated efficiently using automatic differentiation.</p>
<blockquote>
<p><strong>Why:</strong> Reconstructing $z(t)$ backwards avoids storing the forward pass, enabling <strong>constant memory cost</strong> ($O(1)$) regardless of depth.</p>
<p><strong>Origin:</strong> Adapted from Pontryagin&rsquo;s maximum principle (1962) for optimal control.</p></blockquote>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> torchdiffeq <span style="color:#f92672">import</span> odeint_adjoint
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEFunc</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, dim):
</span></span><span style="display:flex;"><span>        super(ODEFunc, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>net <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(dim, <span style="color:#ae81ff">50</span>),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">50</span>, dim),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, t, y):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Defines dy/dt = f(y, t)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>net(y)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Usage with adjoint method for O(1) memory backprop</span>
</span></span><span style="display:flex;"><span>func <span style="color:#f92672">=</span> ODEFunc(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">2</span>)
</span></span><span style="display:flex;"><span>y0 <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([[<span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">0.</span>]]) <span style="color:#75715e"># Initial state</span>
</span></span><span style="display:flex;"><span>t <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>linspace(<span style="color:#ae81ff">0.</span>, <span style="color:#ae81ff">1.</span>, <span style="color:#ae81ff">10</span>) <span style="color:#75715e"># Time points to solve for</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># &#39;odeint_adjoint&#39; automatically handles the augmented state backward pass</span>
</span></span><span style="display:flex;"><span>out <span style="color:#f92672">=</span> odeint_adjoint(func, y0, t, method<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;dopri5&#39;</span>)
</span></span></code></pre></div><p><strong>2. Instantaneous Change of Variables (CNF)</strong></p>
<p>For generative modeling, the authors introduce <strong>Continuous Normalizing Flows (CNF)</strong>. In discrete normalizing flows, the probability density of a transformed variable is calculated using the change of variables theorem, which requires computing the log-determinant of the Jacobian: $\log p(z_1) = \log p(z_0) - \log |\det \frac{\partial z_1}{\partial z_0}|$. This operation is computationally expensive ($O(D^3)$) and often restricts model architectures to ensure the Jacobian is easy to compute (e.g., triangular).</p>
<p>Moving to continuous time simplifies this requirement. The paper proves that if the transformation is defined by an ODE, the change in log-probability follows a differential equation determined by the <strong>trace</strong> of the Jacobian:
$$\frac{\partial \log p(z(t))}{\partial t} = -\text{tr}\left( \frac{\partial f}{\partial z(t)} \right)$$</p>
<p>The total change in log-density is obtained by integrating this value over time.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_trace</span>(y, f):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    Computes trace of Jacobian df/dy.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    For high dimensions, use Hutchinson&#39;s trace estimator (approximate).
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">    &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    tr <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.</span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">for</span> i <span style="color:#f92672">in</span> range(y<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">1</span>)):
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Gradients of f&#39;s i-th component w.r.t y&#39;s i-th component</span>
</span></span><span style="display:flex;"><span>        tr <span style="color:#f92672">+=</span> torch<span style="color:#f92672">.</span>autograd<span style="color:#f92672">.</span>grad(f[:, i]<span style="color:#f92672">.</span>sum(), y, create_graph<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>)[<span style="color:#ae81ff">0</span>][:, i]
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> tr
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In the ODE function:</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># d(log_p)/dt = -trace(df/dy)</span>
</span></span></code></pre></div><blockquote>
<p><strong>Why:</strong> The trace operator has <strong>linear cost</strong> ($O(D)$), whereas the determinant has cubic cost ($O(D^3)$). This allows for unrestricted, &ldquo;wide&rdquo; architectures that are automatically bijective.</p>
<p><strong>Origin:</strong> This is the &ldquo;Instantaneous Change of Variables&rdquo; theorem (Theorem 1), derived in Appendix A of the paper.</p></blockquote>
<h3 id="models">Models</h3>
<p><strong>ODE-Net (MNIST Classification)</strong>:</p>
<ul>
<li><strong>Input</strong>: Downsamples input twice.</li>
<li><strong>Core</strong>: 6 standard residual blocks replaced by a single <strong>ODESolve</strong> module.</li>
<li><strong>Output</strong>: Global average pooling + Fully connected layer.</li>
<li><strong>Solver</strong>: Implicit Adams method.</li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">ODEBlock</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, odefunc):
</span></span><span style="display:flex;"><span>        super(ODEBlock, self)<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>odefunc <span style="color:#f92672">=</span> odefunc
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>tensor([<span style="color:#ae81ff">0</span>, <span style="color:#ae81ff">1</span>])<span style="color:#f92672">.</span>float()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x):
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>integration_time <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>integration_time<span style="color:#f92672">.</span>type_as(x)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Returns [x(t0), x(t1)]; we only want final state x(t1)</span>
</span></span><span style="display:flex;"><span>        out <span style="color:#f92672">=</span> odeint_adjoint(self<span style="color:#f92672">.</span>odefunc, x, self<span style="color:#f92672">.</span>integration_time)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> out[<span style="color:#ae81ff">1</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># ResNet-like architecture with ODE block</span>
</span></span><span style="display:flex;"><span>model <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Conv2d(<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">3</span>, <span style="color:#ae81ff">1</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>ReLU(inplace<span style="color:#f92672">=</span><span style="color:#66d9ef">True</span>),
</span></span><span style="display:flex;"><span>    ODEBlock(ODEFunc(<span style="color:#ae81ff">64</span>)), <span style="color:#75715e"># Continuous-depth layer replacement</span>
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>BatchNorm2d(<span style="color:#ae81ff">64</span>),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>AdaptiveAvgPool2d((<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">1</span>)),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>    nn<span style="color:#f92672">.</span>Linear(<span style="color:#ae81ff">64</span>, <span style="color:#ae81ff">10</span>)
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Latent ODE (Time-Series)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: RNN with 25 hidden units processing data backwards to produce $q(z_0|x)$. It runs backwards so the final RNN state summarizes the entire sequence at $t_0$, parameterizing the initial latent state $z_0$ for the forward-running ODE.</li>
<li><strong>Latent Space</strong>: 4-dimensional latent state $z_0$.</li>
<li><strong>Dynamics ($f$)</strong>: Neural network with one hidden layer of 20 units.</li>
<li><strong>Decoder</strong>: Neural network with one hidden layer of 20 units computing $p(x_{t_i}|z_{t_i})$.</li>
<li><strong>Likelihood</strong>: Gaussian log-likelihood for the spiral reconstruction task. The paper also describes an optional Poisson process likelihood $\lambda(z(t))$ for event-time data (e.g., medical records), but this is not used in the spiral experiment.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Experiment</th>
          <th>Metric</th>
          <th>Baseline (ResNet/RNN)</th>
          <th>ODE Model</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MNIST</td>
          <td>Test Error</td>
          <td>0.41%</td>
          <td>0.42%</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Parameters</td>
          <td>0.60 M</td>
          <td>0.22 M</td>
      </tr>
      <tr>
          <td>MNIST</td>
          <td>Memory</td>
          <td>$O(L)$</td>
          <td>$O(1)$</td>
      </tr>
      <tr>
          <td>Spirals (30 obs)</td>
          <td>RMSE</td>
          <td>0.3937</td>
          <td><strong>0.1642</strong></td>
      </tr>
      <tr>
          <td>Spirals (50 obs)</td>
          <td>RMSE</td>
          <td>0.3202</td>
          <td><strong>0.1502</strong></td>
      </tr>
      <tr>
          <td>Spirals (100 obs)</td>
          <td>RMSE</td>
          <td>0.1813</td>
          <td><strong>0.1346</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Implementation</strong>: Hidden state dynamics evaluated on GPU using <strong>TensorFlow</strong>.</li>
<li><strong>Solvers</strong>: Fortran ODE solvers (LSODE, VODE) from <code>scipy.integrate</code> were used for the actual integration.</li>
<li><strong>Note</strong>: While the original paper used TensorFlow/Scipy, the authors later released <code>torchdiffeq</code> (PyTorch), which has become the standard implementation for this architecture. The code samples above reflect this modern standard.</li>
<li><strong>Interface</strong>: Python&rsquo;s <code>autograd</code> framework bridged the TensorFlow dynamics and Scipy solvers.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper identifies several practical limitations of Neural ODEs:</p>
<ul>
<li><strong>Minibatching</strong>: Batching requires concatenating states of each batch element into a combined ODE of dimension $D \times K$. Controlling error on all batch elements together can require more evaluations than solving each system individually, though in practice this overhead was not substantial.</li>
<li><strong>Tolerance tuning</strong>: Users must choose error tolerances for both the forward and reverse passes. The paper used 1.5e-8 for sequence modeling, 1e-3 for classification, and 1e-5 for density estimation.</li>
<li><strong>Backward trajectory reconstruction</strong>: Running the dynamics backwards to reconstruct the forward state trajectory can introduce extra numerical error if the reconstructed trajectory diverges from the original. Checkpointing (storing intermediate states) can address this, though the authors did not find it necessary in practice.</li>
<li><strong>Uniqueness requirements</strong>: The neural network $f$ must be Lipschitz continuous (e.g., using tanh or ReLU activations with finite weights) to guarantee a unique solution via Picard&rsquo;s existence theorem.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/rtqichen/torchdiffeq">torchdiffeq</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with GPU-based ODE solvers</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, R. T. Q., Rubanova, Y., Bettencourt, J., &amp; Duvenaud, D. (2018). Neural ordinary differential equations. <em>Proceedings of the 32nd International Conference on Neural Information Processing Systems</em>, 6572-6583.</p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{chen2018neural,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Neural ordinary differential equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 32nd International Conference on Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{6572--6583}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/rtqichen/torchdiffeq">Official PyTorch Implementation</a></li>
</ul>
]]></content:encoded></item><item><title>Flow Matching for Generative Modeling: Scalable CNFs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</guid><description>A simulation-free framework for training Continuous Normalizing Flows using Conditional Flow Matching and Optimal Transport paths.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, as it introduces &ldquo;Flow Matching&rdquo; (FM), a novel simulation-free paradigm for training Continuous Normalizing Flows (CNFs) at scale. It is supported by a strong <strong>Theory</strong> basis, providing formal theorems that allow the intractable marginal vector field regression to be solved via a tractable conditional objective. It also touches on <strong>Systematization</strong> by showing that existing diffusion paths are specific instances of the proposed Gaussian probability path framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper aims to overcome the scaling limitations of Continuous Normalizing Flows (CNFs).</p>
<ul>
<li><strong>Problem</strong>: Standard Maximum Likelihood training for CNFs requires expensive numerical ODE simulations during training, which scales poorly. Existing simulation-free methods often involve intractable integrals or result in biased gradients.</li>
<li><strong>Gap</strong>: Diffusion models scale well, yet they are restricted to specific, curved probability paths (e.g., VP, VE) that can result in slow sampling and long training times.</li>
<li><strong>Goal</strong>: To develop an efficient, simulation-free training method for CNFs that supports arbitrary probability paths, specifically allowing for straighter, more efficient trajectories like those from Optimal Transport.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is <strong>Flow Matching (FM)</strong> and specifically the <strong>Conditional Flow Matching (CFM)</strong> objective.</p>
<ul>
<li><strong>Direct Vector Field Regression</strong>: The model regresses a target vector field $u_t$ that generates a desired probability path $p_t$.</li>
<li><strong>Conditional Flow Matching (CFM)</strong>: The authors prove that regressing the vector field of <em>conditional</em> paths (e.g., $p_t(x|x_1)$ given a single data point) yields the same gradients as regressing the intractable marginal vector field. This bypasses the need to know the marginal score or vector field.</li>
<li><strong>Optimal Transport Paths</strong>: The framework enables the use of <strong>Optimal Transport (OT)</strong> displacement interpolation for probability paths. OT paths are straight lines with constant speed, leading to faster training and easier sampling.</li>
</ul>
<p><strong>Concurrent work note</strong>: Rectified Flow (Liu et al., 2023) and Stochastic Interpolants (Albergo &amp; Vanden-Eijnden, 2023) were published concurrently at ICLR 2023 with structurally similar contributions under different names. All three independently propose simulation-free training of continuous flows via direct vector field regression; the differences lie in the specific interpolation schemes, theoretical framing, and experimental focus.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<ul>
<li><strong>Domains</strong>: 2D Checkerboard data, CIFAR-10, and ImageNet at resolutions $32 \times 32$, $64 \times 64$, and $128 \times 128$.</li>
<li><strong>Task</strong>: Unconditional generative modeling (density estimation and sample quality) and conditional super-resolution ($64 \times 64 \to 256 \times 256$).</li>
<li><strong>Baselines</strong>: Compared against Diffusion-based methods on the same architecture (U-Net): DDPM, Score Matching (SM), and ScoreFlow.</li>
<li><strong>Ablations</strong>: Specifically compared <strong>FM with Diffusion paths</strong> vs. <strong>FM with Optimal Transport (OT) paths</strong> to isolate the benefit of the training objective vs. the path choice.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Outperforms diffusion baselines</strong>: FM-OT consistently outperforms all diffusion-based methods (DDPM, Score Matching, ScoreFlow) in both Likelihood (NLL) and Sample Quality (FID) across CIFAR-10 and ImageNet, using the same U-Net architecture and training budget. Selected rows from Table 1 (NLL in bits per dimension, BPD; lower is better for all three metrics; &ldquo;FM w/ OT&rdquo; and &ldquo;FM w/ Diffusion&rdquo; refer to FM trained with OT paths and Diffusion paths respectively):</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>NLL (BPD) ↓</th>
          <th>FID ↓</th>
          <th>NFE ↓</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>DDPM</td>
          <td>3.12</td>
          <td>7.48</td>
          <td>274</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>FM w/ OT</td>
          <td><strong>2.99</strong></td>
          <td><strong>6.35</strong></td>
          <td><strong>142</strong></td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>ScoreFlow</td>
          <td>3.36</td>
          <td>24.95</td>
          <td>601</td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>FM w/ OT</td>
          <td><strong>3.31</strong></td>
          <td><strong>14.45</strong></td>
          <td><strong>138</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Training stability</strong>: FM with diffusion paths (FM w/ Diffusion) is itself a more stable alternative to diffusion training than DDPM and Score Matching, as shown by training curves in the paper (Figure 5), even before switching to OT paths. The OT path then provides further gains.</li>
<li><strong>Sampling speed</strong>: The straight trajectories of OT paths allow accurate sampling with significantly fewer function evaluations (NFE) compared to diffusion paths.</li>
<li><strong>Generality</strong>: Diffusion is a specific instance of Gaussian probability paths within FM. OT paths are a better-optimized alternative available within the same framework.</li>
<li><strong>Downstream adoption</strong>: Flow matching has been adopted beyond image generation. <a href="/notes/biology/computational-biology/dynamicflow/">DynamicFlow</a> uses it as the generative backbone for simultaneously generating ligand molecules and transforming protein pockets, extending flow matching to structure-based drug design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Datasets</strong>: CIFAR-10, ImageNet ($32 \times 32$, $64 \times 64$, $128 \times 128$).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Images are center-cropped and resized.</li>
<li>For $32 \times 32$ and $64 \times 64$, the preprocessing follows Chrabaszcz et al. (2017).</li>
<li>Data is transformed via $\varphi(y) = 2^7(y+1)$ mapping $[-1, 1]$ pixel values to $[0, 256]$ for BPD computation.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Conditional Flow Matching (CFM) Objective</strong></p>
<p>The practical training objective used is the CFM loss, which bypasses intractable marginalization:</p>
<p>$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} | v_t(\psi_t(x_0)) - u_t(\psi_t(x_0) | x_1) |^2$$</p>
<p>Where $t \sim \mathcal{U}[0,1]$, $x_1 \sim q(x_1)$ (data), and $x_0 \sim p(x_0)$ (noise).</p>
<p><strong>2. Optimal Transport (OT) Probability Path</strong></p>
<p>The authors recommend the OT path for efficiency.</p>
<ul>
<li><strong>Mean/Std Schedule</strong>: $\mu_t(x) = t x_1$ and $\sigma_t(x) = 1 - (1 - \sigma_{min})t$.</li>
<li><strong>Conditional Flow Map</strong>: $\psi_t(x) = (1 - (1 - \sigma_{min})t)x + t x_1$.</li>
<li><strong>Target Vector Field</strong>: The closed-form regression target for OT is:
$$u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{min})x}{1 - (1 - \sigma_{min})t}$$</li>
</ul>
<p><strong>3. Sampling</strong></p>
<p>Sampling is performed by solving the ODE $\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))$ from $t=0$ to $t=1$ using the learned vector field $v_t$.</p>
<ul>
<li><strong>Solver</strong>: <code>dopri5</code> (adaptive) is used for robust evaluation. Fixed-step solvers (Euler, Midpoint) are used for low-NFE efficiency tests.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: U-Net architecture from Dhariwal &amp; Nichol (2021) is used for all image experiments.</li>
<li><strong>Toy Data</strong>: 5-layer MLP with 512 neurons.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$, weight decay=0.0).</li>
<li>Learning Rate: Polynomial decay or constant (see Table 3 in paper).</li>
<li>$\sigma_{min}$: Set to a small value (e.g., $1e-5$).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>NLL (BPD)</strong>: Computed using the continuous change of variables formula, estimated via the Hutchinson trace estimator to bypass $O(d^3)$ divergence computation.</li>
<li><strong>FID</strong>: Frechet Inception Distance for sample quality.</li>
<li><strong>NFE</strong>: Number of Function Evaluations required by the solver.</li>
</ul>
</li>
<li><strong>Likelihood Computation</strong>: Requires solving an augmented ODE to track the log-density change:
$$\frac{d}{dt} \begin{bmatrix} \phi_t(x) \ f(t) \end{bmatrix} = \begin{bmatrix} v_t(\phi_t(x)) \ -\text{div}(v_t(\phi_t(x))) \end{bmatrix}$$</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CIFAR-10</strong>: 2 GPUs.</li>
<li><strong>ImageNet-32</strong>: 4 GPUs.</li>
<li><strong>ImageNet-64</strong>: 16 GPUs.</li>
<li><strong>ImageNet-128</strong>: 32 GPUs.</li>
<li><strong>Precision</strong>: Full 32-bit for CIFAR/IM-32; 16-bit mixed precision for IM-64/128.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/flow_matching">flow_matching (PyTorch library)</a></td>
          <td>Code</td>
          <td>CC BY-NC 4.0</td>
          <td>Later official library from Meta; not the original experiment code</td>
      </tr>
  </tbody>
</table>
<p>The paper does not release the original training code or model weights used in the experiments. The <code>facebookresearch/flow_matching</code> library was released later as a general-purpose PyTorch implementation of flow matching algorithms. Standard benchmark datasets (CIFAR-10, ImageNet) are publicly available.</p>
<hr>
<h2 id="theoretical-notes-why-cfm-works">Theoretical Notes: Why CFM Works</h2>
<p>The paper relies on three key theorems to make training tractable.</p>
<p><strong>Theorem 1 (Marginal Generation)</strong>:</p>
<p>Marginalizing conditional vector fields $u_t(x|x_1)$ yields the correct marginal vector field $u_t(x)$ that generates the marginal probability path $p_t(x)$.</p>
<p>$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$$</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>To understand why this theorem holds, we have to look at the <strong>Continuity Equation</strong>, which is the fundamental partial differential equation (PDE) that links a probability density path $p_t$ to a vector field $u_t$.</p>
<p>A vector field $u_t$ is said to &ldquo;generate&rdquo; a probability path $p_t$ if and only if they satisfy the continuity equation:</p>
<p>$$\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$$</p>
<p>The proof of Theorem 1 relies on substituting the definitions of the marginal path and vector field into this equation to see if they balance out.</p>
<p><strong>Step-by-Step Proof:</strong></p>
<ol>
<li>
<p><strong>Start with the time derivative of the marginal path</strong>: We begin by differentiating the marginal probability path $p_t(x)$ with respect to time. By definition, the marginal path is the integral of the conditional paths over the data distribution:
$$\frac{\partial p_t(x)}{\partial t} = \frac{\partial}{\partial t} \int p_t(x|x_1) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Swap derivative and integral</strong>: Assuming standard regularity conditions (Leibniz Rule), we can move the time derivative inside the integral:
$$\frac{\partial p_t(x)}{\partial t} = \int \frac{\partial p_t(x|x_1)}{\partial t} q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Apply the Conditional Continuity Equation</strong>: This is the critical step. We know that the conditional vector field $u_t(x|x_1)$ generates the conditional path $p_t(x|x_1)$. Therefore, for every single sample $x_1$, the pair satisfies the continuity equation:
$$\frac{\partial p_t(x|x_1)}{\partial t} = -\nabla \cdot (p_t(x|x_1) u_t(x|x_1))$$</p>
<p>Substituting this into our integral gives:
$$\frac{\partial p_t(x)}{\partial t} = -\int \nabla \cdot (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Pull the Divergence out</strong>: Since the divergence operator ($\nabla \cdot$) acts on $x$ and the integral is over $x_1$, we can pull the divergence operator outside the integral (by linearity):
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot \left( \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1 \right)$$</p>
</li>
<li>
<p><strong>Match with the Marginal Vector Field Definition</strong>: Now, look at the term inside the parentheses. The paper defines the marginal vector field $u_t(x)$ specifically to make this term simpler. Rearranging the definition of $u_t(x)$ provided in the theorem:
$$p_t(x) u_t(x) = \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1$$</p>
<p>Substitute $p_t(x) u_t(x)$ back into our equation from Step 4:
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot (p_t(x) u_t(x))$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: We have just shown that $\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$. This is exactly the continuity equation. Because the marginal path and the aggregated marginal vector field satisfy this equation, the vector field is proven to generate the path.</p></blockquote>
<p><strong>Theorem 2 (Gradient Equivalence)</strong>:</p>
<p>The intractable Flow Matching objective $\mathcal{L}_{FM}$ (which requires $u_t(x)$) has the <strong>same gradients</strong> as the tractable Conditional Flow Matching objective $\mathcal{L}_{CFM}$.</p>
<p>$$\nabla_\theta \mathcal{L}_{FM}(\theta) = \nabla_\theta \mathcal{L}_{CFM}(\theta)$$</p>
<p>This allows the model to learn the marginal vector field by only seeing conditional sample paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The reason Theorem 2 holds is that the &ldquo;Conditional Flow Matching&rdquo; (CFM) objective is essentially an unbiased estimator of the &ldquo;Flow Matching&rdquo; (FM) objective (up to a constant). When we average over all the conditional data points $x_1$, the &ldquo;cross-term&rdquo; in the loss function aligns perfectly with the marginal vector field.</p>
<p><strong>1. Expand the Loss Functions</strong></p>
<p>First, let&rsquo;s look at the squared error in both objectives. Recall that $v_t$ is our neural network (parameterized by $\theta$), $u_t$ is the intractable marginal target, and $u_t(x|x_1)$ is the tractable conditional target.</p>
<p>Expanding the squared norms:</p>
<ul>
<li>
<p><strong>FM Objective</strong>:
$$\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, p_t(x)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x) + |u_t(x)|^2 \right]$$</p>
</li>
<li>
<p><strong>CFM Objective</strong>:
$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x|x_1) + |u_t(x|x_1)|^2 \right]$$</p>
</li>
</ul>
<p><strong>Key Insight</strong>: When we take the gradient $\nabla_\theta$, the last term in both equations disappears because the targets ($u_t$) are independent of the network weights $\theta$. We only need to show that the expectations of the first two terms match.</p>
<p><strong>2. Matching the First Term ($|v_t(x)|^2$)</strong></p>
<p>This part is straightforward. The expectation of $|v_t(x)|^2$ is the same in both cases because of how the marginal density $p_t(x)$ is defined.</p>
<ul>
<li><strong>FM</strong>: averages over $p_t(x)$.</li>
<li><strong>CFM</strong>: averages over $p_t(x|x_1)q(x_1)$.</li>
</ul>
<p>Since $p_t(x) = \int p_t(x|x_1) q(x_1) dx_1$ (by definition), averaging over the joint distribution is mathematically identical to averaging over the marginal $p_t(x)$.</p>
<p><strong>3. Matching the Cross Term (The &ldquo;Trick&rdquo;)</strong></p>
<p>This is the critical part of the proof. We need to show that the interaction between the network and the marginal field equals the interaction between the network and the conditional field.</p>
<p><strong>The Goal</strong>: Show $\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$.</p>
<p><strong>The Proof</strong>:</p>
<ol>
<li>
<p>Start with the <strong>FM cross-term</strong> (marginal):
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)]$$</p>
</li>
<li>
<p>Substitute the definition of the marginal vector field $u_t(x)$ derived in <strong>Theorem 1</strong>:
$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1$$</p>
</li>
<li>
<p>Plug this into the integral. The $p_t(x)$ terms cancel:
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \int_t \int_x p_t(x) v_t(x) \cdot \left[ \int_{x_1} u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 \right] dx$$</p>
</li>
<li>
<p>This simplifies to:
$$= \int_t \int_x \int_{x_1} v_t(x) \cdot u_t(x|x_1) p_t(x|x_1) q(x_1) dx_1 dx dt$$</p>
</li>
<li>
<p>This is exactly the definition of the expectation in the <strong>CFM objective</strong>:
$$= \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: Because the expectations of all terms involving $\theta$ are identical, the gradients must be identical.</p>
<p>Intuitively, this works like <strong>Denoising Score Matching</strong> or <strong>Stochastic Gradient Descent</strong>: even though each individual conditional vector field $u_t(x|x_1)$ points to a specific data point $x_1$ (which may differ from the true marginal direction), the <em>average</em> of all these pulls equals the true marginal vector field $u_t(x)$.</p></blockquote>
<p><strong>Theorem 3 (Gaussian Conditional VFs)</strong>:</p>
<p>For any Gaussian probability path $p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$, the unique vector field generating it is available in closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p>This theorem allows explicitly defining targets for both Diffusion (curved) and Optimal Transport (straight) paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The derivation of Theorem 3 comes from the direct relationship between a flow map $\psi_t$ and its generating vector field. Because we chose a specific, simple path (Gaussian), we can invert the flow map to find the vector field in closed form.</p>
<p><strong>1. Define the Flow Map $\psi_t$</strong></p>
<p>We start by defining the conditional probability path as a Gaussian:</p>
<p>$$p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$$</p>
<p>The simplest way to &ldquo;push&rdquo; a standard normal distribution (noise) $p_0 = \mathcal{N}(0, I)$ to this Gaussian is using an affine transformation (scaling and shifting). We define the flow map $\psi_t$ as:</p>
<p>$$\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$$</p>
<p>This map takes a noise sample $x_0$ and transforms it into a sample $x$ at time $t$.</p>
<p><strong>2. The Definition of a Generating Vector Field</strong></p>
<p>By definition, a vector field $u_t$ generates a flow $\psi_t$ if the vector field describes the instantaneous velocity of the flow at any point. Mathematically:</p>
<p>$$u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$$</p>
<p>Let $x = \psi_t(x_0)$ be the position of the particle at time $t$. We want to find $u_t(x)$.</p>
<p><strong>3. Invert the Flow Map</strong></p>
<p>To find $u_t(x)$, we must express the equation in terms of $x$ rather than $x_0$. Since our flow map is a simple affine transformation (multiply and add), it is easily invertible (assuming $\sigma_t(x_1) \neq 0$):</p>
<p>$$x_0 = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}$$</p>
<p>We will call this inverse map $\psi_t^{-1}(x)$.</p>
<p><strong>4. Differentiate the Flow Map</strong></p>
<p>Now we calculate the left side of our definition equation (velocity): $\frac{d}{dt}\psi_t(x_0)$.</p>
<p>Taking the time derivative of $\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$:</p>
<p>$$\frac{d}{dt}\psi_t(x_0) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>(Note: $\sigma&rsquo;_t$ and $\mu&rsquo;_t$ denote time derivatives).</p>
<p><strong>5. Substitute and Solve</strong></p>
<p>Now we combine everything. We know $u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$.</p>
<p>Substitute the result from Step 4 into this equation:</p>
<p>$$u_t(\psi_t(x_0)) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>This expresses the vector field in terms of the initial point $x_0$. We must express it in terms of the current point $x$. So, we plug in the inverse formula for $x_0$ derived in Step 3:</p>
<p>$$u_t(x|x_1) = \sigma&rsquo;_t(x_1) \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \mu&rsquo;_t(x_1)$$</p>
<p>Rearranging terms gives the final closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p><strong>Why is this useful?</strong></p>
<p>This formula means that as long as you can define a mean schedule $\mu_t(x_1)$ and a standard deviation schedule $\sigma_t(x_1)$ (which is easy to do for both Diffusion and Optimal Transport), you immediately get the exact vector field target $u_t(x|x_1)$ needed to train your neural network, bypassing complex ODE solving or score matching approximations.</p></blockquote>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., &amp; Le, M. (2023). Flow Matching for Generative Modeling. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lipmanFlowMatchingGenerative2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow Matching for Generative Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2210.02747">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Building Normalizing Flows with Stochastic Interpolants</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</guid><description>A continuous-time normalizing flow using stochastic interpolants and quadratic loss to bypass costly ODE backpropagation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with significant <strong>Theory</strong> contributions.</p>
<p>The authors propose a specific algorithm (&ldquo;InterFlow&rdquo;) for constructing generative models based on continuous-time normalizing flows. The work is characterized by the derivation of a new training objective (a simple quadratic loss) that bypasses the computational bottlenecks of previous methods. It includes prominent baseline comparisons against continuous flow methods (FFJORD, OT-Flow) and diffusion models. The theoretical component establishes the validity of the interpolant density satisfying the continuity equation (a conservation law governing how probability mass flows) and bounds the Wasserstein-2 distance (a measure of transport cost between distributions, penalizing squared displacement) of the transport.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to overcome the computational inefficiency of training Continuous Normalizing Flows (CNFs) using Maximum Likelihood Estimation (MLE). Standard CNF training requires backpropagating through numerical ODE solvers, which is costly and limits scalability.</p>
<p>Additionally, while score-based diffusion models (SDEs) have achieved high sample quality, they theoretically require infinite time integration and rely on specific noise schedules. The authors aim to establish a method that works strictly with Probability Flow ODEs on finite time intervals, retaining the flexibility to connect arbitrary densities without the complexity of SDEs or the cost of standard ODE adjoint methods.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Stochastic Interpolant</strong> framework:</p>
<ul>
<li><strong>Explicit Interpolant Construction</strong>: The method defines a time-dependent interpolant $x_t = I_t(x_0, x_1)$ (e.g., trigonometric interpolation) that connects samples from the base density $\rho_0$ and target $\rho_1$.</li>
<li><strong>Simulation-Free Training</strong>: The velocity field $v_t(x)$ of the probability flow is learned by minimizing a simple quadratic objective: $G(\hat{v}) = \mathbb{E}[|\hat{v}_t(x_t)|^2 - 2\partial_t x_t \cdot \hat{v}_t(x_t)]$. Because $\partial_t I_t$ is known analytically from the interpolant definition, the expectation can be estimated by sampling $(x_0, x_1, t)$ directly. This avoids ODE integration during training (ODE integration is still required at inference).</li>
<li><strong>Decoupling Path and Optimization</strong>: The choice of path (interpolant) is separated from the optimization of the velocity field. MLE methods couple the path and objective.</li>
<li><strong>Connection to Score-Based Models</strong>: The authors show that for Gaussian base densities and trigonometric interpolants, the learned velocity field is explicitly related to the score function $\nabla \log \rho_t$, providing a theoretical bridge between CNFs and diffusion models.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors performed validation across synthetic, tabular, and image domains:</p>
<ul>
<li><strong>2D Density Estimation</strong>: Benchmarked on &ldquo;Checkerboard&rdquo;, &ldquo;8 Gaussians&rdquo;, and anisotropic curved densities to visualize mode coverage and transport smoothness.</li>
<li><strong>High-Dimensional Tabular Data</strong>: Evaluated on standard benchmarks (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) comparing Negative Log Likelihood (NLL) against FFJORD, OT-Flow, and others.</li>
<li><strong>Image Generation</strong>: Trained models on CIFAR-10 ($32 \times 32$), ImageNet ($32 \times 32$), and Oxford Flowers ($128 \times 128$) to test scalability.</li>
<li><strong>Ablations</strong>: Investigated optimizing the interpolant path itself (e.g., learning Fourier coefficients for the path) to approach optimal transport and minimize path length.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Performance</strong>: The method matches or supersedes conventional ODE flows (like FFJORD) in terms of NLL while being significantly cheaper to train.</li>
<li><strong>Efficiency</strong>: The training cost per epoch is constant (simulation-free), whereas MLE-based ODE methods see growing costs as the dynamics become more complex.</li>
<li><strong>Scalability</strong>: The method successfully scales to $128 \times 128$ resolution on a single GPU, a resolution that prior ab-initio ODE flows had not demonstrated.</li>
<li><strong>Flexibility</strong>: The framework can connect <em>any</em> two arbitrary densities (e.g., connecting two different complex 2D distributions) without needing one to be Gaussian.</li>
<li><strong>Optimal Transport</strong>: For a fixed interpolant, minimizing $G(\hat{v})$ over the velocity field recovers the velocity for that specific path. Additionally optimizing over the interpolant family yields a solution to the Benamou-Brenier optimal transport problem.</li>
<li><strong>Limitations</strong>: The authors acknowledge that image FID scores trail dedicated diffusion models, noting that InterFlow was not optimized with standard training tricks such as exponential moving averages, truncation, or learning rate warm-ups. The framework&rsquo;s sample quality could likely improve with these additions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Tabular Datasets</strong>: POWER (6D), GAS (8D), HEPMASS (21D), MINIBOONE (43D), BSDS300 (63D).
<ul>
<li>Training points range from ~30k (MINIBOONE) to ~1.6M (POWER).</li>
</ul>
</li>
<li><strong>Image Datasets</strong>:
<ul>
<li>CIFAR-10 ($32 \times 32$, 50k training points).</li>
<li>ImageNet ($32 \times 32$, ~1.28M training points).</li>
<li>Oxford Flowers ($128 \times 128$, ~315k training points).</li>
</ul>
</li>
<li><strong>Time Sampling</strong>: Time $t$ is sampled from a Beta distribution during training (reweighting) to focus learning near the target.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Interpolant</strong>: The primary interpolant used is trigonometric: $I_t(x_0, x_1) = \cos(\frac{\pi t}{2})x_0 + \sin(\frac{\pi t}{2})x_1$.
<ul>
<li>Alternative linear interpolant: $I_t = a_t x_0 + b_t x_1$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>:
$$G(\hat{v}) = \mathbb{E}_{t, x_0, x_1}[|\hat{v}_t(x_t)|^2 - 2\partial_t I_t(x_0, x_1) \cdot \hat{v}_t(x_t)]$$
<ul>
<li>The expectation is amenable to empirical estimation using batches of $x_0, x_1, t$.</li>
</ul>
</li>
<li><strong>Sampling</strong>: Numerical integration using Dormand-Prince (Runge-Kutta 4/5).</li>
<li><strong>Optimization</strong>: SGD/Adam variants used for optimization.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Tabular Architectures</strong>:
<ul>
<li>Feed-forward networks with 4-5 hidden layers.</li>
<li>Hidden widths: 512 (POWER, GAS, HEPMASS, MINIBOONE) or 1024 (BSDS300).</li>
<li>Activation: ReLU (general) or ELU (BSDS300).</li>
</ul>
</li>
<li><strong>Image Architectures</strong>:
<ul>
<li>U-Net based on the DDPM implementation.</li>
<li>Dimensions: 256 hidden dimension.</li>
<li>Sinusoidal time embeddings used.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Negative Log Likelihood (NLL) in nats (tabular) or bits per dim (images), Frechet Inception Distance (FID) for images.</li>
<li><strong>Baselines</strong>: FFJORD, Glow, Real NVP, OT-Flow, ScoreFlow, DDPM.</li>
</ul>
<p><strong>Tabular NLL</strong> (nats, lower is better; Table 2 Left):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>POWER</th>
          <th>GAS</th>
          <th>HEPMASS</th>
          <th>MINIBOONE</th>
          <th>BSDS300</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MADE</td>
          <td>3.08</td>
          <td>-3.56</td>
          <td>20.98</td>
          <td>15.59</td>
          <td>-148.85</td>
      </tr>
      <tr>
          <td>Real NVP</td>
          <td>-0.17</td>
          <td>-8.33</td>
          <td>18.71</td>
          <td>13.55</td>
          <td>-153.28</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>-0.17</td>
          <td>-8.15</td>
          <td>18.92</td>
          <td>11.35</td>
          <td>-155.07</td>
      </tr>
      <tr>
          <td>CPF</td>
          <td>-0.52</td>
          <td>-10.36</td>
          <td>16.93</td>
          <td>10.58</td>
          <td>-154.99</td>
      </tr>
      <tr>
          <td>NSP</td>
          <td>-0.64</td>
          <td>-13.09</td>
          <td>14.75</td>
          <td>9.67</td>
          <td>-157.54</td>
      </tr>
      <tr>
          <td>FFJORD</td>
          <td>-0.46</td>
          <td>-8.59</td>
          <td>14.92</td>
          <td>10.43</td>
          <td>-157.40</td>
      </tr>
      <tr>
          <td>OT-Flow</td>
          <td>-0.30</td>
          <td>-9.20</td>
          <td>17.32</td>
          <td>10.55</td>
          <td>-154.20</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>-0.57</strong></td>
          <td><strong>-12.35</strong></td>
          <td><strong>14.85</strong></td>
          <td><strong>10.42</strong></td>
          <td><strong>-156.22</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation NLL and FID</strong> (Table 2 Right; NLL in bits per dim, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>CIFAR-10 NLL</th>
          <th>CIFAR-10 FID</th>
          <th>ImageNet-32 NLL</th>
          <th>ImageNet-32 FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FFJORD</td>
          <td>3.40</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>3.35</td>
          <td>-</td>
          <td>4.09</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM</td>
          <td>≤3.75</td>
          <td>3.17</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM++ (Song et al., 2021)</td>
          <td>≤3.37</td>
          <td>2.90</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ScoreSDE (Song et al., 2021)</td>
          <td>2.99</td>
          <td>2.92</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>VDM</td>
          <td>≤2.65</td>
          <td>7.41</td>
          <td>≤3.72</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Soft Truncation</td>
          <td>2.88</td>
          <td>3.45</td>
          <td>3.85</td>
          <td>8.42</td>
      </tr>
      <tr>
          <td>ScoreFlow</td>
          <td>2.81</td>
          <td>5.40</td>
          <td>3.76</td>
          <td>10.18</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>2.99</strong></td>
          <td><strong>10.27</strong></td>
          <td><strong>3.48</strong></td>
          <td><strong>8.49</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: DDPM++ is from Song et al. (2021), the same work as ScoreSDE (it is the architecture optimized for VP/sub-VP SDEs). InterFlow matches ScoreSDE on CIFAR-10 NLL (2.99 bits per dim) while being simulation-free. FID is weaker than dedicated image models (10.27 vs 2.92 for ScoreSDE), reflecting the paper&rsquo;s primary focus on tractable likelihood rather than sample quality.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: All models were trained on a single NVIDIA A100 GPU.</li>
<li><strong>Training Time</strong>:
<ul>
<li>Tabular: $10^5$ steps.</li>
<li>Images: $1.5 \times 10^5$ to $6 \times 10^5$ steps.</li>
<li>Speedup: Demonstrated ~400x speedup compared to FFJORD on MiniBooNE dataset.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>lucidrains/denoising-diffusion-pytorch (link defunct)</td>
          <td>Code</td>
          <td>MIT</td>
          <td>Base U-Net architecture used for image experiments; original GitHub account no longer available</td>
      </tr>
  </tbody>
</table>
<p>No official code release accompanies this paper. All tabular datasets (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) are publicly available from prior work. CIFAR-10 and ImageNet are standard public benchmarks. Oxford Flowers 102 is also publicly available. Hyperparameters and architectures are fully specified in Tables 3 and 4 of the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>The Eleventh International Conference on Learning Representations</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{albergoBuildingNormalizingFlows2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Building {{Normalizing Flows}} with {{Stochastic Interpolants}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Eleventh International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Albergo, Michael Samuel and {Vanden-Eijnden}, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=li7qeBbCR1t}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=li7qeBbCR1t">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/2209.15571">arXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Translating InChI to IUPAC Names with Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/handsel-inchi-iupac-2021/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/handsel-inchi-iupac-2021/</guid><description>Sequence-to-sequence Transformer translating InChI identifiers to IUPAC names with 91% accuracy on organic compounds.</description><content:encoded><![CDATA[<h2 id="primary-contribution-a-transformer-based-method">Primary Contribution: A Transformer-Based Method</h2>
<p>This is primarily a <strong>Method</strong> paper. It adapts a specific architecture (Transformer) to a specific task (InChI-to-IUPAC translation) and evaluates its performance against both machine learning and commercial baselines. It also has a secondary <strong>Resource</strong> contribution, as the trained model and scripts are released as open-source software.</p>
<h2 id="motivation-the-bottleneck-in-algorithmic-iupac-nomenclature">Motivation: The Bottleneck in Algorithmic IUPAC Nomenclature</h2>
<p>Generating correct IUPAC names is difficult due to the comprehensive but complex rules defined by the International Union of Pure and Applied Chemistry. Commercial software generates names from structures but remains closed-source with opaque methodologies and frequent inter-package disagreements. Open identifiers like InChI and SMILES lack direct human readability. This creates a need for an open, automated method to generate informative IUPAC names from standard identifiers like InChI, which are ubiquitous in online chemical databases.</p>
<h2 id="novelty-treating-chemical-translation-as-a-character-level-sequence">Novelty: Treating Chemical Translation as a Character-Level Sequence</h2>
<p>The key novelty is treating chemical nomenclature translation as a character-level sequence-to-sequence problem using a Transformer architecture, specifically using <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> as the source language.</p>
<ul>
<li>Standard Neural Machine Translation (NMT) uses sub-word tokenization. This model processes InChI and predicts IUPAC names character-by-character.</li>
<li>It demonstrates that character-level tokenization outperforms byte-pair encoding or unigram models for this specific chemical task.</li>
<li>It uses InChI&rsquo;s standardization to avoid the canonicalization issues inherent in SMILES-based approaches.</li>
<li>The attention mechanism allows the decoder to align specific parts of the generated IUPAC name with corresponding structural features in the source InChI string, operating via the standard scaled dot-product attention:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$</li>
</ul>
<h2 id="methodology--experimental-validation">Methodology &amp; Experimental Validation</h2>
<ul>
<li><strong>Training:</strong> The model was trained on 10 million InChI/IUPAC pairs sampled from PubChem using a character-level objective. The model is supervised using categorical cross-entropy loss across the vocabulary of characters:
$$ \mathcal{L} = -\sum_{i=1}^{N} y_i \log(\hat{y}_i) $$</li>
<li><strong>Ablation Studies:</strong> The authors experimentally validated architecture choices, finding that LSTM models and sub-word tokenization (BPE) performed worse than the Transformer with character tokenization. They also optimized dropout rates.</li>
<li><strong>Performance Benchmarking:</strong> The model was evaluated on a held-out test set of 200,000 samples. Performance was quantified primarily by Whole-Name Accuracy and Normalized Edit Distance (based on the Damerau-Levenshtein distance, scaled by the maximum string length).</li>
<li><strong>Commercial Comparison:</strong> The authors compared their model against four major commercial packages (ACD/I-Labs, ChemAxon, Mestrelab, and PubChem&rsquo;s Lexichem). However, this evaluation used a highly limited test set of only 100 molecules, restricting the statistical confidence of the external baseline.</li>
<li><strong>Error Analysis:</strong> They analyzed performance across different chemical classes (organics, charged species, macrocycles, inorganics) and visualized attention coefficients to interpret model focus.</li>
</ul>
<h2 id="key-results-and-the-inorganic-challenge">Key Results and the Inorganic Challenge</h2>
<ul>
<li><strong>High Accuracy on Organics:</strong> The model achieved 91% whole-name accuracy on the test set, performing particularly well on organic compounds.</li>
<li><strong>Comparable to Commercial Tools:</strong> On the limited 100-molecule benchmark, the edit distance between the model&rsquo;s predictions and commercial packages (15-23%) was similar to the variation found <em>between</em> the commercial packages themselves (16-21%).</li>
<li><strong>Limitations on Inorganics:</strong> The model performed poorly on inorganic (14% accuracy) and organometallic compounds (20% accuracy). This is attributed to inherent data limitations in the standard InChI format (which deliberately disconnects metal atoms from their ligands) and low training data coverage for those classes.</li>
<li><strong>Character-Level Superiority:</strong> Character-level tokenization was found to be essential; byte-pair encoding reduced accuracy significantly.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset was derived from <a href="https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Extras/">PubChem&rsquo;s public FTP server</a> (<code>CID-SMILES.gz</code> and <code>CID-IUPAC.gz</code>).</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Raw</strong></td>
          <td>PubChem</td>
          <td>100M pairs</td>
          <td>Filtered for length (InChI &lt; 200 chars, IUPAC &lt; 150 chars). 132k unparseable SMILES dropped.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Subsampled</td>
          <td>10M pairs</td>
          <td>Random sample from the filtered set.</td>
      </tr>
      <tr>
          <td><strong>Validation</strong></td>
          <td>Held-out</td>
          <td>10,000 samples</td>
          <td>Limited to InChI length &gt; 50 chars.</td>
      </tr>
      <tr>
          <td><strong>Test</strong></td>
          <td>Held-out</td>
          <td>200,000 samples</td>
          <td>Limited to InChI length &gt; 50 chars.</td>
      </tr>
      <tr>
          <td><strong>Tokenization</strong></td>
          <td>Vocab</td>
          <td>InChI: 66 chars<br>IUPAC: 70 chars</td>
          <td>Character-level tokenization. Spaces treated as tokens.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework</strong>: OpenNMT-py 2.0.0 (using PyTorch). Training scripts and vocabularies are available as supplementary files to the original publication. Pre-trained model weights are hosted on <a href="https://doi.org/10.5281/zenodo.5081159">Zenodo</a>.</li>
<li><strong>Architecture Type</strong>: Transformer Encoder-Decoder.</li>
<li><strong>Optimization</strong>: ADAM optimizer ($\beta_1=0.9, \beta_2=0.998$).</li>
<li><strong>Learning Rate</strong>: Linear warmup over 8000 steps to 0.0005, then decayed by inverse square root of iteration.</li>
<li><strong>Regularization</strong>:
<ul>
<li>Dropout: 0.1 (applied to dense and attentional layers).</li>
<li>Label Smoothing: Magnitude 0.1.</li>
</ul>
</li>
<li><strong>Training Strategy</strong>: Teacher forcing used for both training and validation.</li>
<li><strong>Gradient Accumulation</strong>: Gradients accumulated over 4 batches before updating parameters.</li>
<li><strong>Inference</strong>: Beam search with width 10 and length penalty 1.0.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Structure</strong>: 6 layers in encoder, 6 layers in decoder.</li>
<li><strong>Attention</strong>: 8 heads per attention sub-layer.</li>
<li><strong>Dimensions</strong>:
<ul>
<li>Feed-forward hidden state size: 2048.</li>
<li>Embedding vector length: 512.</li>
</ul>
</li>
<li><strong>Initialization</strong>: Glorot&rsquo;s method.</li>
<li><strong>Position</strong>: Positional encoding added to word vectors.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics reported include <strong>Whole-Name Accuracy</strong> (percentage of exact matches) and <strong>Normalized Edit Distance</strong> (Damerau-Levenshtein, scale 0-1).</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy (All)</td>
          <td>91%</td>
          <td>N/A</td>
          <td>Test set of 200k samples.</td>
      </tr>
      <tr>
          <td>Accuracy (Inorganic)</td>
          <td>14%</td>
          <td>N/A</td>
          <td>Limited by InChI format and data.</td>
      </tr>
      <tr>
          <td>Accuracy (Organometallic)</td>
          <td>20%</td>
          <td>N/A</td>
          <td>Limited by InChI format and data.</td>
      </tr>
      <tr>
          <td>Accuracy (Charged)</td>
          <td>79%</td>
          <td>N/A</td>
          <td>Test set subset.</td>
      </tr>
      <tr>
          <td>Accuracy (Rajan)</td>
          <td>72%</td>
          <td>N/A</td>
          <td>Comparative ML model (STOUT).</td>
      </tr>
      <tr>
          <td>Edit Dist (Organic)</td>
          <td>$0.02 \pm 0.03$</td>
          <td>N/A</td>
          <td>Very high similarity for organics.</td>
      </tr>
      <tr>
          <td>Edit Dist (Inorganic)</td>
          <td>$0.32 \pm 0.20$</td>
          <td>N/A</td>
          <td>Poor performance on inorganics.</td>
      </tr>
      <tr>
          <td>Edit Dist (Organometallic)</td>
          <td>$0.37 \pm 0.24$</td>
          <td>N/A</td>
          <td>Poor performance on organometallics.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Tesla K80.</li>
<li><strong>Training Time</strong>: 7 days.</li>
<li><strong>Throughput</strong>: ~6000 tokens/sec (InChI) and ~3800 tokens/sec (IUPAC).</li>
<li><strong>Batch Size</strong>: 4096 tokens (approx. 30 compounds).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.5081159">InChI to IUPAC model</a></td>
          <td>Model</td>
          <td>CC BY 4.0</td>
          <td>Pre-trained Transformer weights (551 MB), requires OpenNMT-py 2.0.0</td>
      </tr>
      <tr>
          <td><a href="https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Extras/">PubChem FTP</a></td>
          <td>Dataset</td>
          <td>Public Domain</td>
          <td>Source data: CID-SMILES.gz and CID-IUPAC.gz</td>
      </tr>
      <tr>
          <td>Training scripts &amp; vocabularies</td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Included as supplementary files with the publication</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Handsel, J., Matthews, B., Knight, N. J., &amp; Coles, S. J. (2021). Translating the InChI: Adapting Neural Machine Translation to Predict IUPAC Names from a Chemical Identifier. <em>Journal of Cheminformatics</em>, 13(1), 79. <a href="https://doi.org/10.1186/s13321-021-00535-x">https://doi.org/10.1186/s13321-021-00535-x</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{handselTranslatingInChIAdapting2021a,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Translating the {{InChI}}: Adapting Neural Machine Translation to Predict {{IUPAC}} Names from a Chemical Identifier}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Translating the {{InChI}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Handsel, Jennifer and Matthews, Brian and Knight, Nicola J. and Coles, Simon J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{79}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00535-x}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-20}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{We present a sequence-to-sequence machine learning model for predicting the IUPAC name of a chemical from its standard International Chemical Identifier (InChI). The model uses two stacks of transformers in an encoder-decoder architecture, a setup similar to the neural networks used in state-of-the-art machine translation. Unlike neural machine translation, which usually tokenizes input and output into words or sub-words, our model processes the InChI and predicts the IUPAC name character by character. The model was trained on a dataset of 10 million InChI/IUPAC name pairs freely downloaded from the National Library of Medicine&#39;s online PubChem service. Training took seven days on a Tesla K80 GPU, and the model achieved a test set accuracy of 91\%. The model performed particularly well on organics, with the exception of macrocycles, and was comparable to commercial IUPAC name generation software. The predictions were less accurate for inorganic and organometallic compounds. This can be explained by inherent limitations of standard InChI for representing inorganics, as well as low coverage in the training data.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{Attention,GPU,InChI,IUPAC,seq2seq,Transformer}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Struct2IUPAC: Translating SMILES to IUPAC via Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/struct2iupac-2021/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/struct2iupac-2021/</guid><description>A Transformer-based model for translating between SMILES strings and IUPAC names, trained on 47M PubChem examples, achieving 98.9% accuracy with verification.</description><content:encoded><![CDATA[<h2 id="struct2iupac-as-a-methodological-shift">Struct2IUPAC as a Methodological Shift</h2>
<p>This is primarily a <strong>Method</strong> paper with significant elements of <strong>Position</strong>.</p>
<ul>
<li><strong>Method</strong>: The authors propose a specific neural architecture (Transformer with custom tokenization) and a verification pipeline (round-trip check) to solve the SMILES $\leftrightarrow$ IUPAC translation task. They rigorously benchmark this against rule-based baselines (OPSIN).</li>
<li><strong>Position</strong>: The authors explicitly argue for a paradigm shift, suggesting that &ldquo;heavy&rdquo; neural architectures should replace complex, costly rule-based legacy systems even for &ldquo;exact&rdquo; algorithmic tasks.</li>
</ul>
<h2 id="the-cost-of-rule-based-chemical-naming">The Cost of Rule-Based Chemical Naming</h2>
<ul>
<li><strong>Complexity of Naming</strong>: Generating IUPAC names manually is error-prone and requires deep algorithmic knowledge.</li>
<li><strong>Lack of Open Source Tools</strong>: While open-source tools exist for Name-to-Structure (e.g., OPSIN), there were no open-source tools for the inverse &ldquo;Structure-to-Name&rdquo; conversion at the time of writing.</li>
<li><strong>Cost of Development</strong>: Developing rule-based converters &ldquo;from scratch&rdquo; is prohibitively expensive and time-consuming compared to training a neural model on existing data.</li>
</ul>
<h2 id="struct2iupac-core-innovation">Struct2IUPAC Core Innovation</h2>
<ul>
<li><strong>Struct2IUPAC</strong>: The first effective open-source neural model for <a href="/notes/chemistry/molecular-representations/name-translation/stout-v2/">converting SMILES to IUPAC names</a>, treating chemical translation as a Neural Machine Translation (NMT) problem.</li>
<li><strong>Verification Loop</strong>: A novel inference pipeline that generates multiple candidates via beam search and validates them using a reverse converter (OPSIN) to ensure the generated name maps back to the original structure.</li>
<li><strong>Custom Tokenization</strong>: A manually curated rule-based tokenizer for IUPAC names that handles specific chemical suffixes, prefixes, and stereochemical markers.</li>
</ul>
<h2 id="experimental-setup-and-stress-testing">Experimental Setup and Stress Testing</h2>
<ul>
<li><strong>Accuracy Benchmarking</strong>: The models were tested on a held-out subset of 100,000 molecules from PubChem. The authors measured accuracy across different beam sizes (1, 3, 5).</li>
<li><strong>Comparison to Rules</strong>: The neural IUPAC2Struct model was compared directly against the rule-based OPSIN tool.</li>
<li><strong>Stress Testing</strong>:
<ul>
<li><strong>Sequence Length</strong>: Evaluated performance across varying token lengths, identifying a &ldquo;sweet spot&rdquo; (10-60 tokens) and failure modes for very short (e.g., methane) or long molecules.</li>
<li><strong>Stereochemistry</strong>: Tested on &ldquo;stereo-dense&rdquo; compounds. The authors define a &ldquo;stereo-density&rdquo; index ($I$) as the ratio of stereocenters ($S$) to total tokens ($N$):
$$I = \frac{S}{N}$$
They observed a performance drop for these dense molecules, though the model still handled many stereocenters robustly.</li>
<li><strong>Tautomers</strong>: Verified the model&rsquo;s ability to handle different tautomeric forms (e.g., Guanine and Uracil variants).</li>
</ul>
</li>
<li><strong>Latency Analysis</strong>: Benchmarked inference speeds on CPU vs. GPU relative to output sequence length.</li>
</ul>
<h2 id="benchmarks-and-outcomes">Benchmarks and Outcomes</h2>
<ul>
<li><strong>High Accuracy</strong>: The Struct2IUPAC model achieved <strong>98.9% accuracy</strong> (Beam 5 with verification). The reverse model (IUPAC2Struct) achieved <strong>99.1%</strong>, comparable to OPSIN&rsquo;s 99.4%.</li>
<li><strong>Distribution Modeling vs. Intuition</strong>: The authors claim the model infers &ldquo;chemical logic,&rdquo; because it correctly generates multiple valid IUPAC names for single molecules where naming ambiguity exists (e.g., parent group selection). However, this more likely reflects the Transformer successfully modeling the high-frequency conditional probability distribution of synonymous names present in the PubChem training data, rather than learning intrinsic chemical rules.</li>
<li><strong>Production Readiness</strong>: Inference on GPU takes less than 0.5 seconds even for long names, making it viable for production use.</li>
<li><strong>Paradigm Shift</strong>: The authors conclude that neural networks are a viable, cost-effective alternative to developing rule-based algorithms for legacy notation conversion.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study utilized the PubChem database.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Total</strong></td>
          <td>PubChem</td>
          <td>~95M</td>
          <td>Filtered for RDKit compatibility</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Split A</td>
          <td>47,312,235</td>
          <td>Random 50% split</td>
      </tr>
      <tr>
          <td><strong>Testing</strong></td>
          <td>Split B</td>
          <td>47,413,850</td>
          <td>Random 50% split</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Cleaning</strong>: Molecules that could not be processed by RDKit were removed. Molecules containing tokens not in the tokenizer (e.g., aromatic selenium) were excluded.</li>
<li><strong>Availability</strong>: A subset of 100,000 test molecules is available on GitHub (<code>data/test_100000.csv</code>) and Zenodo. The full train/test splits are not explicitly provided.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>SMILES</strong>: Character-based tokenization.</li>
<li><strong>IUPAC</strong>: Custom rule-based tokenizer splitting suffixes (<code>-one</code>, <code>-al</code>), prefixes (<code>-oxy</code>, <code>-di</code>), and special symbols (<code>(</code>, <code>)</code>, <code>R(S)</code>).</li>
</ul>
</li>
<li><strong>Verification Step</strong>:
<ol>
<li>Generate $N$ names using Beam Search ($N=5$).</li>
<li>Reverse translate the candidate name using OPSIN.</li>
<li>Check if the OPSIN structure matches the original input SMILES.</li>
<li>Display the first verified match; otherwise, report failure.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Standard Transformer with 6 encoder layers and 6 decoder layers.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Attention Heads: 8</li>
<li>Attention Dimension ($d_{\text{model}}$): 512</li>
<li>Feed-Forward Dimension ($d_{\text{ff}}$): 2048</li>
</ul>
</li>
<li><strong>Training Objective</strong>: The models were trained using standard autoregressive cross-entropy loss over the target token sequence $y$ given the input string $x$:
$$\mathcal{L} = - \sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, x)$$</li>
<li><strong>Training</strong>: Two separate models were trained: <code>Struct2IUPAC</code> (SMILES $\to$ IUPAC) and <code>IUPAC2Struct</code> (IUPAC $\to$ SMILES).</li>
<li><strong>Availability</strong>: Code for model architecture is provided in the GitHub repository. Pre-trained weights for the IUPAC2Struct model are available, but the Struct2IUPAC model weights are not publicly released, meaning researchers would need to retrain that model on their own PubChem data to reproduce those results.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation was performed on a random subset of 100,000 molecules from the test set.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Beam Size</th>
          <th>Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Struct2IUPAC</td>
          <td>1</td>
          <td>96.1%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Struct2IUPAC</td>
          <td>5</td>
          <td>98.9%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>IUPAC2Struct</td>
          <td>1</td>
          <td>96.6%</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>IUPAC2Struct</td>
          <td>5</td>
          <td>99.1%</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Robustness</strong>: Accuracy drops significantly for augmented (non-canonical) SMILES (37.16%) and stereo-enriched compounds (66.52%).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Infrastructure</strong>: 4 $\times$ Tesla V100 GPUs and 36 CPUs.</li>
<li><strong>Training Time</strong>: Approximately 10 days under full load.</li>
<li><strong>Inference Speed</strong>: &lt;0.5s per molecule on GPU; scale is linear with output token length.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/sergsb/IUPAC2Struct">IUPAC2Struct (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Transformer code and pre-trained IUPAC2Struct model</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.4280814">Test data (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>100k test molecules, OPSIN failure cases, model failure cases</td>
      </tr>
      <tr>
          <td><a href="https://app.syntelly.com/smiles2iupac">Struct2IUPAC web demo</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online interface for SMILES to IUPAC conversion</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Krasnov, L., Khokhlov, I., Fedorov, M. V., &amp; Sosnin, S. (2021). Transformer-based artificial neural networks for the conversion between chemical notations. <em>Scientific Reports</em>, 11(1), 14798. <a href="https://doi.org/10.1038/s41598-021-94082-y">https://doi.org/10.1038/s41598-021-94082-y</a></p>
<p><strong>Publication</strong>: Scientific Reports 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{krasnovTransformerbasedArtificialNeural2021a,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Transformer-Based Artificial Neural Networks for the Conversion between Chemical Notations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Krasnov, Lev and Khokhlov, Ivan and Fedorov, Maxim V. and Sosnin, Sergey}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{11}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{14798}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41598-021-94082-y}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/sergsb/IUPAC2Struct">GitHub Repository</a></li>
<li><a href="https://app.syntelly.com/smiles2iupac">Web Demo</a></li>
</ul>
]]></content:encoded></item><item><title>STOUT: SMILES to IUPAC Names via Neural Machine Translation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout/</guid><description>A deep-learning neural machine translation approach to translate between SMILES strings and IUPAC names using the STOUT model.</description><content:encoded><![CDATA[<h2 id="contribution-translating-chemistry-as-a-language">Contribution: Translating Chemistry as a Language</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong secondary contribution as a <strong>Resource</strong> paper.</p>
<ul>
<li><strong>Method</strong>: It proposes a neural machine translation (NMT) architecture to approximate the complex, rule-based algorithm of IUPAC naming, treating it as a language translation task.</li>
<li><strong>Resource</strong>: It provides an open-source tool and trained models to the community, addressing a gap where such functionality was previously limited to proprietary software.</li>
</ul>
<h2 id="motivation-democratizing-iupac-nomenclature">Motivation: Democratizing IUPAC Nomenclature</h2>
<p>The International Union of Pure and Applied Chemistry (IUPAC) naming scheme is universally accepted but algorithmically complex. Generating these names correctly is challenging for humans, and automated generation is largely missing from major open-source toolkits like CDK, RDKit, or Open Babel. While reliable commercial tools exist (e.g., ChemAxon&rsquo;s <code>molconvert</code>), there was a lack of open-source alternatives for the scientific community. STOUT aims to fill this gap using a data-driven approach.</p>
<h2 id="core-innovation-sequence-to-sequence-naming">Core Innovation: Sequence-to-Sequence Naming</h2>
<ul>
<li><strong>Language Translation Approach</strong>: The authors treat chemical representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>/<a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>) and IUPAC names as two different languages, applying Neural Machine Translation (NMT) to translate between them.</li>
<li><strong>Use of SELFIES</strong>: The work establishes SELFIES (Self-Referencing Embedded Strings) as a robust choice over SMILES for deep learning tokenization in this specific task, capitalizing on its syntactic robustness.</li>
<li><strong>Hardware Acceleration</strong>: The paper benchmarks GPU versus TPU training and highlights the practical necessity of Tensor Processing Units (TPUs) for training large-scale chemical language models, reducing training time by an order of magnitude.</li>
</ul>
<h2 id="methodology--translation-validation">Methodology &amp; Translation Validation</h2>
<ul>
<li><strong>Data Scale</strong>: The model was trained on datasets of 30 million and 60 million molecules derived from PubChem.</li>
<li><strong>Hardware Benchmarking</strong>: Training efficiency was compared between an nVidia Tesla V100 GPU and Google TPU v3-8/v3-32 units.</li>
<li><strong>Bidirectional Translation</strong>: The system was tested on two distinct tasks:
<ol>
<li><strong>Forward</strong>: SELFIES → IUPAC names</li>
<li><strong>Reverse</strong>: IUPAC names → SELFIES</li>
</ol>
</li>
<li><strong>Validation</strong>: Performance was evaluated on a held-out test set of 2.2 million molecules.</li>
</ul>
<h2 id="translation-accuracy--hardware-scaling">Translation Accuracy &amp; Hardware Scaling</h2>
<ul>
<li><strong>High Accuracy</strong>: The model achieved an average BLEU score of ~90% and a Tanimoto similarity index &gt; 0.9 for both translation directions.</li>
<li><strong>Generalization</strong>: Even when predictions were textually mismatched (low BLEU score), the underlying chemical structures often remained highly similar (high Tanimoto similarity), suggesting the system captures fundamental chemical semantics rather than merely memorizing strings.</li>
<li><strong>Impact of Data Size</strong>: Expanding training from 30 million to 60 million molecules yielded consistent performance gains without saturating.</li>
<li><strong>Hardware Necessity</strong>: Training on TPUs proved up to 54 times faster than a standard GPU baseline (Tesla V100), making scaling highly computationally tractable.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">STOUT (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Current repo hosts STOUT V2.0 transformer models; V1 RNN code available in earlier commits</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Public Domain</td>
          <td style="text-align: left">Source of 111M molecules; 30M/60M training subsets not directly provided</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The dataset was curated from PubChem (111 million molecules). Note that the specific 30M and 60M subsets are not directly linked in the publication repository, which means a user would have to reconstruct the filtering process.</p>
<p><strong>Preprocessing &amp; Filtering</strong>:</p>
<ul>
<li>Explicit hydrogens removed; converted to canonical SMILES.</li>
<li><strong>Filtering Rules</strong>: MW &lt; 1500 Da, no counter ions, limited element set (C, H, O, N, P, S, F, Cl, Br, I, Se, B), no hydrogen isotopes, 3-40 bonds, no charged groups.</li>
<li><strong>Ground Truth Generation</strong>: ChemAxon&rsquo;s <code>molconvert</code> (Marvin Suite 20.15) was used to generate target IUPAC names for training.</li>
<li><strong>Representation</strong>: All SMILES were converted to SELFIES for training.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">PubChem Filtered</td>
          <td style="text-align: left">30M &amp; 60M</td>
          <td style="text-align: left">Two distinct training sets created.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">PubChem Held-out</td>
          <td style="text-align: left">2.2M</td>
          <td style="text-align: left">Molecules not present in training sets; uniform token frequency.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>:
<ul>
<li><strong>SELFIES</strong>: Split iteratively by brackets <code>[</code> and <code>]</code>.</li>
<li><strong>IUPAC</strong>: Split via punctuation (<code>(</code>, <code>)</code>, <code>{</code>, <code>}</code>, <code>[</code>, <code>]</code>, <code>-</code>, <code>.</code>, <code>,</code>) and a discrete set of sub-word chemical morphemes (e.g., <code>methyl</code>, <code>benzene</code>, <code>fluoro</code>).</li>
<li><strong>Padding</strong>: SELFIES padded to 48 tokens; IUPAC padded to 78 tokens. &ldquo;Start&rdquo; and &ldquo;End&rdquo; sequence markers append each chain.</li>
</ul>
</li>
<li><strong>Optimization</strong>: Adam optimizer instantiated with a learning rate of $0.0005$.</li>
<li><strong>Objective Function</strong>: Sparse categorical cross-entropy, assessing prediction probabilities for token $i$ over vocabulary $V$:
$$ \mathcal{L} = -\sum_{i=1}^{V} y_i \log(\hat{y}_i) $$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder sequence-to-sequence network with Bahdanau attention mechanism context weighting.</li>
<li><strong>Components</strong>:
<ul>
<li><strong>Encoder/Decoder</strong>: Recurrent Neural Networks (RNN) constructed using Gated Recurrent Units (GRU).</li>
<li><strong>Attention</strong>: Bahdanau (additive) soft attention, which calculates alignment scores to softly weight encoder hidden states natively:
$$ e_{tj} = v_a^\top \tanh(W_a s_{t-1} + U_a h_j) $$</li>
<li><strong>Embedding</strong>: Decoder output passes through a continuous embedding layer before concatenating with the attention context vector.</li>
</ul>
</li>
<li><strong>Implementation</strong>: Python 3 backend using TensorFlow 2.3.0. <em>Note: The linked GitHub repository currently defaults to the STOUT V2.0 transformer models, so researchers aiming to reproduce this specific V1 RNN paper should reference the older tag/commit history.</em></li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics heavily emphasize both linguistic accuracy and cheminformatic structural correctness:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Details</th>
          <th style="text-align: left">Result (60M Model)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>BLEU Score</strong></td>
          <td style="text-align: left">NLTK sentence BLEU (unigram to 4-gram)</td>
          <td style="text-align: left">0.94 (IUPAC $\to$ SELFIES)</td>
          <td style="text-align: left">Exact text overlap. Serves as a strictly syntactic proxy.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Tanimoto Similarity</strong></td>
          <td style="text-align: left">PubChem fingerprints via CDK</td>
          <td style="text-align: left">0.98 (Valid IUPAC names)</td>
          <td style="text-align: left">Evaluates substructure alignment over bit vectors, $T(A, B) = \frac{\vert A \cap B \vert}{\vert A \cup B \vert}$.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Comparison of hardware efficiency for training large chemical language models:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Hardware</th>
          <th style="text-align: left">Batch Size</th>
          <th style="text-align: left">Time per Epoch (15M subset)</th>
          <th style="text-align: left">Speedup Factor</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>GPU (Tesla V100)</strong></td>
          <td style="text-align: left">256</td>
          <td style="text-align: left">~27 hours</td>
          <td style="text-align: left">1x</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>TPU v3-8</strong></td>
          <td style="text-align: left">1024 (Global)</td>
          <td style="text-align: left">~2 hours</td>
          <td style="text-align: left">13x</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>TPU v3-32</strong></td>
          <td style="text-align: left">1024 (Global)</td>
          <td style="text-align: left">~0.5 hours</td>
          <td style="text-align: left">54x</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2021). STOUT: SMILES to IUPAC names using neural machine translation. <em>Journal of Cheminformatics</em>, 13(1), 34. <a href="https://doi.org/10.1186/s13321-021-00512-4">https://doi.org/10.1186/s13321-021-00512-4</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanSTOUTSMILESIUPAC2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{STOUT: SMILES to IUPAC Names Using Neural Machine Translation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{STOUT}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00512-4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-09-22}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{Chemical compounds can be identified through a graphical depiction, a suitable string representation, or a chemical name. A universally accepted naming scheme for chemistry was established by the International Union of Pure and Applied Chemistry (IUPAC) based on a set of rules. Due to the complexity of this ruleset a correct chemical name assignment remains challenging for human beings and there are only a few rule-based cheminformatics toolkits available that support this task in an automated manner. Here we present STOUT (SMILES-TO-IUPAC-name translator), a deep-learning neural machine translation approach to generate the IUPAC name for a given molecule from its SMILES string as well as the reverse translation, i.e. predicting the SMILES string from the IUPAC name. In both cases, the system is able to predict with an average BLEU score of about 90% and a Tanimoto similarity index of more than 0.9. Also incorrect predictions show a remarkable similarity between true and predicted compounds.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{Attention mechanism,Chemical language,Deep neural network,DeepSMILES,IUPAC names,Neural machine translation,Recurrent neural network,SELFIES,SMILES}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">GitHub Repository</a></li>
<li><a href="/notes/chemistry/molecular-representations/name-translation/stout-v2/">STOUT V2.0 Note</a></li>
<li><a href="/notes/chemistry/molecular-representations/name-translation/struct2iupac-2021/">Struct2IUPAC Note</a></li>
<li><a href="/notes/chemistry/molecular-representations/name-translation/handsel-inchi-iupac-2021/">HandSEL Note (InChI to IUPAC)</a></li>
</ul>
]]></content:encoded></item><item><title>STOUT V2.0: Transformer-Based SMILES to IUPAC Translation</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout-v2/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/name-translation/stout-v2/</guid><description>A Transformer-based model for translating SMILES to IUPAC names, trained on ~1 billion molecules, achieving ~0.99 BLEU score on benchmarks.</description><content:encoded><![CDATA[<h2 id="paper-contribution--methodological-scope">Paper Contribution &amp; Methodological Scope</h2>
<p><strong>Method (Primary) / Resource (Secondary)</strong></p>
<p>This paper presents a <strong>Methodological</strong> contribution by developing and validating a Transformer-based neural machine translation model (STOUT V2) for bidirectional chemical nomenclature (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> $\leftrightarrow$ IUPAC). It systematically compares this new architecture against previous RNN-based baselines (<a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT V1</a>) and performs ablation studies on tokenization strategies.</p>
<p>It also serves as a significant <strong>Resource</strong> contribution by generating a massive training dataset of nearly 1 billion SMILES-IUPAC pairs (curated via commercial Lexichem software) and releasing the resulting models and code as open-source tools for chemical naming.</p>
<h2 id="the-need-for-robust-open-source-iupac-nomenclature-rules">The Need for Robust Open-Source IUPAC Nomenclature Rules</h2>
<p>Assigning systematic IUPAC names to chemical structures requires adherence to complex rules, challenging human consistency. Deterministic, rule-based software options like OpenEye Lexichem and ChemAxon are reliable commercial solutions. Existing open-source tools like OPSIN focus on parsing names to structures.</p>
<p>The previous version of STOUT (V1), based on RNNs/GRUs, achieved ~90% BLEU accuracy, with known limitations in capturing long-distance dependencies required for stereochemistry handling. This work uses the sequence-learning capabilities of Transformers combined with large-scale datasets to create a competitive open-source IUPAC naming tool.</p>
<h2 id="architectural-shift-and-billion-scale-training">Architectural Shift and Billion-Scale Training</h2>
<p>The primary advancements over previous iterations address both architecture and dataset scale:</p>
<ol>
<li><strong>Architecture Shift</strong>: Moving from an RNN-based Seq2Seq model to a <strong>Transformer-based architecture</strong> (4 layers, 8 heads), which captures intricate chemical patterns better than GRUs.</li>
<li><strong>Billion-Scale Training</strong>: Training on a dataset of nearly <strong>1 billion molecules</strong> (combining PubChem and ZINC15), significantly larger than the 60 million used for STOUT V1.</li>
<li><strong>Tokenization Strategy</strong>: Determining that <strong>character-wise tokenization</strong> for IUPAC names is superior to word-wise tokenization in terms of both accuracy and training efficiency (15% faster).</li>
</ol>
<h2 id="experimental-validation-and-scaling-limits">Experimental Validation and Scaling Limits</h2>
<p>The authors conducted three primary experiments to validate bidirectional translation (SMILES $\rightarrow$ IUPAC and IUPAC $\rightarrow$ SMILES):</p>
<ul>
<li><strong>Experiment 1 (Optimization)</strong>: Assessed the impact of dataset size (1M vs 10M vs 50M) and tokenization strategy on SMILES-to-IUPAC performance.</li>
<li><strong>Experiment 2 (Scaling)</strong>: Trained models on 110 million PubChem molecules for <strong>both</strong> forward and reverse translation tasks to test performance on longer sequences.</li>
<li><strong>Experiment 3 (Generalization)</strong>: Trained on the full ~1 billion dataset (PubChem + ZINC15) for both translation directions.</li>
<li><strong>External Validation</strong>: Benchmarked against an external dataset from ChEBI (1,485 molecules) and ChEMBL34 to test generalization to unseen data.</li>
</ul>
<p><strong>Evaluation Metrics</strong>:</p>
<ul>
<li><strong>Textual Accuracy</strong>: BLEU scores (1-4) and Exact String Match.</li>
<li><strong>Chemical Validity</strong>: Retranslation of generated names back to SMILES using OPSIN, followed by Tanimoto similarity checks (PubChem fingerprints) against the original input.</li>
</ul>
<h2 id="translation-accuracy-and-structural-validity">Translation Accuracy and Structural Validity</h2>
<ul>
<li><strong>Superior Performance</strong>: STOUT V2 achieved an average BLEU score of <strong>0.99</strong> (vs 0.94 for V1). While exact string matches varied by experiment (83-89%), the model notably achieved a perfect BLEU score (1.0) on <strong>97.49%</strong> of a specific test set where STOUT V1 only reached 66.65%.</li>
<li><strong>Structural Validity (&ldquo;Near Misses&rdquo;)</strong>: When the generated name differed from the ground truth string, the re-generated structure often remained chemically valid. The model maintained an average Tanimoto similarity $T(A,B)$ of <strong>0.68</strong> for these divergent names between bit-vector fingerprints $A$ and $B$, roughly defined as:
$$ T(A,B) = \frac{\sum (A \cap B)}{\sum (A \cup B)} $$
<em>Critique</em>: Note that an average Tanimoto coefficient of 0.68 typically suggests moderate structural similarity/drift, not an almost-identical &ldquo;near miss&rdquo; (which would be $&gt;0.85$). This implies the model constructs chemically related but structurally distinct outputs when it fails exact string matching.</li>
<li><strong>Tokenization</strong>: Character-level splitting for IUPAC names outperformed word-level splitting and was more computationally efficient.</li>
<li><strong>Data Imbalance &amp; Generalization</strong>: The model&rsquo;s drop in performance for sequences &gt;600 characters highlights a systemic issue in open chemical databases: long, highly complex SMILES strings are significantly underrepresented. Even billion-scale training datasets are still bound by the chemical diversity of their source material.</li>
<li><strong>Limitations</strong>:
<ul>
<li><strong>Preferred Names (PINs)</strong>: The model mimics Lexichem&rsquo;s naming conventions, generating valid IUPAC names distinct from strict <em>Preferred IUPAC Names</em> (PINs).</li>
<li><strong>Sequence Length</strong>: Performance degrades for very long SMILES (&gt;600 characters) due to scarcity in the training data.</li>
<li><strong>Algorithmic Distillation Bottleneck</strong>: Because the 1 billion training pairs were generated entirely by OpenEye&rsquo;s Lexichem, STOUT V2 acts as a knowledge distillation of that specific commercial algorithm. The model learns Lexichem’s heuristic mapping, specific dialects, and potential systematic errors, rather than deriving true nomenclature rules from first principles.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data was derived from PubChem and ZINC15. Ground truth IUPAC names were generated using OpenEye Lexichem TK 2.8.1 to ensure consistency.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training (Exp 1)</strong></td>
          <td>PubChem Subset</td>
          <td>1M, 10M, 50M</td>
          <td>Selected via MaxMin algorithm for diversity</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 2)</strong></td>
          <td>PubChem</td>
          <td>110M</td>
          <td>Filtered for SMILES length &lt; 600</td>
      </tr>
      <tr>
          <td><strong>Training (Exp 3)</strong></td>
          <td>PubChem + ZINC15</td>
          <td>~1 Billion</td>
          <td>999,637,326 molecules total</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>ChEBI</td>
          <td>1,485</td>
          <td>External validation set, non-overlapping with training</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>SMILES</strong>: Canonicalized, isomeric, and kekulized using RDKit (v2023.03.1).</li>
<li><strong>Formatting</strong>: Converted to TFRecord format in 100 MB chunks for TPU efficiency.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES Tokenization</strong>: Regex-based splitting. Atoms (e.g., &ldquo;Cl&rdquo;, &ldquo;Au&rdquo;), bonds, brackets, and digits are separate tokens.</li>
<li><strong>IUPAC Tokenization</strong>: <strong>Character-wise split</strong> was selected as the optimal strategy (treating every character as a token).</li>
<li><strong>Optimization</strong>: Adam optimizer with a custom learning rate scheduler based on model dimensions.</li>
<li><strong>Loss Function</strong>: Trained to minimize the Sparse Categorical Cross-Entropy $L$, masking padding tokens. For a correctly predicted target class $t$ alongside probabilities $p_i$, the masked loss is represented mathematically as:
$$ L = - \sum_{i=1}^{m} m_i y_{i} \log(p_{i}) $$
where $m_i$ masks padded positions.</li>
<li><strong>Code Availability</strong>: The <a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">main STOUT V2 repository</a> contains the inference package. The training pipeline/instructions (originally linked to a separate repo that is currently a 404) can still be found within the <a href="https://doi.org/10.5281/zenodo.6559438">Zenodo archive release</a>.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model follows the standard Transformer architecture from &ldquo;Attention is All You Need&rdquo; (Vaswani et al.).</p>
<ul>
<li><strong>Architecture</strong>: 4 Transformer layers (encoder/decoder stack).</li>
<li><strong>Attention</strong>: Multi-head attention with <strong>8 heads</strong>.</li>
<li><strong>Dimensions</strong>: Embedding size ($d_{model}$) = 512; Feed-forward dimension ($d_{ff}$) = 2048.</li>
<li><strong>Regularization</strong>: Dropout rate of 0.1.</li>
<li><strong>Context Window</strong>: Max input length (SMILES) = 600; Max output length (IUPAC) = 700-1000.</li>
<li><strong>Weights</strong>: Model weights for forward and reverse architectures are <a href="https://doi.org/10.5281/zenodo.13318286">available via Zenodo (v3)</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation focused on both string similarity and chemical structural integrity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Scope</th>
          <th>Method</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>BLEU Score</strong></td>
          <td>N-gram overlap</td>
          <td>Compared predicted IUPAC string to Ground Truth.</td>
      </tr>
      <tr>
          <td><strong>Exact Match</strong></td>
          <td>Accuracy</td>
          <td>Binary 1/0 check for identical strings.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto</strong></td>
          <td>Structural Similarity</td>
          <td>Predicted Name $\rightarrow$ OPSIN $\rightarrow$ SMILES $\rightarrow$ Fingerprint comparison to input.</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/egonw/Smiles-TO-iUpac-Translator">STOUT V2 GitHub</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Inference package (PyPI: STOUT-pypi)</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.13318286">Model Weights (Zenodo v3)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Forward and reverse translation weights</td>
      </tr>
      <tr>
          <td><a href="https://zenodo.org/records/6559438">Code Snapshot (Zenodo)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training pipeline archive</td>
      </tr>
      <tr>
          <td><a href="https://stout.decimer.ai">Web Application</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Demo with Ketcher, bulk submission, DECIMER integration</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was conducted entirely on Google Cloud Platform (GCP) TPUs.</p>
<ul>
<li><strong>STOUT V1</strong>: Trained on TPU v3-8.</li>
<li><strong>STOUT V2</strong>: Trained on <strong>TPU v4-128 pod slices</strong> (128 nodes).</li>
<li><strong>Large Scale (Exp 3)</strong>: Trained on <strong>TPU v4-256 pod slice</strong> (256 nodes).</li>
<li><strong>Training Time</strong>: Average of <strong>15 hours and 2 minutes per epoch</strong> for the 1 billion dataset.</li>
<li><strong>Framework</strong>: TensorFlow 2.15.0-pjrt with Keras.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A., &amp; Steinbeck, C. (2024). STOUT V2.0: SMILES to IUPAC name conversion using transformer models. <em>Journal of Cheminformatics</em>, 16(146). <a href="https://doi.org/10.1186/s13321-024-00941-x">https://doi.org/10.1186/s13321-024-00941-x</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanSTOUTV20SMILES2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{STOUT V2}}.0: {{SMILES}} to {{IUPAC}} Name Conversion Using Transformer Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{STOUT V2}}.0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{146}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00941-x}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://stout.decimer.ai">Web Application</a> (Includes Ketcher drawing, bulk submission, and DECIMER integration)</li>
<li><a href="https://decimer.ai">DECIMER Project</a></li>
<li><a href="/notes/chemistry/molecular-representations/name-translation/stout/">STOUT V1 Note</a></li>
<li><a href="https://zenodo.org/records/6559438">Zenodo Archive (Code Snapshot)</a></li>
</ul>
]]></content:encoded></item><item><title>Multimodal Search in Chemical Documents and Reactions</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/shah-multimodal-search-2025/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/shah-multimodal-search-2025/</guid><description>A multimodal search engine that integrates text passages, molecular diagrams, and reaction data to enable passage-level retrieval in chemical literature.</description><content:encoded><![CDATA[<h2 id="contribution-multimodal-synthesis-retrieval">Contribution: Multimodal Synthesis Retrieval</h2>
<p>This paper represents a $\Psi_{\text{Method}}$ projection that proposes a novel architectural pipeline for indexing and searching chemical literature. The framework unifies text, molecular diagrams, and structured reaction records. It also contains a secondary $\Psi_{\text{Resource}}$ projection, providing a functional demonstration tool and curating a specific benchmark dataset for Suzuki coupling reactions.</p>
<h2 id="the-gap-in-passage-level-chemical-retrieval">The Gap in Passage-Level Chemical Retrieval</h2>
<p>Scientific literature documents chemical reactions through a combination of text and visual diagrams. Textual descriptions detail parameters like yield and operational temperature, whereas diagrams graphically model these structural transformations. Existing tools such as SciFinder or <a href="https://en.wikipedia.org/wiki/Reaxys">Reaxys</a> perform document-level or individual compound retrieval. They fail to explicitly link molecular figures to localized textual descriptions. This structure prevents researchers from directly extracting a corresponding reaction diagram alongside the exact textual protocol. Researchers require passage-level retrieval of synthesis protocols to efficiently access complete reaction conditions.</p>
<h2 id="core-innovation-unified-multimodal-indexing">Core Innovation: Unified Multimodal Indexing</h2>
<p>The core methodological innovation is a multimodal passage-level indexing and linking pipeline.</p>
<ul>
<li><strong>Unified Indexing:</strong> The framework processes text and diagrams in parallel and directly links them into a single index structure. This architecture supports search queries utilizing raw text, discrete <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings, or multimodal combinations.</li>
<li><strong>Compound-Passage Linking:</strong> The mechanism applies conflict-resolution logic linking chemical diagrams to specific text citations using two parallel heuristics:
<ol>
<li><strong>Token-based Alignment:</strong> Matching parsed diagram labels against documented text strings (e.g., &ldquo;compound 5&rdquo;) using normalized <a href="https://en.wikipedia.org/wiki/Levenshtein_distance">Levenshtein distance</a>.</li>
<li><strong>Fingerprint-based Alignment:</strong> Matching chemical structures against generated SMILES strings via structural <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto Similarity</a>.</li>
</ol>
</li>
<li><strong>ReactionMiner Integration:</strong> The pipeline parses and incorporates formatted reaction records (reactants, products, catalysts, quantitative yields) directly derived from segmented text passages.</li>
</ul>
<h2 id="methodology--expert-evaluation">Methodology &amp; Expert Evaluation</h2>
<p>The authors evaluated the system utilizing a chemical case study targeting specific synthesis domains alongside qualitative expert assessment.</p>
<ul>
<li><strong>Dataset:</strong> Evaluators processed a corpus of 7 research manuscripts and 6 supplementary data documents detailing <a href="https://en.wikipedia.org/wiki/Suzuki_reaction">Suzuki coupling</a> reactions.</li>
<li><strong>Volume:</strong> The resulting index processed 1,282 extracted passages (indexing 538), extracted 383 unique SMILES, and logged 219 parsed reactions.</li>
<li><strong>Qualitative Evaluation:</strong> Practicing structural chemists developed real-world queries (such as cross-referencing the conceptual &ldquo;Burke group&rdquo; alongside an explicit structural SMARTS pattern) to gauge retrieval capability.</li>
</ul>
<h2 id="key-findings--system-limitations">Key Findings &amp; System Limitations</h2>
<ul>
<li><strong>Diagram-to-Text Linking:</strong> The pipeline accurately paired visual molecular diagrams with structurally derived text details, permitting testers to navigate directly from a molecule query card to the exact origin passage within the source PDF.</li>
<li><strong>Contextual Insight Extraction:</strong> Specialized chemists found the parsed reaction representations (yield metrics, isolated catalysts) functionally pragmatic as high-level extractive summaries.</li>
<li><strong>Extrapolative Retrieval:</strong> The architecture permitted the effective retrieval of targeted chemical derivatives (such as benzo[b]thiophen-2-ylboronic acid) via structurally related input queries (dibenzothiophene).</li>
</ul>
<p>The system evaluation highlights several architectural restrictions:</p>
<ul>
<li><strong>Domain-Restricted Validation:</strong> The initial validation is entirely qualitative and bounded to the specific subclass of Suzuki coupling reactions. The evaluation omits standardized quantitative retrieval baselines (e.g., MAP, NDCG) and lacks systematic ablation data for the fusion scoring mechanism.</li>
<li><strong>Algorithmic Transparency:</strong> The multimodal query routing mechanism does not clearly indicate the dominant retrieval feature. This hides whether keyword text or structural similarity actually drove the final result placement. This ambiguity limits operator control.</li>
<li><strong>Optical Processing Brittleness:</strong> The embedded vision inference and primitive parsing pipelines display inherent fragility, producing intermittent failures when associating text passages with correctly parsed molecular diagrams.</li>
<li><strong>Metadata Logging Incompleteness:</strong> Practicing chemists requested additional structured metadata targets (such as specific molar equivalents and parameterized mol% values) to successfully bridge the extracted data stream directly into digital electronic lab notebooks.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.cs.rit.edu/~dprl/reactionminer-demo-landing/">ReactionMiner Demo</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Online demo landing page; source code repository not publicly linked</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source:</strong> The corpus features 7 primary research papers and 6 auxiliary supplementary information documents focusing on Suzuki coupling reactions, sourced from practicing chemists at UIUC. This evaluation dataset is strictly internal and not publicly available.</li>
<li><strong>Preprocessing:</strong>
<ul>
<li>Engineers convert source PDFs to full-page raster images.</li>
<li>The system extracts localized graphical layout and raw text via <strong>PyTesseract</strong>.</li>
<li>The pipeline segments valid passage chunks emphasizing reaction-related sentences utilizing product-indicative lexicons and topic modeling.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Diagram Extraction:</strong> A <strong>YOLOv8</strong> model identifies and segments molecular regions within structured PDF pages.</li>
<li><strong>Diagram Parsing:</strong> The architecture relies on <strong>ChemScraper</strong> to infer structural semantics from raw diagrams:
<ul>
<li><em>Born-digital PDFs:</em> <strong>SymbolScraper</strong> extracts vector lines and polygons directly from bounding box definitions.</li>
<li><em>Raster images:</em> The system employs the <strong>Line Segment Detector (LSD)</strong> and watershed bounding algorithms to isolate native geometric primitives.</li>
</ul>
</li>
<li><strong>Text Entity Extraction:</strong> The framework deploys <strong>ChemDataExtractor 2.0</strong> to extract explicit molecular aliases. A translation layer maps these entities to string representations via <strong>OPSIN</strong>.</li>
<li><strong>Linking Logic (Fusion Score):</strong>
<ul>
<li><strong>Text Link:</strong> The algorithm calculates a normalized Levenshtein ratio connecting visual diagram labels against proximal text mentions based on calculated edit distance.</li>
<li><strong>Structure Link:</strong> The algorithm computes the discrete Tanimoto Similarity between generated 2048-bit Morgan fingerprints extracted from localized visual diagram features and baseline text SMILES queries:
$$ T(A, B) = \frac{A \cdot B}{|A|^{2} + |B|^{2} - A \cdot B} $$
where $A$ and $B$ represent the boolean bit vectors of the respective fingerprint pairs.</li>
<li><strong>Conflict Resolution Protocol:</strong> The system fuses structural geometry bounds and discrete textual tokenization metrics, prioritizing the ranking sequence that yields a higher terminal similarity score. During final retrieval, the candidate subset is systematically re-ranked leveraging the hybrid calculation of the <a href="https://en.wikipedia.org/wiki/Okapi_BM25">BM25</a> explicit metric and the localized count of exact SMILES pattern hits.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Reaction Extraction Parameters:</strong> The engineers configure a <strong>LLaMA-3.1-8b</strong> model fine-tuned entirely via <strong>LoRA</strong> targeting custom tokens representing reaction entities (compounds, reagents, thermal inputs) directly pulled from text sub-chunks. Exact prompt constraints, the fine-tuning dataset, and specific LoRA hyperparameters are omitted from the source text.</li>
<li><strong>Diagram Processing Bounds:</strong> The codebase incorporates a segmentation-aware multi-task neural network topology built into ChemScraper to execute low-level raster image parsing tasks.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Search Engine Base:</strong> The authors implemented their indexing framework scaling atop <strong>PyTerrier</strong>.</li>
<li><strong>Text Feature Ranking:</strong> The metric utilizes standalone <strong>BM25</strong> bounds mapping keyword-similarity.</li>
<li><strong>Structure Feature Operations:</strong> The topology operates <strong>RDKit</strong> bindings powering substructure coordinate mapping logic and exact molecular similarity searches.</li>
<li><strong>Multimodal Fusion Processing:</strong>
<ul>
<li>The algorithm filters out terminal candidates mapping initial structural properties (SMILES queries) against the document-wide lexical properties (BM25 scores).</li>
<li>The final fusion routing assigns the strongest positive weight to retrieved passages that accumulate dense local clusters of structurally exact verified SMILES patterns.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute Infrastructure:</strong> The hardware and parameter requirements to host the multi-stage vision extractors (YOLOv8, ChemScraper) alongside a local 8B LLM are entirely unspecified in the paper.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Shah, A. K., et al. (2025). Multimodal Search in Chemical Documents and Reactions. In <em>Proceedings of the 48th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR &lsquo;25)</em>. ACM. <a href="https://doi.org/10.48550/arXiv.2502.16865">https://doi.org/10.48550/arXiv.2502.16865</a></p>
<p><strong>Publication</strong>: SIGIR &lsquo;25 (Demo Track), 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{shahMultimodalSearchChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Multimodal {{Search}} in {{Chemical Documents}} and {{Reactions}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Shah, Ayush Kumar and Dey, Abhisek and Luo, Leo and Amador, Bryan and Philippy, Patrick and Zhong, Ming and Ouyang, Siru and Friday, David Mark and Bianchi, David and Jackson, Nick and Zanibbi, Richard and Han, Jiawei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2502.16865}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.cs.rit.edu/~dprl/reactionminer-demo-landing/">Online Demo</a> (Note: While the landing page advertises the system as open-source, the exact repository URL and installation prerequisites are omitted from the official manuscript.)</li>
</ul>
]]></content:encoded></item><item><title>MOFFlow: Flow Matching for MOF Structure Prediction</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mofflow/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/mofflow/</guid><description>A Riemannian flow matching framework for generating Metal-Organic Framework structures by treating building blocks as rigid bodies.</description><content:encoded><![CDATA[<h2 id="methodological-contribution-mofflow-architecture">Methodological Contribution: MOFFlow Architecture</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$).</p>
<p>It introduces <strong>MOFFlow</strong>, a generative architecture and training framework designed specifically for the structure prediction of Metal-Organic Frameworks (MOFs). The paper focuses on the algorithmic innovation of decomposing the problem into rigid-body assembly on a Riemannian manifold, validates this through comparison against existing baselines, and performs ablation studies to justify architectural choices. While it leverages the theory of flow matching, its primary contribution is the application-specific architecture and the handling of modular constraints.</p>
<h2 id="motivation-scaling-limits-of-atom-level-generation">Motivation: Scaling Limits of Atom-Level Generation</h2>
<p>The primary motivation is to overcome the scalability and accuracy limitations of existing methods for MOF structure prediction.</p>
<ul>
<li><strong>Computational Cost of DFT:</strong> Conventional approaches rely on <em>ab initio</em> calculations (DFT) combined with random search, which are computationally prohibitive for large, complex systems like MOFs.</li>
<li><strong>Failure of General CSP:</strong> Existing deep generative models for general Crystal Structure Prediction (CSP) operate on an atom-by-atom basis. They fail to scale to MOFs, which often contain hundreds or thousands of atoms per unit cell, and do not exploit the inherent modular nature (building blocks) of MOFs.</li>
<li><strong>Tunability:</strong> MOFs have applications in carbon capture and drug delivery due to their tunable porosity, making automated design tools valuable.</li>
</ul>
<h2 id="core-innovation-rigid-body-flow-matching-on-se3">Core Innovation: Rigid-Body Flow Matching on SE(3)</h2>
<p>MOFFlow introduces a <strong>hierarchical, rigid-body flow matching framework</strong> tailored for MOFs.</p>
<ul>
<li><strong>Rigid Body Decomposition:</strong> MOFFlow treats metal nodes and organic linkers as rigid bodies, reducing the search space from $3N$ (atoms) to $6M$ (roto-translation of $M$ blocks) compared to atom-based methods.</li>
<li><strong>Riemannian Flow Matching on $SE(3)$:</strong> It is the first end-to-end model to jointly generate block-level rotations ($SO(3)$), translations ($\mathbb{R}^3$), and lattice parameters using <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">Riemannian flow matching</a>.</li>
<li><strong>MOFAttention:</strong> A custom attention module designed to encode the geometric relationships between building blocks, lattice parameters, and rotational constraints.</li>
<li><strong>Constraint Handling:</strong> It incorporates domain knowledge by operating on a mean-free system for translation invariance and using canonicalized coordinates for rotation invariance.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The authors evaluated MOFFlow on structure prediction accuracy, physical property preservation, and scalability.</p>
<ul>
<li><strong>Dataset:</strong> The <strong>Boyd et al. (2019)</strong> dataset consisting of 324,426 hypothetical MOF structures, decomposed into building blocks using the <strong>MOFid</strong> algorithm. Filtered to structures with &lt;200 blocks, yielding 308,829 structures (247,066 train / 30,883 val / 30,880 test). Structures contain up to approximately 2,400 atoms per unit cell.</li>
<li><strong>Baselines:</strong>
<ul>
<li><em>Optimization-based:</em> Random Search (RS) and Evolutionary Algorithm (EA) using CrySPY and CHGNet.</li>
<li><em>Deep Learning:</em> DiffCSP (deep generative model for general crystals).</li>
<li><em>Self-Assembly:</em> A heuristic algorithm used in MOFDiff (adapted for comparison).</li>
</ul>
</li>
<li><strong>Metrics:</strong>
<ul>
<li><strong>Match Rate (MR):</strong> Percentage of generated structures matching ground truth within tolerance.</li>
<li><strong>RMSE:</strong> Root mean squared displacement normalized by average free length per atom.</li>
<li><strong>Structural Properties:</strong> Volumetric/Gravimetric Surface Area (VSA/GSA), Pore Limiting Diameter (PLD), Void Fraction, etc., calculated via Zeo++.</li>
<li><strong>Scalability:</strong> Performance vs. number of atoms and building blocks.</li>
</ul>
</li>
</ul>
<h2 id="results-and-generative-performance">Results and Generative Performance</h2>
<p>MOFFlow outperformed all baselines in accuracy and efficiency, particularly for large structures.</p>
<ul>
<li><strong>Accuracy:</strong> With a single sample, MOFFlow achieved a <strong>31.69% match rate</strong> (stol=0.5) and <strong>87.46%</strong> (stol=1.0) on the full test set (30,880 structures). With 5 samples, these rose to <strong>44.75%</strong> (stol=0.5) and <strong>100.0%</strong> (stol=1.0). RS and EA (tested on 100 and 15 samples respectively due to computational cost, generating 20 candidates each) achieved 0.00% MR at both tolerance levels. DiffCSP reached 0.09% (stol=0.5) and 23.12% (stol=1.0) with 1 sample.</li>
<li><strong>Speed:</strong> Inference took <strong>1.94 seconds</strong> per structure, compared to 5.37s for DiffCSP, 332s for RS, and 1,959s for EA.</li>
<li><strong>Scalability:</strong> MOFFlow preserved high match rates across all system sizes, while DiffCSP&rsquo;s match rate dropped sharply beyond 200 atoms.</li>
<li><strong>Property Preservation:</strong> The distributions of physical properties (e.g., surface area, void fraction) for MOFFlow-generated structures closely matched the ground truth. DiffCSP frequently reduced volumetric surface area and void fraction to zero.</li>
<li><strong>Self-Assembly Comparison:</strong> In a controlled comparison where the self-assembly (SA) algorithm received MOFFlow&rsquo;s predicted translations and lattice, MOFFlow (MR=31.69%, RMSE=0.2820) outperformed SA (MR=30.04%, RMSE=0.3084), confirming the value of the learned rotational vector fields. In an extended scalability comparison, SA scaled better for structures with many building blocks, but MOFFlow achieved higher overall match rate (31.69% vs. 27.14%).</li>
<li><strong>Batch Implementation:</strong> A refactored Batch version achieves improved results: <strong>32.73% MR</strong> (stol=0.5), RMSE of 0.2743, inference in <strong>0.19s</strong> per structure (10x faster), and training in roughly 1/3 the GPU hours.</li>
</ul>
<h3 id="limitations">Limitations</h3>
<p>The paper identifies three key limitations:</p>
<ol>
<li><strong>Hypothetical-only evaluation:</strong> All experiments use the Boyd et al. hypothetical database. Evaluation on more challenging real-world datasets remains needed.</li>
<li><strong>Rigid-body assumption:</strong> The model assumes that local building block structures are known, which may be impractical for rare building blocks whose structural information is missing from existing libraries or is inaccurate.</li>
<li><strong>Periodic invariance:</strong> The model is not invariant to periodic transformations of the input. Explicitly modeling periodic invariance could further improve performance.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source:</strong> MOF dataset by Boyd et al. (2019).</li>
<li><strong>Preprocessing:</strong> Structures were decomposed using the metal-oxo decomposition algorithm from <strong>MOFid</strong>.</li>
<li><strong>Filtering:</strong> Structures with fewer than 200 building blocks were used, yielding 308,829 structures.</li>
<li><strong>Splits:</strong> Train/Validation/Test ratio of 8:1:1 (247,066 / 30,883 / 30,880).</li>
<li><strong>Availability:</strong> Pre-processed dataset is available on <a href="https://zenodo.org/records/15187230">Zenodo</a>.</li>
<li><strong>Representations:</strong>
<ul>
<li><em>Atom-level:</em> Tuple $(X, a, l)$ (coordinates, types, lattice).</li>
<li><em>Block-level:</em> Tuple $(\mathcal{B}, q, \tau, l)$ (blocks, rotations, translations, lattice).</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Framework:</strong> Riemannian Flow Matching.</li>
<li><strong>Objective:</strong> Conditional Flow Matching (CFM) loss regressing to clean data $q_1, \tau_1, l_1$.
$$
\begin{aligned}
\mathcal{L}(\theta) = \mathbb{E}_{t, \mathcal{S}^{(1)}} \left[ \frac{1}{(1-t)^2} \left( \lambda_1 |\log_{q_t}(\hat{q}_1) - \log_{q_t}(q_1)|^2 + \dots \right) \right]
\end{aligned}
$$</li>
<li><strong>Priors:</strong>
<ul>
<li>Rotations ($q$): Uniform on $SO(3)$.</li>
<li>Translations ($\tau$): Standard normal on $\mathbb{R}^3$.</li>
<li>Lattice ($l$): Log-normal for lengths, Uniform(60, 120) for angles (Niggli reduced).</li>
</ul>
</li>
<li><strong>Inference:</strong> ODE solver with <strong>50 integration steps</strong>.</li>
<li><strong>Local Coordinates:</strong> Defined using PCA axes, corrected for symmetry to ensure consistency.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture:</strong> Hierarchical structure with two key modules.
<ul>
<li><strong>Atom-level Update Layers:</strong> 4-layer EGNN-like structure to encode building block features $h_m$ from atomic graphs (cutoff 5Å).</li>
<li><strong>Block-level Update Layers:</strong> 6 layers that iteratively update $q, \tau, l$ using the <strong>MOFAttention</strong> module.</li>
</ul>
</li>
<li><strong>MOFAttention:</strong> Modified Invariant Point Attention (IPA) that incorporates lattice parameters as offsets to the attention matrix.</li>
<li><strong>Hyperparameters:</strong>
<ul>
<li>Node dimension: 256 (block-level), 64 (atom-level).</li>
<li>Attention heads: 24.</li>
<li>Loss coefficients: $\lambda_1=1.0$ (rot), $\lambda_2=2.0$ (trans), $\lambda_3=0.1$ (lattice).</li>
</ul>
</li>
<li><strong>Checkpoints:</strong> Pre-trained weights and models are openly provided on <a href="https://zenodo.org/records/15187230">Zenodo</a>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics:</strong>
<ul>
<li><strong>Match Rate:</strong> Using <code>StructureMatcher</code> from <code>pymatgen</code>. Tolerances: <code>stol=0.5/1.0</code>, <code>ltol=0.3</code>, <code>angle_tol=10.0</code>.</li>
<li><strong>RMSE:</strong> Normalized by average free length per atom.</li>
</ul>
</li>
<li><strong>Tools:</strong> <strong>Zeo++</strong> for structural property calculations (Surface Area, Pore Diameter, etc.).</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">MOFFlow</th>
          <th style="text-align: left">DiffCSP</th>
          <th style="text-align: left">RS (20 cands)</th>
          <th style="text-align: left">EA (20 cands)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">MR (stol=0.5, k=1)</td>
          <td style="text-align: left"><strong>31.69%</strong></td>
          <td style="text-align: left">0.09%</td>
          <td style="text-align: left">0.00%</td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=1.0, k=1)</td>
          <td style="text-align: left"><strong>87.46%</strong></td>
          <td style="text-align: left">23.12%</td>
          <td style="text-align: left">0.00%</td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=0.5, k=5)</td>
          <td style="text-align: left"><strong>44.75%</strong></td>
          <td style="text-align: left">0.34%</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">MR (stol=1.0, k=5)</td>
          <td style="text-align: left"><strong>100.0%</strong></td>
          <td style="text-align: left">38.94%</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">RMSE (stol=0.5, k=1)</td>
          <td style="text-align: left"><strong>0.2820</strong></td>
          <td style="text-align: left">0.3961</td>
          <td style="text-align: left">-</td>
          <td style="text-align: left">-</td>
      </tr>
      <tr>
          <td style="text-align: left">Avg. time per structure</td>
          <td style="text-align: left"><strong>1.94s</strong></td>
          <td style="text-align: left">5.37s</td>
          <td style="text-align: left">332s</td>
          <td style="text-align: left">1,959s</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Hardware:</strong> 8 $\times$ NVIDIA RTX 3090 (24GB VRAM).</li>
<li><strong>Training Time:</strong>
<ul>
<li><em>TimestepBatch version (main paper):</em> ~5 days 15 hours.</li>
<li><em>Batch version:</em> ~1 day 17 hours (332.74 GPU hours). The authors also release this refactored implementation, which achieves comparable performance with faster convergence.</li>
</ul>
</li>
<li><strong>Batch Size:</strong> 160 (capped by $N^2$ where $N$ is the number of atoms, for memory management).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/nayoung10/MOFFlow">MOFFlow (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Official implementation built on DiffDock, EGNN, MOFDiff, and protein-frame-flow</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://zenodo.org/records/15187230">Pre-processed dataset and checkpoints (Zenodo)</a></td>
          <td style="text-align: left">Dataset / Model</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Includes pre-processed MOF structures and trained model weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kim, N., Kim, S., Kim, M., Park, J., &amp; Ahn, S. (2025). MOFFlow: Flow Matching for Structure Prediction of Metal-Organic Frameworks. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kimMOFFlowFlowMatching2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MOFFlow: Flow Matching for Structure Prediction of Metal-Organic Frameworks}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Kim, Nayoung and Kim, Seongsu and Kim, Minsu and Park, Jinkyoo and Ahn, Sungsoo}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Thirteenth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=dNT3abOsLo}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=dNT3abOsLo">OpenReview Discussion</a></li>
<li><a href="https://github.com/nayoung10/MOFFlow">Official Code Repository</a></li>
</ul>
]]></content:encoded></item><item><title>InvMSAFold: Generative Inverse Folding with Potts Models</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/invmsafold/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/invmsafold/</guid><description>InvMSAFold generates diverse protein sequences from structure by predicting Potts model parameters, enabling orders-of-magnitude faster sampling.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Methodological ($\Psi_{\text{Method}}$)</strong> paper. It introduces a novel architecture, <strong>InvMSAFold</strong>, which hybridizes deep learning encoders with statistical physics-based decoders (Potts models). The rhetorical structure focuses on architectural innovation (low-rank parameter generation), ablation of speed/diversity against baselines (ESM-IF1), and algorithmic efficiency.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard inverse folding models (like ESM-IF1 or ProteinMPNN) solve a &ldquo;one-to-one&rdquo; mapping: given a structure, predict the <em>single</em> native sequence. However, in nature, folding is &ldquo;many-to-one&rdquo;: many homologous sequences fold into the same structure.</p>
<p>The authors identify two key gaps:</p>
<ol>
<li><strong>Lack of Diversity</strong>: Standard autoregressive models maximize probability for the ground truth sequence, often failing to capture the broad evolutionary landscape of viable homologs.</li>
<li><strong>Slow Inference</strong>: Autoregressive sampling requires a full neural network pass for <em>every amino acid</em>, making high-throughput screening (e.g., millions of candidates) computationally prohibitive.</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is shifting the learning objective from predicting <em>sequences</em> to predicting <em>probability distributions</em>.</p>
<p>InvMSAFold outputs the parameters (couplings $\mathbf{J}$ and fields $\mathbf{h}$) of a <strong>Potts Model</strong> (a pairwise Markov Random Field).</p>
<ul>
<li><strong>Low-Rank Decomposition</strong>: To handle the massive parameter space of pairwise couplings ($L \times L \times q \times q$), the model predicts a low-rank approximation $\mathbf{V}$ ($L \times K \times q$), reducing complexity from $\mathcal{O}(L^2)$ to $\mathcal{O}(L)$.</li>
<li><strong>One-Shot Generation</strong>: The deep network runs only <em>once</em> to generate the Potts parameters. Sampling sequences from this Potts model is then performed on CPU via MCMC (for the PW variant) or direct autoregressive sampling (for the AR variant), which is orders of magnitude faster than running a Transformer decoder for every step.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the model on three CATH-based test sets (Inter-cluster, Intra-cluster, MSA) to test generalization at varying levels of homology.</p>
<ul>
<li><strong>Speed Benchmarking</strong>: Compared wall-clock sampling time vs. ESM-IF1 on CPU/GPU.</li>
<li><strong>Covariance Reconstruction</strong>: Checked if generated sequences recover the evolutionary correlations found in natural MSAs (Pearson correlation of covariance matrices).</li>
<li><strong>Structural Fidelity</strong>: Generated sequences with high Hamming distance from native, folded them with AlphaFold 2 (no templates), and measured RMSD to the target structure.</li>
<li><strong>Property Profiling</strong>: Analyzed the distribution of predicted solubility (Protein-Sol) and thermostability (Thermoprot) to show that sequence diversity translates into a wider range of biochemical properties.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Massive Speedup</strong>: InvMSAFold is orders of magnitude faster than ESM-IF1 (CPU vs. GPU; the comparison is not hardware-matched). Because the &ldquo;heavy lifting&rdquo; (generating Potts parameters) happens once, sampling millions of sequences becomes trivial on CPUs.</li>
<li><strong>Better Diversity</strong>: The model captures evolutionary covariances significantly better than ESM-IF1 and ProteinMPNN (which shares similar covariance recovery to ESM-IF1). A PCA-based KL-divergence analysis (lower is better; 0 means a perfect match to the natural MSA distribution) shows InvMSAFold-AR scores of $0.49$ (Inter-cluster) and $0.67$ (Intra-cluster), compared to $15.8$ and $11.9$ for ESM-IF1, demonstrating that the generated sequences occupy a distribution much closer to natural MSAs.</li>
<li><strong>Robust Folding</strong>: Sequences generated far from the native sequence (high Hamming distance) still fold into the correct structure (low RMSD), whereas ESM-IF1 struggles to produce diverse valid sequences.</li>
<li><strong>Property Expansion</strong>: The method generates a wider spread of predicted biochemical properties (solubility/thermostability), which could be useful for virtual screening in protein design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Source</strong>: CATH database (40% non-redundant dataset).</p>
<p><strong>Splits</strong>:</p>
<ul>
<li><strong>Training</strong>: ~22k domains.</li>
<li><strong>Inter-cluster Test</strong>: 10% of sequence clusters held out (unseen clusters, many with superfamilies absent from training).</li>
<li><strong>Intra-cluster Test</strong>: Unseen domains from seen clusters.</li>
<li><strong>Augmentation</strong>: MSAs generated using <strong>MMseqs2</strong> against the Uniprot50 database. Training uses random subsamples of these MSAs ($|M_X| = 64$ for PW, $|M_X| = 32$ for AR) to teach the model evolutionary variance.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Architecture</strong>:</p>
<ul>
<li><strong>Encoder</strong>: Pre-trained <strong>ESM-IF1</strong> encoder (GVP-GNN architecture). The encoder is used to pre-compute structure embeddings, with independent Gaussian noise (std = 5% of the embedding std) added during training.</li>
<li><strong>Decoder</strong>: 6-layer Transformer (8 heads) that outputs a latent tensor.</li>
<li><strong>Projection</strong>: Linear layers project latent tensor to fields $\mathbf{h}$ ($L \times q$) and low-rank tensor $\mathbf{V}$ ($L \times K \times q$).</li>
</ul>
<p><strong>Coupling Construction</strong>:
The full coupling tensor $\mathcal{J}$ is approximated via:
$$\mathcal{J}_{i,a,j,b} = \frac{1}{\sqrt{K}} \sum_{k=1}^{K} \mathcal{V}_{i,k,a} \mathcal{V}_{j,k,b}$$
Rank $K=48$ was used.</p>
<p><strong>Loss Functions</strong>:
Two variants were trained:</p>
<ol>
<li><strong>InvMSAFold-PW</strong>: Trained via <strong>Pseudo-Likelihood (PL)</strong>. Computation is optimized to $\mathcal{O}(L)$ time using the low-rank property.</li>
<li><strong>InvMSAFold-AR</strong>: Trained via <strong>Autoregressive Likelihood</strong>. Couplings are masked ($J_{ij} = 0$ if $i &lt; j$) to allow exact likelihood computation and direct sampling without MCMC.</li>
</ol>
<h3 id="models">Models</h3>
<ul>
<li><strong>InvMSAFold-PW</strong>: Requires MCMC sampling (Metropolis-Hastings) at inference.</li>
<li><strong>InvMSAFold-AR</strong>: Allows direct, fast autoregressive sampling.</li>
<li><strong>Hyperparameters</strong>: AdamW optimizer, lr=$10^{-4}$ (PW) / $3.4 \times 10^{-4}$ (AR), 94 epochs. L2 regularization: $\lambda_h = \lambda_J = 10^{-4}$ (PW); $\lambda_J = 3.2 \times 10^{-6}$, $\lambda_h = 5.0 \times 10^{-5}$ (AR).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>RMSD</strong>: Structure fidelity (AlphaFold2 prediction vs. native structure).</li>
<li><strong>Covariance Pearson Correlation</strong>: Measures recovery of evolutionary pairwise statistics.</li>
<li><strong>KL Divergence</strong>: Between PCA-projected densities of natural and synthetic sequences (Gaussian KDE, kernel size 1.0).</li>
<li><strong>Sampling Speed</strong>: Wall-clock time vs. sequence length/batch size.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: Not specified in the paper. The GitHub repository reports testing on an NVIDIA RTX 3090, with training taking 10-24 hours depending on model variant.</li>
<li><strong>Inference</strong>:
<ul>
<li><strong>ESM-IF1</strong>: NVIDIA GeForce RTX 4060 Laptop (8GB).</li>
<li><strong>InvMSAFold</strong>: Single core of Intel i9-13905H CPU.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/luchinoprince/Potts_Inverse_Folding">Potts_Inverse_Folding</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training and inference code (PyTorch)</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Silva, L. A., Meynard-Piganeau, B., Lucibello, C., &amp; Feinauer, C. (2025). Fast Uncovering of Protein Sequence Diversity from Structure. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://arxiv.org/abs/2406.11975">https://arxiv.org/abs/2406.11975</a></p>
<p><strong>Publication</strong>: ICLR 2025 (Spotlight)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{silvaFastUncoveringProtein2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Fast Uncovering of Protein Sequence Diversity from Structure}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Silva, Luca Alessandro and {Meynard-Piganeau}, Barthelemy and Lucibello, Carlo and Feinauer, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Statistical Mechanics: Theory and Experiment}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{084003}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1088/1742-5468/adf0e7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=1iuaxjssVp}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=1iuaxjssVp">OpenReview Page</a></li>
<li><a href="https://github.com/luchinoprince/Potts_Inverse_Folding">GitHub Repository</a></li>
</ul>
]]></content:encoded></item><item><title>InstructMol: Multi-Modal Molecular LLM for Drug Discovery</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/instructmol/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/instructmol/</guid><description>A multi-modal LLM aligning 2D molecular graphs with text via two-stage instruction tuning for drug discovery tasks.</description><content:encoded><![CDATA[<h2 id="instructmol-framework-overview">InstructMol Framework Overview</h2>
<p><strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong></p>
<p>This work proposes <strong>InstructMol</strong>, a novel multi-modal architecture and training paradigm. It focuses on engineering a system that aligns a pre-trained molecular graph encoder with a general-purpose Large Language Model (LLM). The paper&rsquo;s primary contribution is the <strong>Two-Stage Instruction Tuning</strong> strategy (Alignment Pre-training + Task-Specific Tuning) designed to bridge the modality gap between 2D molecular graphs and natural language.</p>
<h2 id="bridging-specialist-and-generalist-models">Bridging Specialist and Generalist Models</h2>
<p>Current AI approaches in drug discovery typically fall into two categories. Specialist models deliver high accuracy on specific tasks (such as property prediction) but require extensive labeled datasets and lack conversational adaptability. Conversely, generalist LLMs offer strong reasoning and dialogue capabilities but struggle to natively interpret complex structural data, often relying on brittle 1D text representations of molecules like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>.</p>
<p>There is a practical need for a unified &ldquo;Molecular Assistant&rdquo; capable of visually interpreting molecular graphs, reasoning about structure in natural language, and adapting across tasks like synthesis planning and property analysis without training from scratch.</p>
<h2 id="two-stage-modality-alignment">Two-Stage Modality Alignment</h2>
<p>The core novelty lies in the architecture and the <strong>two-stage training pipeline</strong> designed to align differing modalities efficiently:</p>
<ol>
<li><strong>MoleculeSTM Integration</strong>: InstructMol initializes its graph encoder with <strong>MoleculeSTM</strong>, which is already pre-aligned with text via contrastive learning, facilitating easier downstream alignment.</li>
<li><strong>Two-Stage Alignment Strategy</strong>:
<ul>
<li><strong>Stage 1 (Alignment Pre-training)</strong>: Freezes both the LLM and Graph Encoder; trains <em>only</em> a linear projector using a massive dataset of molecule-description pairs to map graph features into the LLM&rsquo;s token space.</li>
<li><strong>Stage 2 (Task-Specific Instruction Tuning)</strong>: Freezes the Graph Encoder; fine-tunes the Projector and the LLM (using <strong>LoRA</strong>) on specific downstream tasks. This allows the model to adapt its reasoning capabilities while preserving the structural understanding gained in Stage 1.</li>
</ul>
</li>
</ol>
<h2 id="task-evaluation-in-drug-discovery">Task Evaluation in Drug Discovery</h2>
<p>The authors evaluated InstructMol across three distinct categories of drug discovery tasks, comparing it against generalist LLMs (Vicuna, LLaMA, <a href="/notes/chemistry/llm-applications/galactica-large-language-model-for-science/">Galactica</a>) and specialist models (<a href="/notes/chemistry/molecular-representations/encoders/chemberta/">ChemBERTa</a>, MolT5):</p>
<ol>
<li><strong>Property Prediction</strong>:
<ul>
<li><em>Regression</em>: Predicting quantum mechanical properties (HOMO, LUMO, Gap) using the <a href="/notes/chemistry/datasets/qm9/">QM9</a> dataset.</li>
<li><em>Classification</em>: Predicting biological activity (BACE, BBBP, HIV) using <a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a>.</li>
</ul>
</li>
<li><strong>Molecule Description Generation</strong>: Generating natural language descriptions of molecules using the ChEBI-20 dataset.</li>
<li><strong>Chemical Reaction Analysis</strong>:
<ul>
<li><em>Forward Reaction Prediction</em>: Predicting products from reactants.</li>
<li><em>Reagent Prediction</em>: Identifying necessary reagents.</li>
<li><em><a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">Retrosynthesis</a></em>: Suggesting reactants for a given product.</li>
</ul>
</li>
</ol>
<p><strong>Ablation Studies</strong> tested the impact of the projector type (Linear vs. MLP), LLM scale (7B vs 13B), and the necessity of the two-stage training approach.</p>
<h2 id="core-findings-and-limitations">Core Findings and Limitations</h2>
<ul>
<li><strong>Improvement Over Baseline Generalists</strong>: InstructMol significantly outperformed generalist LLMs (like LLaMA and Galactica) on all tasks, demonstrating the value of incorporating explicit graph modalities.</li>
<li><strong>Reducing the Gap with Specialists</strong>: While InstructMol brings versatile reasoning capabilities, it still trails highly optimized specialist models (such as Uni-Mol and MolT5) on tasks like molecule description generation. This remaining gap likely stems from its reliance on a relatively small alignment pre-training dataset (~264K PubChem pairs) and the information bottleneck of using a simple linear projector, compared to the millions of structures used to train expert foundational models.</li>
<li><strong>Importance of Alignment</strong>: Ablation studies confirmed that skipping Stage 1 (Alignment Pre-training) degraded performance, proving that a dedicated phase for projecting graph features into text space is crucial.</li>
<li><strong>Limitation</strong>: The model struggles with highly imbalanced datasets (e.g., HIV) and complex reaction mixtures where mapping multiple graph tokens to text becomes ambiguous.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training pipeline utilizes distinct datasets for the two stages. <strong>Note:</strong> As of the latest repository update, the finely-processed instruction-tuning datasets (e.g., the filtered ~264K PubChem pairs and instruction-formatted subset pairs) are listed as &ldquo;coming soon&rdquo;, requiring manual recreation for full reproduction.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Stage 1</strong> (Alignment)</td>
          <td style="text-align: left"><strong><a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a></strong></td>
          <td style="text-align: left">~264K pairs</td>
          <td style="text-align: left">Molecule-text pairs. Filtered from 330K for invalid descriptions and overlaps with ChEBI-20 test set.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Reg.)</td>
          <td style="text-align: left"><strong>QM9</strong></td>
          <td style="text-align: left">362K samples</td>
          <td style="text-align: left">Quantum mechanics properties (HOMO, LUMO, Gap).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Prop. Class.)</td>
          <td style="text-align: left"><strong>MoleculeNet</strong></td>
          <td style="text-align: left">35K samples</td>
          <td style="text-align: left">BACE, BBBP, HIV datasets. Converted to instruction format (Yes/No answer).</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Generation)</td>
          <td style="text-align: left"><strong>ChEBI-20</strong></td>
          <td style="text-align: left">26.5K samples</td>
          <td style="text-align: left">Molecule description generation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Stage 2</strong> (Reactions)</td>
          <td style="text-align: left"><strong>USPTO</strong></td>
          <td style="text-align: left">~380K samples</td>
          <td style="text-align: left">Combined datasets for Forward (125K), Retrosynthesis (130K), and Reagent (125K) prediction.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Two-Stage Training</strong>:
<ol>
<li><strong>Alignment Pre-training</strong>: Updates only the Projector. The objective maximizes the probability of generating the target description token sequence $\mathbf{X}_A$ given the molecule input $\mathbf{X}_M$ and instruction $\mathbf{X}_I$:
$$p(\mathbf{X}_A | \mathbf{X}_M, \mathbf{X}_I) = \prod_{i=1}^L p_\theta(x_i | \mathbf{X}_G \parallel \mathbf{X}_S, \mathbf{X}_I, \mathbf{X}_{A,&lt;i})$$</li>
<li><strong>Instruction Tuning</strong>: Updates Projector + LLM (via LoRA) using standard autoregressive language modeling on task-specific instructions. The objective minimizes the negative log-likelihood of generating the target response $R$ of length $L$:
$$\mathcal{L}(\theta) = -\sum_{i=1}^L \log p(R_i | I, M, R_{&lt;i}; \theta)$$
where $I$ represents the instruction and $M$ is the multi-modal molecular input.</li>
</ol>
</li>
<li><strong>LoRA (Low-Rank Adaptation)</strong>: Applied to the LLM in Stage 2. Rank $r=64$, Scaling $\alpha=16$.</li>
<li><strong>Optimization</strong>: AdamW optimizer. Learning rate starts at 2e-3 (Stage 1) and 8e-5 (Stage 2) with cosine decay. Warm-up ratio 0.03.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Note:</strong> The official repository currently lists the final fine-tuned <strong>InstructMol weights</strong> as &ldquo;coming soon.&rdquo; Consequently, one must fine-tune the components using the provided scripts. Base model weights (Vicuna-7B and MoleculeSTM) are publicly available via Hugging Face.</p>
<ul>
<li><strong>Graph Encoder ($f_g$)</strong>:
<ul>
<li>Architecture: Graph Isomorphism Network (GIN) with 5 layers.</li>
<li>Hidden Dimension: 300.</li>
<li>Initialization: <strong>MoleculeSTM</strong> checkpoint (pre-trained via contrastive learning).</li>
<li>Status: <strong>Frozen</strong> during Stage 2.</li>
</ul>
</li>
<li><strong>LLM</strong>:
<ul>
<li>Base: <strong>Vicuna-v1.3-7B</strong>.</li>
<li>Status: Frozen in Stage 1; LoRA fine-tuned in Stage 2.</li>
</ul>
</li>
<li><strong>Projector</strong>:
<ul>
<li>Architecture: Linear Layer.</li>
<li>Function: Maps node-level graph representation $Z_G \in \mathbb{R}^{N \times d}$ to the LLM&rsquo;s word embedding space dimensions.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric Libraries</strong>: RDKit for validity/fingerprints, standard NLP libraries for BLEU/ROUGE.</li>
<li><strong>Reaction Metrics</strong>: Fingerprint <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto Similarity</a> (FTS), Exact Match, Levenshtein distance, and validity (via RDKit).</li>
<li><strong>Description Metrics</strong>: BLEU-2, BLEU-4, ROUGE-1, ROUGE-2, ROUGE-L, METEOR.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA RTX A6000 (48GB VRAM).</li>
<li><strong>Training Time</strong>:
<ul>
<li>Stage 1: 5 epochs.</li>
<li>Stage 2: 20-50 epochs (Description Generation), 10 epochs (Properties/Reactions).</li>
</ul>
</li>
<li><strong>Batch Size</strong>: 128 for both stages.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/IDEA-XL/InstructMol">InstructMol (GitHub)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache 2.0 (code), CC BY-NC 4.0 (data)</td>
          <td style="text-align: left">Training/evaluation scripts provided; fine-tuned weights listed as &ldquo;coming soon&rdquo;</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/lmsys/vicuna-7b-v1.3">Vicuna-7B v1.3</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">Non-commercial (LLaMA license)</td>
          <td style="text-align: left">Base LLM; must be downloaded separately</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/chao1224/MoleculeSTM">MoleculeSTM</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Pre-trained graph encoder checkpoint</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Cao, H., Liu, Z., Lu, X., Yao, Y., &amp; Li, Y. (2025). InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery. <em>Proceedings of the 31st International Conference on Computational Linguistics</em>, 354-379.</p>
<p><strong>Publication</strong>: COLING 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{caoInstructMolMultiModalIntegration2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{InstructMol}}: {{Multi-Modal Integration}} for {{Building}} a {{Versatile}} and {{Reliable Molecular Assistant}} in {{Drug Discovery}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{InstructMol}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 31st {{International Conference}} on {{Computational Linguistics}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Cao, He and Liu, Zijing and Lu, Xingyu and Yao, Yuan and Li, Yu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{Rambow, Owen and Wanner, Leo and Apidianaki, Marianna and {Al-Khalifa}, Hend and Eugenio, Barbara Di and Schockaert, Steven}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2025</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{354--379}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://aclanthology.org/2025.coling-main.25/}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Abu Dhabi, UAE}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">abstract</span> = <span style="color:#e6db74">{The rapid evolution of artificial intelligence in drug discovery encounters challenges with generalization and extensive training, yet Large Language Models (LLMs) offer promise in reshaping interactions with complex molecular data. Our novel contribution, InstructMol, a multi-modal LLM, effectively aligns molecular structures with natural language via an instruction-tuning approach, utilizing a two-stage training strategy that adeptly combines limited domain-specific data with molecular and textual information. InstructMol showcases substantial performance improvements in drug discovery-related molecular tasks, surpassing leading LLMs and significantly reducing the gap with specialists, thereby establishing a robust foundation for a versatile and dependable drug discovery assistant.}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/IDEA-XL/InstructMol">Official Repository</a></li>
</ul>
]]></content:encoded></item><item><title>Image-to-Sequence OCSR: A Comparative Analysis</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/benchmarks/image-to-sequence-comparison/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/benchmarks/image-to-sequence-comparison/</guid><description>Comparative analysis of image-to-sequence OCSR methods across architecture, output format, training data, and compute requirements.</description><content:encoded><![CDATA[<h2 id="overview">Overview</h2>
<p>This note provides a comparative analysis of image-to-sequence methods for Optical Chemical Structure Recognition (OCSR). These methods treat molecular structure recognition as an image captioning task, using encoder-decoder architectures to generate sequential molecular representations (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>) directly from pixels.</p>
<p>For the full taxonomy of OCSR approaches including image-to-graph and rule-based methods, see the <a href="/notes/chemistry/optical-structure-recognition/benchmarks/ocsr-methods/">OCSR Methods taxonomy</a>.</p>
<h2 id="architectural-evolution-2019-2025">Architectural Evolution (2019-2025)</h2>
<p>The field has undergone rapid architectural evolution, with clear generational shifts in both encoder and decoder design.</p>
<h3 id="timeline">Timeline</h3>
<table>
  <thead>
      <tr>
          <th>Era</th>
          <th>Encoder</th>
          <th>Decoder</th>
          <th>Representative Methods</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2019-2020</strong></td>
          <td>CNN (Inception V3, ResNet)</td>
          <td>LSTM/GRU with Attention</td>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/staker-deep-learning-2019/">Staker et al.</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer/">DECIMER</a></td>
      </tr>
      <tr>
          <td><strong>2021</strong></td>
          <td>EfficientNet, ViT</td>
          <td>Transformer</td>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-1.0/">DECIMER 1.0</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/img2mol/">Img2Mol</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/vit-inchi-transformer/">ViT-InChI</a></td>
      </tr>
      <tr>
          <td><strong>2022</strong></td>
          <td>Swin Transformer, ResNet</td>
          <td>Transformer</td>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/swinocsr/">SwinOCSR</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/image2smiles/">Image2SMILES</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/micer/">MICER</a></td>
      </tr>
      <tr>
          <td><strong>2023-2024</strong></td>
          <td>EfficientNetV2, SwinV2</td>
          <td>Transformer</td>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-ai/">DECIMER.ai</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/image2inchi/">Image2InChI</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/mmssc-net/">MMSSC-Net</a></td>
      </tr>
      <tr>
          <td><strong>2025</strong></td>
          <td>EfficientViT, VLMs (Qwen2-VL)</td>
          <td>LLM decoders, RL fine-tuning</td>
          <td><a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/molsight/">MolSight</a>, <a href="/notes/chemistry/optical-structure-recognition/vision-language/gtr-mol-vlm/">GTR-CoT</a>, <a href="/notes/chemistry/optical-structure-recognition/vision-language/ocsu/">OCSU</a></td>
      </tr>
  </tbody>
</table>
<h3 id="encoder-architectures">Encoder Architectures</h3>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Methods Using It</th>
          <th>Key Characteristics</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Inception V3</strong></td>
          <td>DECIMER (2020)</td>
          <td>Early CNN approach, 299x299 input</td>
      </tr>
      <tr>
          <td><strong>ResNet-50/101</strong></td>
          <td>IMG2SMI, Image2SMILES, MICER, DGAT</td>
          <td>Strong baseline, well-understood</td>
      </tr>
      <tr>
          <td><strong>EfficientNet-B3</strong></td>
          <td>DECIMER 1.0</td>
          <td>Efficient scaling, compound coefficients</td>
      </tr>
      <tr>
          <td><strong>EfficientNet-V2-M</strong></td>
          <td>DECIMER.ai, DECIMER-Hand-Drawn</td>
          <td>Improved training efficiency</td>
      </tr>
      <tr>
          <td><strong>EfficientViT-L1</strong></td>
          <td>MolSight</td>
          <td>Optimized for deployment</td>
      </tr>
      <tr>
          <td><strong>Swin Transformer</strong></td>
          <td>SwinOCSR, MolParser</td>
          <td>Hierarchical vision transformer</td>
      </tr>
      <tr>
          <td><strong>SwinV2</strong></td>
          <td>MMSSC-Net, Image2InChI</td>
          <td>Improved training stability</td>
      </tr>
      <tr>
          <td><strong>Vision Transformer (ViT)</strong></td>
          <td>ViT-InChI</td>
          <td>Pure attention encoder</td>
      </tr>
      <tr>
          <td><strong>DenseNet</strong></td>
          <td>RFL, Hu et al. RCGD</td>
          <td>Dense connections, feature reuse</td>
      </tr>
      <tr>
          <td><strong>Deep TNT</strong></td>
          <td>ICMDT</td>
          <td>Transformer-in-Transformer</td>
      </tr>
      <tr>
          <td><strong>Qwen2-VL</strong></td>
          <td>OCSU, GTR-CoT</td>
          <td>Vision-language model encoder</td>
      </tr>
  </tbody>
</table>
<h3 id="decoder-architectures">Decoder Architectures</h3>
<table>
  <thead>
      <tr>
          <th>Architecture</th>
          <th>Methods Using It</th>
          <th>Output Format</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>GRU with Attention</strong></td>
          <td>DECIMER, RFL, Hu et al. RCGD</td>
          <td>SMILES, RFL, SSML</td>
      </tr>
      <tr>
          <td><strong>LSTM with Attention</strong></td>
          <td>Staker et al., ChemPix, MICER</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><strong>Transformer</strong></td>
          <td>Most 2021+ methods</td>
          <td>SMILES, SELFIES, InChI</td>
      </tr>
      <tr>
          <td><strong>GPT-2</strong></td>
          <td>MMSSC-Net</td>
          <td>SMILES</td>
      </tr>
      <tr>
          <td><strong>BART</strong></td>
          <td>MolParser</td>
          <td>E-SMILES</td>
      </tr>
      <tr>
          <td><strong>Pre-trained CDDD</strong></td>
          <td>Img2Mol</td>
          <td>Continuous embedding → SMILES</td>
      </tr>
  </tbody>
</table>
<h2 id="output-representation-comparison">Output Representation Comparison</h2>
<p>The choice of molecular string representation significantly impacts model performance. Representations fall into three categories: core molecular formats for single structures, extended formats for molecular families and variable structures (primarily Markush structures in patents), and specialized representations optimizing for specific recognition challenges.</p>
<p>The <a href="/notes/chemistry/optical-structure-recognition/benchmarks/rajan-string-representations-2022/">Rajan et al. 2022 ablation study</a> provides a comparison of core formats.</p>
<h3 id="core-molecular-formats">Core Molecular Formats</h3>
<p>These represent specific, concrete molecular structures.</p>
<table>
  <thead>
      <tr>
          <th>Format</th>
          <th>Validity Guarantee</th>
          <th>Sequence Length</th>
          <th>Key Characteristic</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>SMILES</strong></td>
          <td>No</td>
          <td>Shortest (baseline)</td>
          <td>Standard, highest accuracy</td>
          <td>DECIMER.ai, MolSight, DGAT, most 2023+</td>
      </tr>
      <tr>
          <td><strong>DeepSMILES</strong></td>
          <td>Partial</td>
          <td>~1.1x SMILES</td>
          <td>Reduces non-local dependencies</td>
          <td>SwinOCSR</td>
      </tr>
      <tr>
          <td><strong>SELFIES</strong></td>
          <td>Yes (100%)</td>
          <td>~1.5x SMILES</td>
          <td>Guaranteed valid molecules</td>
          <td>DECIMER 1.0, IMG2SMI</td>
      </tr>
      <tr>
          <td><strong>InChI</strong></td>
          <td>N/A (canonical)</td>
          <td>Variable (long)</td>
          <td>Unique identifiers, layered syntax</td>
          <td>ViT-InChI, ICMDT, Image2InChI</td>
      </tr>
      <tr>
          <td><strong>FG-SMILES</strong></td>
          <td>No</td>
          <td>Similar to SMILES</td>
          <td>Functional group-aware tokenization</td>
          <td>Image2SMILES</td>
      </tr>
  </tbody>
</table>
<h4 id="smiles-and-variants">SMILES and Variants</h4>
<p><strong>SMILES</strong> remains the dominant format due to its compactness and highest accuracy on clean data. Standard SMILES uses single characters for ring closures and branches that may appear far apart in the sequence, creating learning challenges for sequence models.</p>
<p><strong>DeepSMILES</strong> addresses these non-local syntax dependencies by modifying how branches and ring closures are encoded, making sequences more learnable for neural models. Despite this modification, DeepSMILES sequences are ~1.1x longer than standard SMILES (not shorter). The format offers partial validity improvements through regex-based tokenization with a compact 76-token vocabulary, providing a middle ground between SMILES accuracy and guaranteed validity.</p>
<p><strong>SELFIES</strong> guarantees 100% valid molecules by design through a context-free grammar, eliminating invalid outputs entirely. This comes at the cost of ~1.5x longer sequences and a typical 2-5% accuracy drop compared to SMILES on exact-match metrics. The validity guarantee makes SELFIES particularly attractive for generative modeling applications.</p>
<p><strong>InChI</strong> uses a layered canonical syntax fundamentally different from SMILES-based formats. While valuable for unique molecular identification, its complex multi-layer structure (formula, connectivity, stereochemistry, isotopes, etc.) and longer sequences make it less suitable for image-to-sequence learning, resulting in lower recognition accuracy.</p>
<h4 id="key-findings-from-rajan-et-al-2022">Key Findings from Rajan et al. 2022</h4>
<ol>
<li><strong>SMILES achieves highest exact-match accuracy</strong> on clean synthetic data</li>
<li><strong>SELFIES guarantees 100% valid molecules</strong> but at cost of ~2-5% accuracy drop</li>
<li><strong>InChI is problematic</strong> due to complex layered syntax and longer sequences</li>
<li><strong>DeepSMILES offers middle ground</strong> with partial validity improvements through modified syntax</li>
</ol>
<h3 id="extended-formats-for-variable-structures">Extended Formats for Variable Structures</h3>
<p><strong>Markush structures</strong> represent families of molecules, using variable groups (R1, R2, etc.) with textual definitions. They are ubiquitous in patent documents for intellectual property protection. Standard SMILES cannot represent these variable structures.</p>
<table>
  <thead>
      <tr>
          <th>Format</th>
          <th>Base Format</th>
          <th>Key Feature</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>E-SMILES</strong></td>
          <td>SMILES + XML annotations</td>
          <td>Backward-compatible with separator token</td>
          <td>MolParser</td>
      </tr>
      <tr>
          <td><strong>CXSMILES</strong></td>
          <td>SMILES + extension block</td>
          <td>Substituent tables, compression</td>
          <td>MarkushGrapher</td>
      </tr>
  </tbody>
</table>
<p><strong>E-SMILES</strong> (Extended SMILES) maintains backward compatibility by using a <code>&lt;sep&gt;</code> token to separate core SMILES from XML-like annotations. Annotations encode Markush substituents (<code>&lt;a&gt;index:group&lt;/a&gt;</code>), polymer structures (<code>&lt;p&gt;polymer_info&lt;/p&gt;</code>), and abstract ring patterns (<code>&lt;r&gt;abstract_ring&lt;/r&gt;</code>). The core structure remains parseable by standard RDKit.</p>
<p><strong>CXSMILES</strong> optimizes representation by moving variable groups directly into the main SMILES string as special atoms with explicit atom indexing (e.g., <code>C:1</code>) to link to an extension block containing substituent tables. This handles both frequency variation and position variation in Markush structures.</p>
<h3 id="specialized-representations">Specialized Representations</h3>
<p>These formats optimize for specific recognition challenges beyond standard single-molecule tasks.</p>
<h4 id="rfl-ring-free-language">RFL: Ring-Free Language</h4>
<p><strong>RFL</strong> fundamentally restructures molecular serialization through hierarchical ring decomposition, addressing a core challenge: standard 1D formats (SMILES, SSML) flatten complex 2D molecular graphs, losing explicit spatial relationships.</p>
<p><strong>Mechanism</strong>: RFL decomposes molecules into three explicit components:</p>
<ul>
<li><strong>Molecular Skeleton (𝒮)</strong>: Main graph with rings &ldquo;collapsed&rdquo;</li>
<li><strong>Ring Structures (ℛ)</strong>: Individual ring components stored separately</li>
<li><strong>Branch Information (ℱ)</strong>: Connectivity between skeleton and rings</li>
</ul>
<p><strong>Technical approach</strong>:</p>
<ol>
<li>Detect all non-nested rings using DFS</li>
<li>Calculate adjacency ($\gamma$) between rings based on shared edges</li>
<li>Merge isolated rings ($\gamma=0$) into <strong>SuperAtoms</strong> (single node placeholders)</li>
<li>Merge adjacent rings ($\gamma&gt;0$) into <strong>SuperBonds</strong> (edge placeholders)</li>
<li>Progressive decoding: predict skeleton first, then conditionally decode rings using stored hidden states</li>
</ol>
<p><strong>Performance</strong>: RFL achieves SOTA results on both handwritten (95.38% EM) and printed (95.58% EM) structures, with particular strength on high-complexity molecules where standard baselines fail completely (0% → ~30% on hardest tier).</p>
<p><strong>Note</strong>: RFL does not preserve original drawing orientation; it&rsquo;s focused on computational efficiency through hierarchical decomposition.</p>
<h4 id="ssml-structure-specific-markup-language">SSML: Structure-Specific Markup Language</h4>
<p><strong>SSML</strong> is the primary orientation-preserving format in OCSR. Based on Chemfig (LaTeX chemical drawing package), it provides step-by-step drawing instructions.</p>
<p><strong>Key characteristics</strong>:</p>
<ul>
<li>Describes <em>how to draw</em> the molecule alongside its graph structure</li>
<li>Uses &ldquo;reconnection marks&rdquo; for cyclic structures</li>
<li>Preserves branch angles and spatial relationships</li>
<li>Significantly outperformed SMILES for handwritten recognition: 92.09% vs 81.89% EM (Hu et al. RCGD 2023)</li>
</ul>
<p><strong>Use case</strong>: Particularly valuable for hand-drawn structure recognition where visual alignment between image and reconstruction sequence aids model learning.</p>
<h2 id="training-data-comparison">Training Data Comparison</h2>
<p>Training data scale has grown dramatically, with a shift toward combining synthetic and real-world images.</p>
<h3 id="data-scale-evolution">Data Scale Evolution</h3>
<table>
  <thead>
      <tr>
          <th>Year</th>
          <th>Typical Scale</th>
          <th>Maximum Reported</th>
          <th>Primary Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2019-2020</td>
          <td>1-15M</td>
          <td>57M (Staker)</td>
          <td>Synthetic (RDKit, CDK)</td>
      </tr>
      <tr>
          <td>2021-2022</td>
          <td>5-35M</td>
          <td>35M (DECIMER 1.0)</td>
          <td>Synthetic with augmentation</td>
      </tr>
      <tr>
          <td>2023-2024</td>
          <td>100-150M</td>
          <td>450M+ (DECIMER.ai)</td>
          <td>Synthetic + real patents</td>
      </tr>
      <tr>
          <td>2025</td>
          <td>1-10M + real</td>
          <td>7.7M (MolParser)</td>
          <td>Curated real + synthetic</td>
      </tr>
  </tbody>
</table>
<h3 id="synthetic-vs-real-data">Synthetic vs Real Data</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Training Data</th>
          <th>Real-World Performance Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>450M+ synthetic (RanDepict)</td>
          <td>Strong generalization via domain randomization</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>7.7M with active learning</td>
          <td>Explicitly targets &ldquo;in the wild&rdquo; images</td>
      </tr>
      <tr>
          <td><strong>GTR-CoT</strong></td>
          <td>Real patent/paper images</td>
          <td>Chain-of-thought improves reasoning</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>Multi-stage curriculum</td>
          <td>RL fine-tuning for stereochemistry</td>
      </tr>
  </tbody>
</table>
<h3 id="data-augmentation-strategies">Data Augmentation Strategies</h3>
<p>Common augmentation techniques across methods:</p>
<table>
  <thead>
      <tr>
          <th>Technique</th>
          <th>Purpose</th>
          <th>Used By</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Rotation</strong></td>
          <td>Orientation invariance</td>
          <td>Nearly all methods</td>
      </tr>
      <tr>
          <td><strong>Gaussian blur</strong></td>
          <td>Image quality variation</td>
          <td>DECIMER, MolParser</td>
      </tr>
      <tr>
          <td><strong>Salt-and-pepper noise</strong></td>
          <td>Scan artifact simulation</td>
          <td>DECIMER, Image2SMILES</td>
      </tr>
      <tr>
          <td><strong>Affine transforms</strong></td>
          <td>Perspective variation</td>
          <td>ChemPix, MolParser</td>
      </tr>
      <tr>
          <td><strong>Font/style variation</strong></td>
          <td>Rendering diversity</td>
          <td>RanDepict (DECIMER.ai)</td>
      </tr>
      <tr>
          <td><strong>Hand-drawn simulation</strong></td>
          <td>Sketch-like inputs</td>
          <td>ChemPix, ChemReco, DECIMER-Hand-Drawn</td>
      </tr>
      <tr>
          <td><strong>Background variation</strong></td>
          <td>Document context</td>
          <td>MolParser, DECIMER.ai</td>
      </tr>
  </tbody>
</table>
<h2 id="hardware-and-compute-requirements">Hardware and Compute Requirements</h2>
<p>Hardware requirements span several orders of magnitude, from consumer GPUs to TPU pods.</p>
<h3 id="training-hardware-comparison">Training Hardware Comparison</h3>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Hardware</th>
          <th>Training Time</th>
          <th>Dataset Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Staker et al. (2019)</strong></td>
          <td>8x GPUs</td>
          <td>26 days</td>
          <td>57M</td>
      </tr>
      <tr>
          <td><strong>IMG2SMI (2021)</strong></td>
          <td>1x RTX 2080 Ti</td>
          <td>5 epochs</td>
          <td>~10M</td>
      </tr>
      <tr>
          <td><strong>Image2SMILES (2022)</strong></td>
          <td>4x V100</td>
          <td>2 weeks</td>
          <td>30M</td>
      </tr>
      <tr>
          <td><strong>MICER (2022)</strong></td>
          <td>4x V100</td>
          <td>42 hours</td>
          <td>10M</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0 (2021)</strong></td>
          <td>TPU v3-8</td>
          <td>Not reported</td>
          <td>35M</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai (2023)</strong></td>
          <td>TPU v3-256</td>
          <td>Not reported</td>
          <td>450M+</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR (2022)</strong></td>
          <td>4x RTX 3090</td>
          <td>5 days</td>
          <td>5M</td>
      </tr>
      <tr>
          <td><strong>MolParser (2025)</strong></td>
          <td>8x A100</td>
          <td>Curriculum learning</td>
          <td>7.7M</td>
      </tr>
      <tr>
          <td><strong>MolSight (2025)</strong></td>
          <td>Not specified</td>
          <td>RL fine-tuning (GRPO)</td>
          <td>Multi-stage</td>
      </tr>
  </tbody>
</table>
<h3 id="inference-considerations">Inference Considerations</h3>
<p>Few papers report inference speed consistently. Available data:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Inference Speed</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>4x faster than DECIMER</td>
          <td>TensorFlow Lite optimization</td>
      </tr>
      <tr>
          <td><strong>OSRA</strong> (baseline)</td>
          <td>~1 image/sec</td>
          <td>CPU-based rule system</td>
      </tr>
      <tr>
          <td><strong>MolScribe</strong></td>
          <td>Real-time capable</td>
          <td>Optimized Swin encoder</td>
      </tr>
  </tbody>
</table>
<h3 id="accessibility-tiers">Accessibility Tiers</h3>
<table>
  <thead>
      <tr>
          <th>Tier</th>
          <th>Hardware</th>
          <th>Representative Methods</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Consumer</strong></td>
          <td>1x RTX 2080/3090</td>
          <td>IMG2SMI, ChemPix</td>
      </tr>
      <tr>
          <td><strong>Workstation</strong></td>
          <td>4x V100/A100</td>
          <td>Image2SMILES, MICER, SwinOCSR</td>
      </tr>
      <tr>
          <td><strong>Cloud/HPC</strong></td>
          <td>TPU pods, 8+ A100</td>
          <td>DECIMER.ai, MolParser</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmark-performance">Benchmark Performance</h2>
<h3 id="common-evaluation-datasets">Common Evaluation Datasets</h3>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Type</th>
          <th>Size</th>
          <th>Challenge</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>USPTO</strong></td>
          <td>Patent images</td>
          <td>~5K test</td>
          <td>Real-world complexity</td>
      </tr>
      <tr>
          <td><strong>UOB</strong></td>
          <td>Scanned images</td>
          <td>~5K test</td>
          <td>Scan artifacts</td>
      </tr>
      <tr>
          <td><strong>Staker</strong></td>
          <td>Synthetic</td>
          <td>Variable</td>
          <td>Baseline synthetic</td>
      </tr>
      <tr>
          <td><strong>CLEF</strong></td>
          <td>Patent images</td>
          <td>~1K test</td>
          <td>Markush structures</td>
      </tr>
      <tr>
          <td><strong>JPO</strong></td>
          <td>Japanese patents</td>
          <td>~1K test</td>
          <td>Different rendering styles</td>
      </tr>
  </tbody>
</table>
<h3 id="accuracy-comparison-exact-match-">Accuracy Comparison (Exact Match %)</h3>
<p>Methods are roughly grouped by evaluation era; direct comparison is complicated by different test sets.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>USPTO</th>
          <th>UOB</th>
          <th>Staker</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>OSRA</strong> (baseline)</td>
          <td>~70%</td>
          <td>~65%</td>
          <td>~80%</td>
          <td>Rule-based reference</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>~85%</td>
          <td>~80%</td>
          <td>~90%</td>
          <td>First transformer-based</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR</strong></td>
          <td>~88%</td>
          <td>~82%</td>
          <td>~92%</td>
          <td>Swin encoder advantage</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>~90%</td>
          <td>~85%</td>
          <td>~95%</td>
          <td>Scale + augmentation</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>~92%</td>
          <td>~88%</td>
          <td>~96%</td>
          <td>Real-world focus</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>~93%+</td>
          <td>~89%+</td>
          <td>~97%+</td>
          <td>RL fine-tuning boost</td>
      </tr>
  </tbody>
</table>
<p><em>Note: Numbers are approximate and may vary by specific test split. See individual paper notes for precise figures.</em></p>
<h3 id="stereochemistry-recognition">Stereochemistry Recognition</h3>
<p>Stereochemistry remains a persistent challenge across all methods:</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Approach</th>
          <th>Stereo Accuracy</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Most methods</strong></td>
          <td>Standard SMILES</td>
          <td>Lower than non-stereo</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>RL (GRPO) specifically for stereo</td>
          <td>Improved</td>
      </tr>
      <tr>
          <td><strong>MolNexTR</strong></td>
          <td>Graph-based explicit stereo</td>
          <td>Better handling</td>
      </tr>
      <tr>
          <td><strong>Image2InChI</strong></td>
          <td>InChI stereo layers</td>
          <td>Mixed results</td>
      </tr>
  </tbody>
</table>
<h2 id="hand-drawn-recognition">Hand-Drawn Recognition</h2>
<p>A distinct sub-lineage focuses on hand-drawn/sketched chemical structures.</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Target Domain</th>
          <th>Key Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ChemPix (2021)</strong></td>
          <td>Hand-drawn hydrocarbons</td>
          <td>First deep learning for sketches</td>
      </tr>
      <tr>
          <td><strong>Hu et al. RCGD (2023)</strong></td>
          <td>Hand-drawn structures</td>
          <td>Random conditional guided decoder</td>
      </tr>
      <tr>
          <td><strong>ChemReco (2024)</strong></td>
          <td>Hand-drawn C-H-O structures</td>
          <td>EfficientNet + curriculum learning</td>
      </tr>
      <tr>
          <td><strong>DECIMER-Hand-Drawn (2024)</strong></td>
          <td>General hand-drawn</td>
          <td>Enhanced DECIMER architecture</td>
      </tr>
  </tbody>
</table>
<h3 id="hand-drawn-vs-printed-trade-offs">Hand-Drawn vs Printed Trade-offs</h3>
<ul>
<li>Hand-drawn methods sacrifice some accuracy on clean printed images</li>
<li>Require specialized training data (synthetic hand-drawn simulation)</li>
<li>Generally smaller training sets due to data collection difficulty</li>
<li>Better suited for educational and lab notebook applications</li>
</ul>
<h2 id="key-innovations-by-method">Key Innovations by Method</h2>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>Primary Innovation</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Staker et al.</strong></td>
          <td>First end-to-end deep learning OCSR</td>
      </tr>
      <tr>
          <td><strong>DECIMER 1.0</strong></td>
          <td>Transformer decoder + SELFIES</td>
      </tr>
      <tr>
          <td><strong>Img2Mol</strong></td>
          <td>Continuous embedding space (CDDD)</td>
      </tr>
      <tr>
          <td><strong>Image2SMILES</strong></td>
          <td>Functional group-aware SMILES (FG-SMILES)</td>
      </tr>
      <tr>
          <td><strong>SwinOCSR</strong></td>
          <td>Hierarchical vision transformer encoder</td>
      </tr>
      <tr>
          <td><strong>DECIMER.ai</strong></td>
          <td>Massive scale + RanDepict augmentation</td>
      </tr>
      <tr>
          <td><strong>MolParser</strong></td>
          <td>Extended SMILES + active learning</td>
      </tr>
      <tr>
          <td><strong>MolSight</strong></td>
          <td>RL fine-tuning (GRPO) for accuracy</td>
      </tr>
      <tr>
          <td><strong>GTR-CoT</strong></td>
          <td>Chain-of-thought graph traversal</td>
      </tr>
      <tr>
          <td><strong>OCSU</strong></td>
          <td>Multi-task vision-language understanding</td>
      </tr>
      <tr>
          <td><strong>RFL</strong></td>
          <td>Hierarchical ring decomposition with SuperAtoms/SuperBonds</td>
      </tr>
  </tbody>
</table>
<h2 id="open-challenges">Open Challenges</h2>
<ol>
<li><strong>Stereochemistry</strong>: Consistent challenge across all methods; RL approaches (MolSight) show promise</li>
<li><strong>Abbreviations/R-groups</strong>: E-SMILES and Markush-specific methods emerging</li>
<li><strong>Real-world robustness</strong>: Gap between synthetic training and patent/paper images</li>
<li><strong>Inference speed</strong>: Rarely reported; important for production deployment</li>
<li><strong>Memory efficiency</strong>: Almost never documented; limits accessibility</li>
<li><strong>Multi-molecule images</strong>: Most methods assume single isolated structure</li>
</ol>
<h2 id="references">References</h2>
<p>Individual paper notes linked throughout. For the complete method listing, see the <a href="/notes/chemistry/optical-structure-recognition/benchmarks/ocsr-methods/">OCSR Methods taxonomy</a>.</p>
]]></content:encoded></item><item><title>DynamicFlow: Integrating Protein Dynamics into Drug Design</title><link>https://hunterheidenreich.com/notes/biology/computational-biology/dynamicflow/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/biology/computational-biology/dynamicflow/</guid><description>Flow matching model that co-generates ligands and flexible protein pockets, addressing rigid-receptor limitations in structure-based drug design.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a strong <strong>Resource</strong> ($\Psi_{\text{Resource}}$) component.</p>
<ul>
<li><strong>Method</strong>: It proposes <strong>DynamicFlow</strong>, a novel multiscale architecture combining atom-level SE(3)-equivariant GNNs (SE(3) is the special Euclidean group in 3D: the set of all 3D rotations and translations, and equivariance means predictions transform consistently under those symmetries) and residue-level Transformers within a <a href="/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/">flow matching</a> framework to model the joint distribution of ligand generation and protein conformational change.</li>
<li><strong>Resource</strong>: It curates a significant dataset derived from MISATO, pairing AlphaFold2-predicted apo structures with multiple MD-simulated holo states, specifically filtered for flow matching tasks.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Traditional Structure-Based Drug Design (SBDD) methods typically assume the protein target is rigid, which limits their applicability because proteins are dynamic and undergo conformational changes (induced fit) upon ligand binding.</p>
<ul>
<li><strong>Biological Reality</strong>: Proteins exist as ensembles of states; binding often involves transitions from &ldquo;apo&rdquo; (unbound) to &ldquo;holo&rdquo; (bound) <a href="/posts/geom-conformer-generation-dataset/">conformational changes</a>, sometimes revealing cryptic pockets.</li>
<li><strong>Computational Bottleneck</strong>: <a href="/notes/chemistry/molecular-simulation/">Molecular Dynamics (MD)</a> simulates these changes but incurs high computational costs due to energy barriers.</li>
<li><strong>Gap</strong>: <a href="/notes/machine-learning/generative-models/">Existing generative models</a> for SBDD mostly condition on a fixed pocket structure, ignoring the co-adaptation of the protein and ligand.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>simultaneous modeling of ligand generation and protein conformational dynamics</strong> using a unified flow matching framework.</p>
<ul>
<li><strong>DynamicFlow Architecture</strong>: A multiscale model that treats the protein as both full-atom (for interaction) and residue-level frames (for large-scale dynamics), utilizing separate flow matching objectives for backbone frames, side-chain torsions, and ligand atoms.</li>
<li><strong>Stochastic Flow (SDE)</strong>: Introduction of a <a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">stochastic variant</a> (DynamicFlow-SDE) that improves robustness and diversity compared to the deterministic ODE flow.</li>
<li><strong>Coupled Generation</strong>: The model learns to transport the <em>apo</em> pocket distribution to the <em>holo</em> pocket distribution while simultaneously denoising the ligand, advancing beyond rigid pocket docking methods.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method on a curated dataset of 5,692 protein-ligand complexes.</p>
<ul>
<li><strong>Baselines</strong>: Compared against rigid-pocket SBDD methods: Pocket2Mol, TargetDiff, and IPDiff (adapted as TargetDiff* and IPDiff* for fair comparison of atom numbers). Also compared against conformation sampling baselines (Str2Str).</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Ligand Quality</strong>: Vina Score (binding affinity), QED (drug-likeness), SA (synthesizability), Lipinski&rsquo;s rule of 5.</li>
<li><strong>Pocket Quality</strong>: RMSD between generated and ground-truth holo pockets, Cover Ratio (percentage of holo states successfully retrieved), and Pocket Volume distributions.</li>
<li><strong>Interaction</strong>: Protein-Ligand Interaction Profiler (PLIP) to measure specific non-covalent interactions.</li>
</ul>
</li>
<li><strong>Ablations</strong>: Tested the impact of the interaction loss, residue-level Transformer, and SDE vs. ODE formulations.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Improved Affinity</strong>: DynamicFlow-SDE achieved the best (lowest) Vina scores ($-7.65$) compared to baselines like TargetDiff ($-5.09$) and Pocket2Mol ($-5.50$). Note that Vina scores are a computational proxy and do not directly predict experimental binding affinity. Moreover, Vina score optimization is gameable: molecules can achieve strong computed binding energies while remaining synthetically inaccessible. QED and SA scores, which assess drug-likeness and synthesizability respectively, were reported but were not primary optimization targets in the paper, which limits the strength of this affinity claim.</li>
<li><strong>Realistic Dynamics</strong>: The model successfully generated holo-like pocket conformations with volume distributions and interaction profiles closer to ground-truth MD simulations than the initial apo structures.</li>
<li><strong>Enhancing Rigid Methods</strong>: Holo pockets generated by DynamicFlow served as better inputs for rigid-SBDD baselines (e.g., TargetDiff improved from $-5.09$ to $-9.00$ and IPDiff improved from $-7.55$ to $-11.04$ when using &ldquo;Our Pocket&rdquo;), suggesting the method can act as a &ldquo;pocket refiner&rdquo;.</li>
<li><strong>ODE vs. SDE Trade-off</strong>: The deterministic ODE variant achieves better pocket RMSD, while the stochastic SDE variant achieves better Cover Ratio (diversity of holo states captured) and binding affinity. Neither dominates uniformly.</li>
<li><strong>Conformation Sampling Baseline</strong>: Str2Str, a dedicated conformation sampling baseline, performed worse than simply perturbing the apo structure with noise. One interpretation is that this highlights the difficulty of the apo-to-holo prediction task; another is that Str2Str was not designed specifically for apo-to-holo prediction, making it a limited test of its capabilities.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset is derived from <strong>MISATO</strong>, which contains MD trajectories for PDBbind complexes.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td>Curated MISATO</td>
          <td>5,692 complexes</td>
          <td>Filtered for valid MD (<a href="/posts/kabsch-algorithm/">RMSD</a> $&lt; 3\text{\AA}$), clustered to remove redundancy. Contains 46,235 holo-ligand conformations total.</td>
      </tr>
      <tr>
          <td><strong>Apo Structures</strong></td>
          <td>AlphaFold2</td>
          <td>N/A</td>
          <td>Apo structures were obtained by mapping PDB IDs to UniProt and retrieving AlphaFold2 predictions, then aligning to MISATO structures.</td>
      </tr>
      <tr>
          <td><strong>Splits</strong></td>
          <td>Standard</td>
          <td>50 test complexes</td>
          <td>50 complexes with no overlap with the training set selected for testing. Note: 50 is a small held-out set; results should be interpreted cautiously.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Clustering</strong>: Holo-ligand conformations clustered with RMSD threshold $1.0\text{\AA}$; top 10 clusters kept per complex.</li>
<li><strong>Pocket Definition</strong>: Residues within $7\text{\AA}$ of the ligand.</li>
<li><strong>Alignment</strong>: AlphaFold predicted structures (apo) aligned to MISATO holo structures using sequence alignment (Smith-Waterman) to identify pocket residues.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Flow Matching Framework</strong>:</p>
<ul>
<li><strong>Continuous Variables</strong> (Pocket translation/rotation/torsions, Ligand positions): Modeled using <strong>Conditional Flow Matching (CFM)</strong>.
<ul>
<li><em>Prior</em>: Apo state for pocket; Normal distribution for ligand positions.</li>
<li><em>Target</em>: Holo state from MD; Ground truth ligand.</li>
<li><em>Interpolant</em>: Linear interpolation for Euclidean variables; Geodesic for rotations ($SO(3)$, the rotation-only subgroup of SE(3) containing all 3D rotations but not translations); Wrapped linear interpolation for torsions (Torus).</li>
</ul>
</li>
<li><strong>Discrete Variables</strong> (Ligand atom/bond types): Modeled using <strong>Discrete Flow Matching</strong> based on Continuous-Time Markov Chains (CTMC).
<ul>
<li><em>Rate Matrix</em>: Interpolates between mask token and data distribution.</li>
</ul>
</li>
<li><strong>Loss Function</strong>: Weighted sum of 7 losses:
<ol>
<li>Translation CFM (Eq 5)</li>
<li>Rotation CFM (Eq 7)</li>
<li>Torsion CFM (Eq 11)</li>
<li>Ligand Position CFM</li>
<li>Ligand Atom Type CTMC (Eq 14)</li>
<li>Ligand Bond Type CTMC</li>
<li><strong>Interaction Loss</strong> (Eq 18): Explicitly penalizes deviations in pairwise distances between protein and ligand atoms for pairs $\leq 3.5\text{\AA}$.</li>
</ol>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: <strong>DynamicFlow</strong> is a multiscale model with 15.9M parameters.</p>
<ol>
<li><strong>Atom-Level SE(3)-Equivariant GNN</strong>:
<ul>
<li><em>Input</em>: Complex graph (k-NN) and Ligand graph (fully connected).</li>
<li><em>Layers</em>: 6 EGNN blocks modified to maintain node and edge hidden states.</li>
<li><em>Function</em>: Updates ligand positions and predicts ligand atom/bond types.</li>
</ul>
</li>
<li><strong>Residue-Level Transformer</strong>:
<ul>
<li><em>Input</em>: Aggregated atom features from the GNN + Residue frames/torsions.</li>
<li><em>Layers</em>: 4 Transformer blocks with <strong>Invariant Point Attention (IPA)</strong>.</li>
<li><em>Function</em>: Updates protein residue frames (translation/rotation) and predicts side-chain torsions.</li>
</ul>
</li>
</ol>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Vina Score</strong>: <code>vina_minimize</code> mode used for binding affinity.</li>
<li><strong>RMSD</strong>: Minimum RMSD between generated pocket and ground-truth holo conformations.</li>
<li><strong>Cover Ratio</strong>: % of ground-truth holo conformations covered by at least one generated sample (threshold $1.42\text{\AA}$).</li>
<li><strong>POVME 3</strong>: For pocket volume calculation.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Inference Benchmark</strong>: 1x Tesla V100-SXM2-32GB.</li>
<li><strong>Speed</strong>: Generates 10 ligands in ~35-36 seconds (100 NFE), significantly faster than diffusion baselines like Pocket2Mol (980s) or TargetDiff (156s).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhou, X., Xiao, Y., Lin, H., He, X., Guan, J., Wang, Y., Liu, Q., Zhou, F., Wang, L., &amp; Ma, J. (2025). Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://arxiv.org/abs/2503.03989">https://arxiv.org/abs/2503.03989</a></p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhouIntegratingProteinDynamics2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Integrating Protein Dynamics into Structure-Based Drug Design via Full-Atom Stochastic Flows}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhou, Xiangxin and Xiao, Yi and Lin, Haowei and He, Xinheng and Guan, Jiaqi and Wang, Yang and Liu, Qiang and Zhou, Feng and Wang, Liang and Ma, Jianzhu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://arxiv.org/abs/2503.03989}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2503.03989">arXiv Page</a></li>
<li>Code: no public repository available at time of writing</li>
</ul>
]]></content:encoded></item><item><title>ChemDFM-X: Multimodal Foundation Model for Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemdfm-x/</link><pubDate>Sat, 20 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemdfm-x/</guid><description>Multimodal chemical model integrating 5 modalities (2D graphs, 3D conformations, images, MS2/IR spectra) trained on 7.6M instructions.</description><content:encoded><![CDATA[<h2 id="chemdfm-x-contribution-and-architecture">ChemDFM-X Contribution and Architecture</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> contribution.</p>
<p><strong>Method</strong>: The paper proposes a novel &ldquo;Cross-modal Dialogue Foundation Model&rdquo; architecture that aligns five distinct chemical modalities (2D graphs, 3D conformations, images, MS2 spectra, IR spectra) to a single LLM decoder using separate encoders and projection modules. It establishes strong baseline performance across multiple modalities compared against current generalist models.</p>
<p><strong>Resource</strong>: The paper addresses the scarcity of multimodal chemical data by constructing a <strong>7.6M instruction-tuning dataset</strong>. This dataset is largely synthesized from seed SMILES strings using approximate calculations (MMFF94, CFM-ID, Chemprop-IR) and specialist model predictions.</p>
<h2 id="bridging-experimental-data-and-llms">Bridging Experimental Data and LLMs</h2>
<p>Existing chemical AI models generally fall into two distinct categories. Task-specific specialist models achieve high accuracy on singular objectives, such as property prediction or molecular generation, but require strict formatting and lack conversational flexibility. Conversely, early chemical large language models provide natural language interaction but are restricted to text and SMILES strings. ChemDFM-X addresses this gap by enabling large multimodal models to process the experimental characterization data (<a href="https://en.wikipedia.org/wiki/Tandem_mass_spectrometry">MS2 spectra</a> and <a href="https://en.wikipedia.org/wiki/Infrared_spectroscopy">IR spectra</a>) and visual data routinely used in practical chemistry workflows.</p>
<h2 id="synthetic-data-scaling-for-modality-alignment">Synthetic Data Scaling for Modality Alignment</h2>
<p>The core novelty lies in the <strong>&ldquo;Any-to-Text&rdquo; alignment strategy via synthetic data scaling</strong>:</p>
<ol>
<li>
<p><strong>Comprehensive Modality Support</strong>: ChemDFM-X incorporates experimental characterization data (MS2 and IR spectra) alongside 2D graphs, 3D conformations, and images. The data representations are formally defined mathematically rather than as raw pixels:</p>
<ul>
<li><strong>Molecular Graph</strong>: An undirected graph $G = (\textbf{V}, \textbf{E})$ with atom set $\textbf{V}$ and bond set $\textbf{E}$.</li>
<li><strong>Molecular Conformation</strong>: An undirected graph $G = (\textbf{V}&rsquo;, \textbf{E})$ storing spatial coordinates: $\textbf{v}_i = (x_i, y_i, z_i, a_i)$.</li>
<li><strong>MS2 Spectrum</strong>: Treated as a point sequence of discrete mass-to-charge ratios and intensities, tokenized via a discrete codebook: $\textbf{M} = ((r_1, I_1), (r_2, I_2), \dots, (r_n, I_n))$.</li>
<li><strong>IR Spectrum</strong>: Treated as a dense sequence of continuous wave lengths and absorption intensities, directly reshaped for feature extraction: $\textbf{R} = ((w_1, t_1), (w_2, t_2), \dots, (w_l, t_l))$.</li>
</ul>
<p>The authors trained new Sequence Transformer encoders from scratch for the MS2 and IR modalities since suitable pre-trained models did not exist.</p>
</li>
<li>
<p><strong>Synthetic Data Generation Pipeline</strong>: The authors generated a 7.6M sample dataset by starting with 1.3M seed SMILES and using &ldquo;approximate calculations&rdquo; to generate missing modalities:</p>
<ul>
<li>3D conformations via <a href="https://en.wikipedia.org/wiki/Merck_molecular_force_field">MMFF94</a> force field optimization</li>
<li>MS2 spectra via CFM-ID 4.0 (Competitive Fragmentation Modeling)</li>
<li>IR spectra via Chemprop-IR (Message Passing Neural Network)</li>
</ul>
</li>
<li>
<p><strong>Cross-Modal Synergy</strong>: The model demonstrates that training on reaction images improves recognition performance by leveraging semantic chemical knowledge (reaction rules) to correct visual recognition errors, an emergent capability from multimodal training.</p>
</li>
</ol>
<h2 id="multimodal-benchmarking-with-chemllmbench">Multimodal Benchmarking with ChemLLMBench</h2>
<p>The model was evaluated using a customized version of <strong><a href="/notes/chemistry/llm-applications/chemllmbench-eight-chemistry-tasks/">ChemLLMBench</a></strong> and <strong><a href="/notes/chemistry/molecular-design/property-prediction/moleculenet-benchmark-molecular-ml/">MoleculeNet</a></strong> across three modality categories:</p>
<ol>
<li>
<p><strong>Structural Modalities</strong> (2D Graphs &amp; 3D Conformations):</p>
<ul>
<li>Molecule recognition and captioning</li>
<li>Property prediction (MoleculeNet: BACE, BBBP, ClinTox, HIV, Tox21)</li>
<li>Compared against specialist models (Mole-BERT, Uni-Mol, MolXPT, MolCA) and generalist models (3D-MoLM, ChemDFM, <a href="/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/">ChemLLM</a>)</li>
</ul>
</li>
<li>
<p><strong>Visual Modalities</strong> (Images):</p>
<ul>
<li>Single molecule image recognition</li>
<li>Reaction image recognition</li>
<li>Compared against GPT-4O, Gemini 1.5 Pro, Qwen-VL, LLaVA, and specialist models <a href="/notes/chemistry/optical-structure-recognition/vision-language/molnextr/">MolNextr</a> and <a href="/notes/chemistry/optical-structure-recognition/image-to-graph/molscribe/">MolScribe</a></li>
</ul>
</li>
<li>
<p><strong>Characterization Modalities</strong> (MS2 &amp; IR Spectra):</p>
<ul>
<li>Spectral analysis tasks (identifying molecules from spectra)</li>
<li>Contextualized spectral interpretation (combining spectra with reaction context)</li>
<li>Novel evaluation requiring integration of spectroscopic data with reaction knowledge</li>
</ul>
</li>
</ol>
<h2 id="cross-modal-synergy-and-generalist-performance">Cross-Modal Synergy and Generalist Performance</h2>
<p><strong>Key Findings</strong>:</p>
<ol>
<li>
<p><strong>Leading Generalist Performance</strong>: ChemDFM-X establishes a new benchmark among existing generalist models (such as 3D-MOLM and ChemLLM), achieving performance metrics that match dedicated specialist models across several multimodal tasks.</p>
</li>
<li>
<p><strong>Failure of General LMMs</strong>: General vision models (GPT-4O, Gemini 1.5 Pro, Qwen-VL, LLaVA, InternLM-XComposer2, DocOwl) failed significantly on chemical image recognition tasks (0% accuracy for most models on molecule and reaction recognition, Table 9), demonstrating that chemical domain knowledge cannot be assumed from general pre-training.</p>
</li>
<li>
<p><strong>Cross-Modal Error Correction</strong>: In reaction image recognition, ChemDFM-X achieved higher accuracy (53.0%) than on single molecules (46.0%) (Table 9). The authors conclude the model uses its internal knowledge of chemical reaction rules to correct recognition errors in the visual modality, an emergent capability from multimodal training.</p>
</li>
<li>
<p><strong>Reliance on Reaction Context for Spectra</strong>: In zero-shot scenarios, ChemDFM-X essentially fails at pure spectral recognition (achieving 0% and 1% top-1 accuracy on MS2 and IR spectra alone, Table 11). However, when SMILES-based reaction context is included, performance rises to 45% (MS2) and 64% (IR) on the reaction prediction task, and 29% (MS2) and 60% (IR) on <a href="https://en.wikipedia.org/wiki/Retrosynthetic_analysis">retrosynthesis</a> (Table 11). This indicates the model uses spectral data as a soft prior to constrain textual deductions. Furthermore, the paper compares ChemDFM-X’s spectral identification performance exclusively against text-only LLMs that cannot process spectra, omitting comparisons against established specialist tools.</p>
</li>
<li>
<p><strong>Surrogate Distillation Trade-offs</strong>: Because the spectral training data relies entirely on outputs from CFM-ID 4.0 and Chemprop-IR, ChemDFM-X effectively distills these surrogate models. Any inherent predictive biases or inaccuracies from these underlying tools are permanently embedded in the new ChemDFM-X encoders.</p>
</li>
</ol>
<p><strong>Main Conclusion</strong>: The &ldquo;separate encoders + unified decoder&rdquo; architecture with synthetic data generation enables effective multimodal chemical understanding, bridging the gap between specialist and generalist AI systems for chemistry.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors constructed a <strong>7.6M sample instruction-tuning dataset</strong> derived from <strong>1.3M seed SMILES</strong> (sourced from <a href="https://en.wikipedia.org/wiki/PubChem">PubChem</a> and USPTO). <strong>Note</strong>: The final 7.6M multimodal tuning dataset itself isn&rsquo;t publicly available.</p>
<p><strong>Generation Pipeline</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Generation Method</th>
          <th>Tool/Model</th>
          <th>Sample Count</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2D Graphs</strong></td>
          <td>Direct extraction from SMILES</td>
          <td>RDKit</td>
          <td>1.1M</td>
      </tr>
      <tr>
          <td><strong>3D Conformations</strong></td>
          <td>Force field optimization</td>
          <td>RDKit + MMFF94</td>
          <td>1.3M (pseudo-optimal)</td>
      </tr>
      <tr>
          <td><strong>Molecule Images</strong></td>
          <td>Rendering with augmentation</td>
          <td>RDKit, Indigo, <a href="/notes/chemistry/optical-structure-recognition/hand-drawn/chempix/">ChemPix</a></td>
          <td>~1M (including handwritten style)</td>
      </tr>
      <tr>
          <td><strong>Reaction Images</strong></td>
          <td>Rendering from reaction SMILES</td>
          <td>RDKit</td>
          <td>300K</td>
      </tr>
      <tr>
          <td><strong>MS2 Spectra</strong></td>
          <td>Computational prediction</td>
          <td>CFM-ID 4.0</td>
          <td>~700K</td>
      </tr>
      <tr>
          <td><strong>IR Spectra</strong></td>
          <td>Computational prediction</td>
          <td>Chemprop-IR</td>
          <td>~1M</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Augmentation</strong>:</p>
<ul>
<li>Molecule images augmented with &ldquo;handwritten&rdquo; style using the ChemPix pipeline</li>
<li>Multiple rendering styles (RDKit default, Indigo clean)</li>
<li>Spectra generated at multiple energy levels (10eV, 20eV, 40eV for MS2)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Architecture</strong>: &ldquo;Separate Encoders + Unified Decoder&rdquo;</p>
<p><strong>Code Availability</strong>: The authors have only released inference logic. The cross-modal projection training and synthetic data-generation scripts are closed.</p>
<p><strong>Modality Alignment</strong>:</p>
<ul>
<li>Each modality has a dedicated encoder (frozen pre-trained models where available)</li>
<li>For graph, conformation, MS2, and IR modalities: <strong>2-layer MLP projector</strong> (Linear, GELU, Linear) maps encoder features to LLM input space</li>
<li>For images: <strong>H-Reducer</strong> module compresses image tokens by factor of $n=8$ to handle high-resolution chemical images, then projects to LLM input space</li>
<li>All projected features are concatenated and fed to the unified LLM decoder</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Base LLM</strong>:</p>
<ul>
<li><strong>ChemDFM (13B)</strong>: LLaMA-based model pre-trained on chemical text and SMILES</li>
</ul>
<p><strong>Modality Encoders</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Encoder</th>
          <th>Pre-training Data</th>
          <th>Parameter Count</th>
          <th>Status</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>2D Graph</strong></td>
          <td>Mole-BERT</td>
          <td>2M molecules</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>3D Conformation</strong></td>
          <td>Uni-Mol</td>
          <td>209M conformations</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>Image</strong></td>
          <td>CLIP (ViT)</td>
          <td>General domain</td>
          <td>-</td>
          <td>Frozen</td>
      </tr>
      <tr>
          <td><strong>MS2 Spectrum</strong></td>
          <td>Transformer (SeqT)</td>
          <td>Trained from scratch</td>
          <td>-</td>
          <td><strong>Trainable</strong></td>
      </tr>
      <tr>
          <td><strong>IR Spectrum</strong></td>
          <td>Transformer (SeqT)</td>
          <td>Trained from scratch</td>
          <td>-</td>
          <td><strong>Trainable</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Design Rationale</strong>: MS2 and IR encoders trained from scratch as Sequence Transformers treating spectral peaks as token sequences, since no suitable pre-trained models exist for chemical spectra.</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Accuracy (Acc)</strong> for recognition tasks</li>
<li><strong>BLEU-2/4</strong> and <strong>METEOR</strong> for captioning tasks</li>
<li><strong>AUC-ROC</strong> for property prediction (classification)</li>
</ul>
<p><strong>Code Availability</strong>: The adapted code for evaluating on ChemLLMBench and their custom spectral recognition tasks is closed-source.</p>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>ChemLLMBench</strong>: Adapted for multimodal inputs across molecule captioning, property prediction, and reaction understanding</li>
<li><strong>MoleculeNet</strong>: Standard molecular property prediction tasks (BACE, BBBP, ClinTox, HIV, Tox21)</li>
<li><strong>USPTO</strong>: Reaction prediction and retrosynthesis tasks</li>
<li><strong>Custom Spectral Tasks</strong>: Novel evaluations requiring spectral interpretation</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Note</strong>: The type and quantity of GPUs used, along with the total training wall-time, were not published.</p>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Total Batch Size</strong>: 256</li>
<li><strong>Epochs</strong>: 3</li>
<li><strong>Optimizer</strong>: AdamW</li>
</ul>
<p><strong>Modality-Specific Learning Rates (Peak)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Modality</th>
          <th>Learning Rate</th>
          <th>Feature Dimension</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Graph</td>
          <td>1e-5</td>
          <td>300</td>
      </tr>
      <tr>
          <td>Conformation</td>
          <td>2e-4</td>
          <td>512</td>
      </tr>
      <tr>
          <td>Image</td>
          <td>2e-3</td>
          <td>1024</td>
      </tr>
      <tr>
          <td>MS2 / IR</td>
          <td>2e-4</td>
          <td>768</td>
      </tr>
  </tbody>
</table>
<p><strong>Note</strong>: Different learning rates reflect the varying degrees of domain adaptation required. Images (general CLIP) need more adaptation than graphs (chemical Mole-BERT).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/OpenDFM/ChemDFM-X">ChemDFM-X (GitHub)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Inference code only; training and data generation scripts are closed</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/OpenDFM/ChemDFM-X-v1.0-13B">ChemDFM-X-v1.0-13B (HuggingFace)</a></td>
          <td>Model</td>
          <td>AGPL-3.0</td>
          <td>13B parameter multimodal model weights</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhao, Z., Chen, B., Li, J., Chen, L., Wen, L., Wang, P., Zhu, Z., Zhang, D., Wan, Z., Li, Y., Dai, Z., Chen, X., &amp; Yu, K. (2024). ChemDFM-X: Towards Large Multimodal Model for Chemistry. <em>Science China Information Sciences</em>, 67(12), 220109. <a href="https://doi.org/10.1007/s11432-024-4243-0">https://doi.org/10.1007/s11432-024-4243-0</a></p>
<p><strong>Publication</strong>: Science China Information Sciences, December 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2409.13194">arXiv Version</a></li>
<li><a href="https://github.com/OpenDFM/ChemDFM-X">Code Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhaoChemDFMXLargeMultimodal2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{ChemDFM-X}}: {{Towards Large Multimodal Model}} for {{Chemistry}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhao, Zihan and Chen, Bo and Li, Jingpiao and Chen, Lu and Wen, Liyang and Wang, Pengyu and Zhu, Zichen and Zhang, Danyang and Wan, Ziping and Li, Yansi and Dai, Zhongyang and Chen, Xin and Yu, Kai}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = dec,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Science China Information Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{220109}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1007/s11432-024-4243-0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2409.13194}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolSight: OCSR with RL and Multi-Granularity Learning</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/molsight/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/molsight/</guid><description>A three-stage OCSR framework using SMILES pretraining, auxiliary bond/coordinate tasks, and reinforcement learning to master stereochemistry recognition.</description><content:encoded><![CDATA[<h2 id="contribution-a-framework-for-optical-chemical-structure-recognition">Contribution: A Framework for Optical Chemical Structure Recognition</h2>
<p>This is primarily a <strong>Method</strong> paper. It proposes a novel three-stage training framework (Pretraining → Fine-tuning → RL Post-training) to improve Optical Chemical Structure Recognition (OCSR). Specifically, it introduces the use of Group Relative Policy Optimization (GRPO) to solve non-differentiable chemical validity issues.</p>
<p>It also has a <strong>Resource</strong> component, as the authors construct and release <em>Stereo-200k</em>, a dataset specifically designed to train models on challenging stereoisomeric molecules.</p>
<h2 id="motivation-resolving-stereochemical-cues">Motivation: Resolving Stereochemical Cues</h2>
<p>Existing OCSR systems struggle to accurately recognize stereochemical information (e.g., chirality, geometric isomerism) because the visual cues distinguishing stereoisomers (such as wedge and dash bonds) are subtle. Current methods often fail to capture the geometric relationships required to distinguish molecules with identical connectivity but different spatial arrangements. Accurate recognition is critical for downstream tasks like drug discovery where stereochemistry determines pharmacological effects.</p>
<h2 id="core-innovations-grpo-and-multi-granularity-learning">Core Innovations: GRPO and Multi-Granularity Learning</h2>
<p>MolSight introduces three key technical innovations:</p>
<ol>
<li><strong>Reinforcement Learning for OCSR</strong>: It is the first OCSR system to incorporate RL (specifically GRPO) to directly optimize for chemical semantic correctness.</li>
<li><strong>Multi-Granularity Learning</strong>: It employs auxiliary heads for chemical bond classification and atom localization. Unlike previous approaches that optimize these jointly, MolSight decouples the coordinate head to prevent interference with SMILES generation.</li>
<li><strong>SMILES-M Notation</strong>: A lightweight extension to SMILES to handle Markush structures (common in patents) without significant sequence length increase.</li>
</ol>
<h2 id="experimental-methodology">Experimental Methodology</h2>
<p>The authors evaluated MolSight using a rigorous mix of real and synthetic benchmarks:</p>
<ul>
<li><strong>Baselines</strong>: Compared against rule-based (OSRA, MolVec, Imago) and deep learning methods (MolScribe, MolGrapher, DECIMER).</li>
<li><strong>Benchmarks</strong>: Evaluated on real-world datasets (USPTO, Maybridge UoB, CLEF-2012, JPO) and synthetic datasets (Staker, ChemDraw, Indigo, Stereo-2K).</li>
<li><strong>Ablation Studies</strong>: Tested the impact of the bond head, coordinate head, and RL stages separately.</li>
<li><strong>Transfer Learning</strong>: Assessed the quality of learned representations by using the frozen encoder for molecular property prediction on MoleculeNet.</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>SOTA Performance</strong>: MolSight achieved 85.1% stereochemical accuracy on the USPTO dataset, significantly outperforming the previous SOTA (MolScribe) which achieved 69.0%.</li>
<li><strong>RL Effectiveness</strong>: Reinforcement learning post-training specifically improved performance on stereoisomers, raising Tanimoto similarity and exact match rates on the Stereo-2k test set.</li>
<li><strong>Robustness</strong>: On perturbed USPTO images (random rotations and shearing), MolSight achieved 92.3% exact match accuracy (vs. the original 92.0%), while rule-based methods like OSRA dropped from 83.5% to 6.7%. On the low-resolution Staker dataset, MolSight reached 82.1% exact match.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training pipeline uses three distinct data sources:</p>
<ol>
<li><strong>Pre-training</strong>: <em>MolParser-7M</em>. Contains diverse images but requires the <strong>SMILES-M</strong> extension to handle Markush structures.</li>
<li><strong>Fine-tuning</strong>: <em>PubChem-1M</em> and <em>USPTO-680K</em>. Used for multi-granularity learning with bond and coordinate labels.</li>
<li><strong>RL Post-training</strong>: <em>Stereo-200k</em>. A self-collected dataset from the first 2M compounds in PubChem, filtered for chirality (&rsquo;@&rsquo;) and cis-trans isomerism (&rsquo;/&rsquo;, &lsquo;\&rsquo;). It uses 5 different RDKit drawing styles to ensure robustness.</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Reinforcement Learning</strong>: Uses <strong>GRPO (Group Relative Policy Optimization)</strong>.
<ul>
<li><strong>Reward Function</strong>: A linear combination of Tanimoto similarity and a graded stereochemistry reward.
$$ R = w_t \cdot r_{\text{tanimoto}} + w_s \cdot r_{\text{stereo}} $$
where $w_t=0.4$ and $w_s=0.6$. The stereochemistry reward $r_{\text{stereo}}$ is 1.0 for an InChIKey exact match, 0.3 if the atom count matches, and 0.1 otherwise.</li>
<li><strong>Sampling</strong>: Samples 4 completions per image with temperature 1.0 during RL training.</li>
</ul>
</li>
<li><strong>Auxiliary Tasks</strong>:
<ul>
<li><strong>Bond Classification</strong>: Concatenates hidden states of two atom queries to predict bond type via MLP.</li>
<li><strong>Atom Localization</strong>: Treated as a classification task (SimCC) but optimized using <strong>Maximum Likelihood Estimation (MLE)</strong> to account for uncertainty.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder Transformer. Input images are preprocessed to $512 \times 512$ resolution.
<ul>
<li><strong>Encoder</strong>: <strong>EfficientViT-L1</strong> (~53M params), chosen for linear attention efficiency.</li>
<li><strong>Decoder</strong>: 6-layer Transformer with <strong>RoPE</strong>, <strong>SwiGLU</strong>, and <strong>RMSNorm</strong>. Randomly initialized (no LLM weights) due to vocabulary mismatch.</li>
<li><strong>Coordinate Head</strong>: Separated from the main decoder. It adds 2 extra Transformer layers to process atom queries before prediction to improve accuracy.</li>
</ul>
</li>
<li><strong>Parameter Tuning</strong>:
<ul>
<li>Stage 3 (RL) uses <strong>LoRA</strong> (Rank=8, Alpha=16) to optimize the decoder.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Exact Match</strong>: Exact recognition accuracy for the full molecular structure.</li>
<li><strong>Tanimoto Coefficient</strong>: Fingerprint similarity for chemical semantics.</li>
<li><strong>OKS (Object Keypoint Similarity)</strong>: Used specifically for evaluating atom localization accuracy.</li>
</ul>
</li>
<li><strong>Perturbation</strong>: Robustness tested with random rotations [-5°, 5°] and xy-shearing [-0.1, 0.1].</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Training and inference performed on a single node.</li>
<li><strong>Processors</strong>: Intel Xeon Silver 4210R CPU.</li>
<li><strong>Accelerators</strong>: 4x <strong>NVIDIA GeForce RTX 3090/4090</strong> GPUs.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Stage 1: Batch size 512, LR $4 \times 10^{-4}$.</li>
<li>Stage 2: Batch size 256, Bond head LR $4 \times 10^{-4}$, Coord head LR $4 \times 10^{-5}$.</li>
<li>Stage 3 (RL): Batch size 64, Base LR $1 \times 10^{-4}$.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/hustvl/MolSight">MolSight (GitHub)</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch implementation with training and inference code</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, W., Wang, X., Feng, B., &amp; Liu, W. (2025). MolSight: Optical Chemical Structure Recognition with SMILES Pretraining, Multi-Granularity Learning and Reinforcement Learning. In <em>Proceedings of the AAAI Conference on Artificial Intelligence (AAAI 2026)</em>. <a href="https://doi.org/10.48550/arXiv.2511.17300">https://doi.org/10.48550/arXiv.2511.17300</a></p>
<p><strong>Publication</strong>: AAAI 2026</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/hustvl/MolSight">Official Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{zhang2025molsight,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolSight: Optical Chemical Structure Recognition with SMILES Pretraining, Multi-Granularity Learning and Reinforcement Learning}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wenrui Zhang and Xinggang Wang and Bin Feng and Wenyu Liu}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the AAAI Conference on Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2511.17300}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.CV}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2511.17300}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolScribe: Robust Image-to-Graph Molecular Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molscribe/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molscribe/</guid><description>Image-to-graph generation model for OCSR that predicts atoms, bonds, and coordinates jointly to better handle stereochemistry and abbreviations.</description><content:encoded><![CDATA[<h2 id="contribution-generative-image-to-graph-modelling">Contribution: Generative Image-to-Graph Modelling</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$) with a secondary contribution to Resources ($\Psi_{\text{Resource}}$).</p>
<p>It proposes a novel architecture (image-to-graph generation) to solve the Optical Chemical Structure Recognition (OCSR) task, validating it through extensive ablation studies and comparisons against strong baselines like MolVec and DECIMER. It also contributes a new benchmark dataset of annotated images from ACS journals.</p>
<h2 id="motivation-limitations-in-existing-ocsr-pipelines">Motivation: Limitations in Existing OCSR Pipelines</h2>
<p>Translating molecular images into machine-readable graphs (OCSR) is challenging due to the high variance in drawing styles, stereochemistry conventions, and abbreviated structures found in literature.</p>
<p>Existing solutions face structural bottlenecks:</p>
<ul>
<li><strong>Rule-based systems</strong> (e.g., OSRA) rely on rigid heuristics that fail on diverse styles.</li>
<li><strong>Image-to-SMILES neural models</strong> treat the problem as captioning. They struggle with geometric reasoning (which is strictly required for chirality) and struggle to incorporate chemical constraints or verify correctness because they omit explicit atom locations.</li>
</ul>
<h2 id="core-innovation-joint-graph-and-coordinate-prediction">Core Innovation: Joint Graph and Coordinate Prediction</h2>
<p>MolScribe introduces an <strong>Image-to-Graph</strong> generation paradigm that combines the flexibility of neural networks with the precision of symbolic constraints. It frames the task probabilistically as:</p>
<p>$$
P(G | I) = P(A | I) P(B | A, I)
$$</p>
<p>Where the model predicts a sequence of atoms $A$ given an image $I$, followed by the bonds $B$ given both the atoms and the image.</p>
<ol>
<li><strong>Explicit Graph Prediction</strong>: It predicts a sequence of atoms (with 2D coordinates) and then predicts bonds between them.</li>
<li><strong>Symbolic Constraints</strong>: It uses the predicted graph structure and coordinates to strictly determine chirality and cis/trans isomerism.</li>
<li><strong>Abbreviation Expansion</strong>: It employs a greedy algorithm to parse and expand &ldquo;superatoms&rdquo; (e.g., &ldquo;CO2Et&rdquo;) into their full atomic structure.</li>
<li><strong>Dynamic Augmentation</strong>: It introduces a data augmentation strategy that randomly substitutes functional groups with abbreviations and adds R-groups during training to improve generalization.</li>
</ol>
<h2 id="methodology-autoregressive-atoms-and-pairwise-bonds">Methodology: Autoregressive Atoms and Pairwise Bonds</h2>
<p>The authors evaluate MolScribe on synthetic and real-world datasets, focusing on <strong>Exact Match Accuracy</strong> of the canonical SMILES string. The model generates atom sequences autoregressively:</p>
<p>$$
P(A | I) = \prod_{i=1}^n P(a_i | A_{&lt;i}, I)
$$</p>
<p>To handle continuous spatial locations, atom coordinates map to discrete bins (e.g., $\hat{x}_i = \lfloor \frac{x_i}{W} \times n_{\text{bins}} \rfloor$), and decode alongside element labels. Bonds act on a pairwise classifier over the hidden states of every atom pair:</p>
<p>$$
P(B | A, I) = \prod_{i=1}^n \prod_{j=1}^n P(b_{i,j} | A, I)
$$</p>
<ul>
<li><strong>Baselines</strong>: Compared against rule-based (MolVec, OSRA) and neural (Img2Mol, DECIMER, SwinOCSR) systems.</li>
<li><strong>Benchmarks</strong>:
<ul>
<li><strong>Synthetic</strong>: Indigo (in-domain) and ChemDraw (out-of-domain).</li>
<li><strong>Realistic</strong>: Five public benchmarks (CLEF, JPO, UOB, USPTO, Staker).</li>
<li><strong>New Dataset</strong>: 331 images from ACS Publications (journal articles).</li>
</ul>
</li>
<li><strong>Ablations</strong>: Tested performance without data augmentation, with continuous vs. discrete coordinates, and without non-atom tokens.</li>
<li><strong>Human Eval</strong>: Measured the time reduction for chemists using MolScribe to digitize molecules vs. drawing from scratch.</li>
</ul>
<h2 id="results-robust-exact-match-accuracy">Results: Robust Exact Match Accuracy</h2>
<ul>
<li><strong>Strong Performance</strong>: MolScribe achieved <strong>76-93% accuracy</strong> across public benchmarks, outperforming baselines on most datasets. On the ACS dataset of journal article images, MolScribe achieved 71.9% compared to the next best 55.3% (OSRA). On the large Staker patent dataset, MolScribe achieved 86.9%, surpassing MSE-DUDL (77.0%) while using far less training data (1.68M vs. 68M examples).</li>
<li><strong>Chirality Verification</strong>: Explicit geometric reasoning allowed MolScribe to predict chiral molecules significantly better than image-to-SMILES baselines. When chirality is ignored, the performance gap narrows (e.g., on Indigo, baseline accuracy rises from 94.1% to 96.3%), isolating MolScribe&rsquo;s primary advantage to geometric reasoning for stereochemistry.</li>
<li><strong>Hand-Drawn Generalization</strong>: The model achieved <strong>11.2% exact match accuracy</strong> on the DECIMER-HDM dataset, despite lacking hand-drawn images in the training set, with many errors bounded to a few atomic mismatches.</li>
<li><strong>Robustness</strong>: The model maintained high performance on perturbed images (rotation/shear), whereas rule-based systems degraded severely.</li>
<li><strong>Usability</strong>: The atom-level alignment allows for confidence visualization, and human evaluation showed it reduced digitization time from <strong>137s to 20s</strong> per molecule.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a mix of synthetic and patent data with extensive dynamic augmentation:</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td><strong>PubChem (Synthetic)</strong></td>
          <td>1M</td>
          <td>Molecules randomly sampled from PubChem and rendered via Indigo toolkit; includes atom coords.</td>
      </tr>
      <tr>
          <td>Training</td>
          <td><strong>USPTO (Patents)</strong></td>
          <td>680K</td>
          <td>Patent data lacks exact atom coordinates; relative coordinates normalized from MOLfiles to image dimensions (often introduces coordinate shifts).</td>
      </tr>
  </tbody>
</table>
<p><strong>Molecule Augmentation</strong>:</p>
<ul>
<li><strong>Functional Groups</strong>: Randomly substituted using 53 common substitution rules (e.g., replacing substructures with &ldquo;Et&rdquo; or &ldquo;Ph&rdquo;).</li>
<li><strong>R-Groups</strong>: Randomly added using vocabulary: <code>[R, R1...R12, Ra, Rb, Rc, Rd, X, Y, Z, A, Ar]</code>.</li>
<li><strong>Styles</strong>: Random variation of aromaticity (circle vs. bonds) and explicit hydrogens.</li>
</ul>
<p><strong>Image Augmentation</strong>:</p>
<ul>
<li><strong>Rendering</strong>: Randomized font (Arial, Times, Courier, Helvetica), line width, and label modes during synthetic generation.</li>
<li><strong>Perturbations</strong>: Applied rotation ($\pm 90^{\circ}$), cropping ($1%$), padding ($40%$), downscaling, blurring, and Salt-and-Pepper/Gaussian noise.</li>
</ul>
<p><strong>Preprocessing</strong>: Input images are resized to $384 \times 384$.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Atom Prediction (Pix2Seq-style)</strong>:
<ul>
<li>The model generates a sequence of tokens: $S^A = [l_1, \hat{x}_1, \hat{y}_1, \dots, l_n, \hat{x}_n, \hat{y}_n]$.</li>
<li><strong>Discretization</strong>: Coordinates are binned into integer tokens ($n_{bins} = 64$).</li>
<li><strong>Tokenizer</strong>: Atom-wise tokenizer splits SMILES into atoms; non-atom tokens (parentheses, digits) are kept to help structure learning.</li>
</ul>
</li>
<li><strong>Bond Prediction</strong>:
<ul>
<li>Format: Pairwise classification for every pair of predicted atoms.</li>
<li>Symmetry: For symmetric bonds (single/double), the probability is averaged as:
$$
\hat{P}(b_{i,j} = t) = \frac{1}{2} \big( P(b_{i,j} = t) + P(b_{j,i} = t) \big)
$$
For wedges, directional logic strictly applies instead.</li>
</ul>
</li>
<li><strong>Abbreviation Expansion (Algorithm 1)</strong>:
<ul>
<li>A greedy algorithm connects atoms within an expanded abbreviation (e.g., &ldquo;COOH&rdquo;) until valences are full, avoiding the need for a fixed dictionary.</li>
<li><strong>Carbon Chains</strong>: Splits condensed chains like $C_aX_b$ into explicit sequences ($CX_q&hellip;CX_{q+r}$).</li>
<li><strong>Nested Formulas</strong>: Recursively parses nested structures like $N(CH_3)_2$ by treating them as superatoms attached to the current backbone.</li>
<li><strong>Valence Handling</strong>: Iterates through common valences first to resolve ambiguities.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is an encoder-decoder with a classification head:</p>
<ul>
<li><strong>Encoder</strong>: <strong>Swin Transformer (Swin-B)</strong>, pre-trained on ImageNet-22K (88M params).</li>
<li><strong>Decoder</strong>: 6-layer Transformer, 8 heads, hidden dimension 256.</li>
<li><strong>Bond Predictor</strong>: 2-layer MLP (Feedforward) with ReLU, taking concatenated atom hidden states as input.</li>
<li><strong>Training</strong>: Teacher forcing, Cross-Entropy Loss, Batch size 128, 30 epochs.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric</strong>: Exact Match of Canonical SMILES.</p>
<ul>
<li>Stereochemistry: Must match tetrahedral chirality; cis-trans ignored.</li>
<li>R-groups: Replaced with wildcards <code>*</code> or <code>[d*]</code> for evaluation.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Training performed on Linux server with <strong>96 CPUs</strong> and <strong>500GB RAM</strong>.</li>
<li><strong>GPUs</strong>: <strong>4x NVIDIA A100 GPUs</strong>.</li>
<li><strong>Training Time</strong>: Unspecified; comparative models on large datasets took &ldquo;more than one day&rdquo;.</li>
<li><strong>Inference</strong>: Requires autoregressive decoding for atoms, followed by a single forward pass for bonds.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/thomas0809/MolScribe">MolScribe (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with training, inference, and evaluation scripts</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/spaces/yujieq/MolScribe">MolScribe (Hugging Face)</a></td>
          <td>Demo</td>
          <td>MIT</td>
          <td>Interactive web demo for molecular image recognition</td>
      </tr>
  </tbody>
</table>
<h3 id="limitations">Limitations</h3>
<ul>
<li>Scoped to single-molecule images only; does not handle multi-molecule diagrams or reaction schemes.</li>
<li>Hand-drawn molecule recognition remains weak (the model was not trained on hand-drawn data).</li>
<li>Complex Markush structures (positional variation, frequency variation) are not supported, as these cannot be represented in SMILES or MOLfiles.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Qian, Y., Guo, J., Tu, Z., Li, Z., Coley, C. W., &amp; Barzilay, R. (2023). MolScribe: Robust Molecular Structure Recognition with Image-To-Graph Generation. <em>Journal of Chemical Information and Modeling</em>, 63(7), 1925-1934. <a href="https://doi.org/10.1021/acs.jcim.2c01480">https://doi.org/10.1021/acs.jcim.2c01480</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://huggingface.co/spaces/yujieq/MolScribe">Hugging Face Space</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{qianMolScribeRobustMolecular2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MolScribe}}: {{Robust Molecular Structure Recognition}} with {{Image-To-Graph Generation}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MolScribe}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Qian, Yujie and Guo, Jiang and Tu, Zhengkai and Li, Zhening and Coley, Connor W. and Barzilay, Regina}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{63}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1925--1934}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/acs.jcim.2c01480}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://pubs.acs.org/doi/10.1021/acs.jcim.2c01480}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolMole: Unified Vision Pipeline for Molecule Mining</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molmole/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molmole/</guid><description>A vision-based deep learning framework that unifies molecule detection, reaction parsing, and OCSR for page-level chemical data extraction.</description><content:encoded><![CDATA[<h2 id="molmoles-dual-contribution-unified-ocsr-method-and-page-level-benchmarks">MolMole&rsquo;s Dual Contribution: Unified OCSR Method and Page-Level Benchmarks</h2>
<p>This is primarily a <strong>Method</strong> paper, with a strong <strong>Resource</strong> contribution.</p>
<p>It functions as a <strong>Method</strong> paper because it introduces &ldquo;MolMole,&rdquo; a unified deep learning framework that integrates molecule detection, reaction diagram parsing, and optical chemical structure recognition (OCSR) into a single pipeline. It validates this method through extensive comparisons against state-of-the-art baselines like DECIMER and OpenChemIE.</p>
<p>It also serves as a <strong>Resource</strong> paper because the authors construct and release a novel page-level benchmark dataset of 550 annotated pages (patents and articles) to address the lack of standardized evaluation metrics for full-page chemical extraction.</p>
<h2 id="addressing-the-limitations-of-fragmented-processing">Addressing the Limitations of Fragmented Processing</h2>
<p>The rapid accumulation of chemical literature has trapped valuable molecular and reaction data in unstructured formats like images and PDFs. Extracting this manually is time-consuming, while existing AI frameworks have significant limitations:</p>
<ul>
<li><strong>DECIMER</strong>: Lacks the ability to process reaction diagrams entirely.</li>
<li><strong>OpenChemIE</strong>: Relies on external layout parser models to crop elements before processing. This dependence often leads to detection failures in documents with complex layouts.</li>
<li><strong>Generative Hallucination</strong>: Existing generative OCSR models (like MolScribe) are prone to &ldquo;hallucinating&rdquo; structures or failing on complex notations like polymers.</li>
</ul>
<h2 id="a-unified-vision-pipeline-for-layout-aware-detection">A Unified Vision Pipeline for Layout-Aware Detection</h2>
<p>MolMole introduces several architectural and workflow innovations:</p>
<ul>
<li><strong>Direct Page-Level Processing</strong>: Unlike OpenChemIE, MolMole processes full document pages directly without requiring an external layout parser, which improves robustness on complex layouts like two-column patents.</li>
<li><strong>Unified Vision Pipeline</strong>: It integrates three specialized vision models into one workflow:
<ul>
<li><strong>ViDetect</strong>: A DINO-based object detector for identifying molecular regions.</li>
<li><strong>ViReact</strong>: An RxnScribe-based model adapted for full-page reaction parsing.</li>
<li><strong>ViMore</strong>: A detection-based OCSR model that explicitly predicts atoms and bonds.</li>
</ul>
</li>
<li><strong>Hallucination Mitigation</strong>: By using a detection-based approach (ViMore), the model avoids hallucinating chemical structures and provides confidence scores.</li>
<li><strong>Advanced Notation Support</strong>: The system explicitly handles &ldquo;wavy bonds&rdquo; (variable attachments in patents) and polymer bracket notations, which confuse standard SMILES-based models.</li>
</ul>
<h2 id="page-level-benchmark-evaluation-and-unified-metrics">Page-Level Benchmark Evaluation and Unified Metrics</h2>
<p>The authors evaluated the framework on both a newly curated benchmark and existing public datasets:</p>
<ul>
<li><strong>New Benchmark Creation</strong>: They curated 550 pages (300 patents, 250 articles) fully annotated with bounding boxes, reaction roles (reactant, product, condition), and MOLfiles.</li>
<li><strong>Baselines</strong>: MolMole was compared against <strong>DECIMER 2.0</strong>, <strong>OpenChemIE</strong>, and <strong>ReactionDataExtractor 2.0</strong>.</li>
<li><strong>OCSR Benchmarking</strong>: ViMore was evaluated against DECIMER, MolScribe, and MolGrapher on four public datasets: <strong>USPTO</strong>, <strong>UOB</strong>, <strong>CLEF</strong>, and <strong>JPO</strong>.</li>
<li><strong>Metric Proposal</strong>: They introduced a combined &ldquo;End-to-End&rdquo; metric that modifies standard object detection Precision/Recall to strictly require correct SMILES conversion for a &ldquo;True Positive&rdquo;.</li>
</ul>
<p>$$ \text{True Positive (End-to-End)} = ( \text{IoU} \geq 0.5 ) \land ( \text{SMILES}_{\text{gt}} == \text{SMILES}_{\text{pred}} ) $$</p>
<h2 id="key-results">Key Results</h2>
<ul>
<li><strong>Page-Level Performance</strong>: On the new benchmark, MolMole achieved F1 scores of <strong>89.1%</strong> (Patents) and <strong>86.8%</strong> (Articles) for the combined detection-to-conversion task, compared to 73.8% and 67.3% for DECIMER and 68.8% and 70.6% for OpenChemIE (Table 4).</li>
<li><strong>Reaction Parsing</strong>: ViReact achieved soft-match F1 scores of <strong>98.0%</strong> on patents and <strong>97.0%</strong> on articles, compared to 82.2% and 82.9% for the next best model, RxnScribe (w/o LP). Hard-match F1 scores were 92.5% (patents) and 84.6% (articles).</li>
<li><strong>Public Benchmarks</strong>: ViMore outperformed competitors on 3 out of 4 public OCSR datasets (CLEF, JPO, USPTO).</li>
<li><strong>Layout Handling</strong>: The authors demonstrated that MolMole successfully handles multi-column reaction diagrams where cropping-based models fail and faithfully preserves layout geometry in generated MOLfiles.</li>
</ul>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://lgai-ddu.github.io/molmole/">MolMole Project Page</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Demo and project information</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training Data</strong>: The models (ViDetect and ViMore) were trained on <strong>private/proprietary datasets</strong>, which is a limitation for full reproducibility from scratch.</li>
<li><strong>Benchmark Data</strong>: The authors introduce a test set of <strong>550 pages</strong> (3,897 molecules, 1,022 reactions) derived from patents and scientific articles. This dataset is stated to be made &ldquo;publicly available&rdquo;.</li>
<li><strong>Public Evaluation Data</strong>: Standard OCSR datasets used include USPTO (5,719 images), UOB (5,740 images), CLEF (992 images), and JPO (450 images).</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Pipeline Workflow</strong>: PDF → PNG Images → Parallel execution of <strong>ViDetect</strong> and <strong>ViReact</strong> → Cropping of molecular regions → <strong>ViMore</strong> conversion → Output (JSON/Excel).</li>
<li><strong>Post-Processing</strong>:
<ul>
<li><em>ViDetect</em>: Removes overlapping proposals based on confidence scores and size constraints.</li>
<li><em>ViReact</em>: Refines predictions by correcting duplicates and removing empty entities.</li>
<li><em>ViMore</em>: Assembles detected atom/bond information into structured representations (MOLfile).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Architecture Basis</th>
          <th>Task</th>
          <th>Key Feature</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ViDetect</strong></td>
          <td>DINO (DETR-based)</td>
          <td>Molecule Detection</td>
          <td>End-to-end training; avoids slow autoregressive methods.</td>
      </tr>
      <tr>
          <td><strong>ViReact</strong></td>
          <td>RxnScribe</td>
          <td>Reaction Parsing</td>
          <td>Operates on full pages; autoregressive decoder for structured sequence generation.</td>
      </tr>
      <tr>
          <td><strong>ViMore</strong></td>
          <td>Custom Vision Model</td>
          <td>OCSR</td>
          <td>Detection-based (predicts atom/bond regions).</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Molecule Detection</strong>: Evaluated using COCO metrics (AP, AR, F1) at IoU thresholds 0.50-0.95.</li>
<li><strong>Molecule Conversion</strong>: Evaluated using SMILES exact match accuracy and Tanimoto similarity.</li>
<li><strong>Combined Metric</strong>: A custom metric where a True Positive requires both IoU $\geq$ 0.5 and a correct SMILES string match where $\text{SMILES}_{\text{gt}} == \text{SMILES}_{\text{pred}}$.</li>
<li><strong>Reaction Parsing</strong>: Evaluated using <strong>Hard Match</strong> (all components correct) and <strong>Soft Match</strong> (molecular entities only, ignoring text labels).</li>
</ul>
<h3 id="missing-components">Missing Components</h3>
<ul>
<li><strong>Source code</strong>: Not publicly released. The paper states the toolkit &ldquo;will be accessible soon through an interactive demo on the LG AI Research website.&rdquo; For commercial use, the authors direct inquiries to contact <a href="mailto:ddu@lgresearch.ai">ddu@lgresearch.ai</a>.</li>
<li><strong>Training data</strong>: ViDetect and ViMore are trained on proprietary datasets. Training code and data are not available.</li>
<li><strong>Hardware requirements</strong>: Not specified in the paper.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chun, S., Kim, J., Jo, A., Jo, Y., Oh, S., et al. (2025). MolMole: Molecule Mining from Scientific Literature. <em>arXiv preprint arXiv:2505.03777</em>. <a href="https://doi.org/10.48550/arXiv.2505.03777">https://doi.org/10.48550/arXiv.2505.03777</a></p>
<p><strong>Publication</strong>: arXiv 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://lgai-ddu.github.io/molmole/">Project Page</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chun2025molmole,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{MolMole: Molecule Mining from Scientific Literature}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Chun, Sehyun and Kim, Jiye and Jo, Ahra and Jo, Yeonsik and Oh, Seungyul and Lee, Seungjun and Ryoo, Kwangrok and Lee, Jongmin and Kim, Seung Hwan and Kang, Byung Jun and Lee, Soonyoung and Park, Jun Ha and Moon, Chanwoo and Ham, Jiwon and Lee, Haein and Han, Heejae and Byun, Jaeseung and Do, Soojong and Ha, Minju and Kim, Dongyun and Bae, Kyunghoon and Lim, Woohyung and Lee, Edward Hwayoung and Park, Yongmin and Yu, Jeongsang and Jo, Gerrard Jeongwon and Hong, Yeonjung and Yoo, Kyungjae and Han, Sehui and Lee, Jaewan and Park, Changyoung and Jeon, Kijeong and Yi, Sihyuk}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2505.03777}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2505.03777}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2505.03777}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolGrapher: Graph-based Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molgrapher/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/molgrapher/</guid><description>A graph-based deep learning approach for optical chemical structure recognition that outperforms image captioning methods.</description><content:encoded><![CDATA[<h2 id="1-contribution--type">1. Contribution / Type</h2>
<p>This is primarily a <strong>Methodological</strong> paper that proposes a novel neural architecture (MolGrapher), shifting the paradigm of Optical Chemical Structure Recognition (OCSR) from image captioning back to graph reconstruction. It also has a significant <strong>Resource</strong> component, releasing a synthetic data generation pipeline and a new large-scale benchmark (USPTO-30K) to address the scarcity of annotated real-world data.</p>
<h2 id="2-motivation">2. Motivation</h2>
<p>The automatic analysis of chemical literature is critical for accelerating drug and material discovery, but much of this information is locked in 2D images of molecular structures.</p>
<ul>
<li><strong>Problem:</strong> Existing rule-based methods are rigid, while recent deep learning methods based on &ldquo;image captioning&rdquo; (predicting <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings) struggle with complex molecules and fail to exploit the natural graph structure of molecules.</li>
<li><strong>Gap:</strong> There is a lack of diverse, annotated real-world training data, and captioning models suffer from &ldquo;hallucinations&rdquo; where they predict valid SMILES that do not match the image.</li>
</ul>
<h2 id="3-novelty--core-innovation">3. Novelty / Core Innovation</h2>
<p>MolGrapher introduces a <strong>graph-based deep learning pipeline</strong> that explicitly models the molecule&rsquo;s geometry and topology.</p>
<ul>
<li><strong>Supergraph Concept:</strong> It first detects all atom keypoints and builds a &ldquo;supergraph&rdquo; of all plausible bonds.</li>
<li><strong>Hybrid Approach:</strong> It combines a ResNet-based keypoint detector with a Graph Neural Network (GNN) that classifies both atom nodes and bond nodes within the supergraph context. Both atoms and bonds are represented as nodes, with edges only connecting atom nodes to bond nodes.</li>
<li><strong>Synthetic Pipeline:</strong> A data generation pipeline that renders molecules with varying styles (fonts, bond widths) and augmentations (pepper patches, random lines, captions) to simulate real document noise.</li>
</ul>
<p>At the core of the Keypoint Detector&rsquo;s performance is the <strong>Weight-Adaptive Heatmap Regression (WAHR)</strong> loss. Since pixels without an atom drastically outnumber pixels containing an atom, WAHR loss is designed to counter the class imbalance. For ground-truth heatmap $y$ and prediction $p$:</p>
<p>$$ L_{WAHR}(p, y) = \sum_i \alpha_y (p_i - y_i)^2 $$</p>
<p>where $\alpha_y$ dynamically down-weights easily classified background pixels.</p>
<h2 id="4-methodology--experiments">4. Methodology &amp; Experiments</h2>
<p>The authors evaluated MolGrapher against both rule-based (OSRA, MolVec) and deep learning baselines (DECIMER, Img2Mol, Image2Graph).</p>
<ul>
<li><strong>Benchmarks:</strong> Evaluated on standard datasets: USPTO, Maybridge UoB, CLEF-2012, and JPO.</li>
<li><strong>New Benchmark:</strong> Introduced and tested on <strong>USPTO-30K</strong>, split into clean, abbreviated, and large molecule subsets.</li>
<li><strong>Ablations:</strong> Analyzed the impact of synthetic augmentations, keypoint loss functions, supergraph connectivity radius, and GNN layers.</li>
<li><strong>Robustness:</strong> Tested on perturbed images (rotations, shearing) to mimic scanned patent quality.</li>
</ul>
<p>The GNN iteratively updates node embeddings through layers ${g^k}_{k \in [1, N]}$, where $e^{k+1} = g^k(e^k)$. Final predictions are obtained via two MLPs (one for atoms, one for bonds): $p_i = MLP_t(e_i^N)$, where $p_i \in \mathbb{R}^{C_t}$ contains the logits for atom or bond classes.</p>
<h2 id="5-results--conclusions">5. Results &amp; Conclusions</h2>
<p>MolGrapher achieved the highest accuracy among synthetic-only deep learning methods on most benchmarks tested.</p>
<ul>
<li><strong>Accuracy:</strong> It achieved <strong>91.5%</strong> accuracy on USPTO, outperforming all other synthetic-only deep learning methods including ChemGrapher (80.9%), Graph Generation (67.0%), and DECIMER 2.0 (61.0%).</li>
<li><strong>Large Molecules:</strong> It demonstrated superior scaling, correctly recognizing large molecules (USPTO-10K-L) where image captioning methods like Img2Mol failed completely (0.0% accuracy).</li>
<li><strong>Generalization:</strong> The method proved robust to image perturbations and style variations without requiring fine-tuning on real data. The paper acknowledges that MolGrapher cannot recognize Markush structures (depictions of sets of molecules with positional and frequency variation indicators).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model relies on synthetic data for training due to the scarcity of annotated real-world images.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Synthetic Data</td>
          <td>300,000 images</td>
          <td>Generated from PubChem SMILES using RDKit. Augmentations include pepper patches, random lines, and variable bond styles.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>USPTO-30K</td>
          <td>30,000 images</td>
          <td>Created by authors from USPTO patents (2001-2020). Subsets: 10K clean, 10K abbreviated, 10K large (&gt;70 atoms).</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Standard Benchmarks</td>
          <td>Various</td>
          <td>USPTO (5,719), Maybridge UoB (5,740), CLEF-2012 (992), JPO (450).</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The pipeline consists of three distinct algorithmic stages:</p>
<ol>
<li>
<p><strong>Keypoint Detection</strong>:</p>
<ul>
<li>Predicts a heatmap of atom locations using a CNN.</li>
<li>Thresholds heatmaps at the bottom 10th percentile and uses a $5\times5$ window for local maxima.</li>
<li>Uses <strong>Weight-Adaptive Heatmap Regression (WAHR)</strong> loss to handle class imbalance (background vs. atoms).</li>
</ul>
</li>
<li>
<p><strong>Supergraph Construction</strong>:</p>
<ul>
<li>Connects every detected keypoint to neighbors within a radius of $3 \times$ the estimated bond length.</li>
<li>Prunes edges with no filled pixels or if obstructed by a third keypoint.</li>
<li>Keeps a maximum of 6 bond candidates per atom.</li>
</ul>
</li>
<li>
<p><strong>Superatom Recognition</strong>:</p>
<ul>
<li>Detects &ldquo;superatom&rdquo; nodes (abbreviations like <code>COOH</code>).</li>
<li>Uses <strong>PP-OCR</strong> to transcribe the text at these node locations.</li>
</ul>
</li>
</ol>
<h3 id="models">Models</h3>
<p>The architecture utilizes standard backbones tailored for specific sub-tasks:</p>
<ul>
<li><strong>Keypoint Detector</strong>: <strong>ResNet-18</strong> backbone with $8\times$ dilation to preserve spatial resolution.</li>
<li><strong>Node Classifier</strong>: <strong>ResNet-50</strong> backbone with $2\times$ dilation for extracting visual features at node locations.</li>
<li><strong>Graph Neural Network</strong>: A custom GNN that updates node embeddings based on visual features and neighborhood context. The initial node embedding combines the visual feature vector $v_i$ and a learnable type encoding $w_{t_i}$.</li>
<li><strong>Readout</strong>: MLPs classify nodes into atom types (e.g., C, O, N) and bond types (No Bond, Single, Double, Triple).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Accuracy is defined strictly: the predicted molecule must have an identical <strong><a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a></strong> string to the ground truth. Stereochemistry and Markush structures are excluded from evaluation.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>MolGrapher Score</th>
          <th>Best DL Baseline (Synthetic)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>USPTO</td>
          <td><strong>91.5%</strong></td>
          <td>80.9% (ChemGrapher)</td>
          <td>Full USPTO benchmark</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>USPTO-10K-L</td>
          <td><strong>31.4%</strong></td>
          <td>0.0% (Img2Mol)</td>
          <td>Large molecules (&gt;70 atoms)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td>JPO</td>
          <td><strong>67.5%</strong></td>
          <td>64.0% (DECIMER 2.0)</td>
          <td>Challenging, low-quality images</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: Trained on 3 NVIDIA A100 GPUs.</li>
<li><strong>Training Time</strong>: 20 epochs.</li>
<li><strong>Optimization</strong>: ADAM optimizer, learning rate 0.0001, decayed by 0.8 after 5000 iterations.</li>
<li><strong>Loss Weighting</strong>: Atom classifier loss weighted by 1; bond classifier loss weighted by 3.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DS4SD/MolGrapher">DS4SD/MolGrapher</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation with training and inference scripts</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Title</strong>: MolGrapher: Graph-based Visual Recognition of Chemical Structures</p>
<p><strong>Authors</strong>: Lucas Morin, Martin Danelljan, Maria Isabel Agea, Ahmed Nassar, Valéry Weber, Ingmar Meijer, Peter Staar, Fisher Yu</p>
<p><strong>Citation</strong>: Morin, L., Danelljan, M., Agea, M. I., Nassar, A., Weber, V., Meijer, I., Staar, P., &amp; Yu, F. (2023). MolGrapher: Graph-based Visual Recognition of Chemical Structures. <em>Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)</em>, 19552-19561.</p>
<p><strong>Publication</strong>: ICCV 2023</p>
<p><strong>Links</strong>:</p>
<ul>
<li><a href="https://openaccess.thecvf.com/content/ICCV2023/html/Morin_MolGrapher_Graph-based_Visual_Recognition_of_Chemical_Structures_ICCV_2023_paper.html">Paper</a></li>
<li><a href="https://github.com/DS4SD/MolGrapher">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{morinMolGrapherGraphbasedVisual2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MolGrapher}}: {{Graph-based Visual Recognition}} of {{Chemical Structures}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MolGrapher}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the {{IEEE}}/{{CVF International Conference}} on {{Computer Vision}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Danelljan, Martin and Agea, Maria Isabel and Nassar, Ahmed and Weber, Valéry and Meijer, Ingmar and Staar, Peter and Yu, Fisher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{19552--19561}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/ICCV51070.2023.01791}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-10-18}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MMSSC-Net: Multi-Stage Sequence Cognitive Networks</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/mmssc-net/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/mmssc-net/</guid><description>A deep learning model for Optical Chemical Structure Recognition (OCSR) using SwinV2 and GPT-2 to convert molecular images to SMILES.</description><content:encoded><![CDATA[<h2 id="contribution-a-multi-stage-architectural-pipeline">Contribution: A Multi-Stage Architectural Pipeline</h2>
<p><strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong>.
The paper proposes a deep learning architecture (<strong>MMSSC-Net</strong>) for Optical Chemical Structure Recognition (OCSR). It focuses on architectural innovation, specifically combining a SwinV2 visual encoder with a GPT-2 decoder, and validates this method through extensive benchmarking against existing rule-based and deep-learning baselines. It includes ablation studies to justify the choice of the visual encoder.</p>
<h2 id="motivation-addressing-noise-and-rigid-image-recognition">Motivation: Addressing Noise and Rigid Image Recognition</h2>
<ul>
<li><strong>Data Usage Gap</strong>: Drug discovery relies heavily on scientific literature, but molecular structures are often locked in vector graphics or images that computers cannot easily process.</li>
<li><strong>Limitations of Prior Work</strong>: Existing Rule-based methods are rigid and sensitive to noise. Previous Deep Learning approaches (Encoder-Decoder &ldquo;Image Captioning&rdquo; styles) often lack precision, interpretability, and struggle with varying image resolutions or large molecules.</li>
<li><strong>Need for &ldquo;Cognition&rdquo;</strong>: The authors argue that treating the image as a single isolated whole is insufficient; a model needs to &ldquo;perceive&rdquo; fine-grained details (atoms and bonds) to handle noise and varying pixel qualities effectively.</li>
</ul>
<h2 id="novelty-a-fine-grained-perception-pipeline">Novelty: A Fine-Grained Perception Pipeline</h2>
<ul>
<li><strong>Multi-Stage Cognitive Architecture</strong>: MMSSC-Net splits the task into stages:
<ol>
<li><strong>Fine-grained Perception</strong>: Detecting atom and bond sequences (including spatial coordinates) using SwinV2.</li>
<li><strong>Graph Construction</strong>: Assembling these into a molecular graph.</li>
<li><strong>Sequence Evolution</strong>: converting the graph into a machine-readable format (SMILES).</li>
</ol>
</li>
<li><strong>Hybrid Transformer Model</strong>: It combines a hierarchical vision transformer (<strong>SwinV2</strong>) for encoding with a generative pre-trained transformer (<strong>GPT-2</strong>) and MLPs for decoding atomic and bond targets.</li>
<li><strong>Robustness Mechanisms</strong>: The inclusion of random noise sequences during training to improve generalization to new molecular targets.</li>
</ul>
<h2 id="methodology-and-benchmarks">Methodology and Benchmarks</h2>
<ul>
<li><strong>Baselines</strong>: compared against 8 other tools:
<ul>
<li><em>Rule-based</em>: MolVec, OSRA.</li>
<li><em>Image-Smiles (DL)</em>: ABC-Net, Img2Mol, MolMiner.</li>
<li><em>Image-Graph-Smiles (DL)</em>: Image-To-Graph, MolScribe, ChemGrapher.</li>
</ul>
</li>
<li><strong>Datasets</strong>: Evaluated on 5 diverse datasets: STAKER (synthetic), USPTO, CLEF, JPO, and UOB (real-world).</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Accuracy</strong>: Exact string match of the predicted SMILES.</li>
<li><strong>Tanimoto Similarity</strong>: Chemical similarity using Morgan fingerprints.</li>
</ul>
</li>
<li><strong>Ablation Study</strong>: Tested different visual encoders (Swin Transformer, ViT-B, ResNet-50) to validate the choice of SwinV2.</li>
<li><strong>Resolution Sensitivity</strong>: Tested model performance across image resolutions from 256px to 2048px.</li>
</ul>
<h2 id="results-and-core-outcomes">Results and Core Outcomes</h2>
<ul>
<li><strong>Strong Performance</strong>: MMSSC-Net achieved 75-98% accuracy across datasets, outperforming baselines on most benchmarks. The first three intra-domain and real datasets achieved above 94% accuracy.</li>
<li><strong>Resolution Robustness</strong>: The model maintained relatively stable accuracy across varying image resolutions, whereas baselines like Img2Mol showed greater sensitivity to resolution changes (Fig. 4 in the paper).</li>
<li><strong>Efficiency</strong>: The SwinV2 encoder was noted to be more efficient than ViT-B in this context.</li>
<li><strong>Limitations</strong>: The model struggles with stereochemistry, specifically confusing dashed wedge bonds with solid wedge bonds and misclassifying single bonds as solid wedge bonds. It also has difficulty with &ldquo;irrelevant text&rdquo; noise (e.g., unexpected symbols in JPO and DECIMER datasets).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a combination of PubChem and USPTO data, augmented to handle visual variability.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>PubChem</strong></td>
          <td>1,000,000</td>
          <td>Converted from <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> to SMILES; random sampling.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>USPTO</strong></td>
          <td>600,000</td>
          <td>Patent images; converted from MOL to SMILES.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>STAKER</strong></td>
          <td>40,000</td>
          <td>Synthetic; Avg res $256 \times 256$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>USPTO</strong></td>
          <td>4,862</td>
          <td>Real; Avg res $721 \times 432$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>CLEF</strong></td>
          <td>881</td>
          <td>Real; Avg res $1245 \times 412$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>JPO</strong></td>
          <td>380</td>
          <td>Real; Avg res $614 \times 367$.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>UOB</strong></td>
          <td>5,720</td>
          <td>Real; Avg res $759 \times 416$.</td>
      </tr>
  </tbody>
</table>
<p><strong>Augmentation</strong>:</p>
<ul>
<li><strong>Image</strong>: Random perturbations using RDKit/Indigo (rotation, filling, cropping, bond thickness/length, font size, Gaussian noise).</li>
<li><strong>Molecular</strong>: Introduction of functional group abbreviations and R-substituents (dummy atoms) using SMARTS templates.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Target Sequence Formulation</strong>: The model predicts a sequence containing bounding box coordinates and type labels: ${y_{\text{min}}, x_{\text{min}}, y_{\text{max}}, x_{\text{max}}, C_{n}}$.</li>
<li><strong>Loss Function</strong>: Cross-entropy loss with maximum likelihood estimation.
$$ \max \sum_{i=1}^{N} \sum_{j=1}^{L} \omega_{j} \log P(t_{j}^{i} \mid x_{1}^{i}, x_{2}^{i}, \dots, x_{M}^{i}, t_{1}^{i}, \dots, t_{j-1}^{i}) $$</li>
<li><strong>Noise Injection</strong>: A random sequence $T_r$ is appended to the target sequence during training to improve generalization to new goals.</li>
<li><strong>Graph Construction</strong>: Atoms ($v$) and bonds ($e$) are recognized separately; bonds are defined by connecting spatial atomic coordinates.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Encoder</strong>: <strong>Swin Transformer V2</strong>.
<ul>
<li>Pre-trained on ImageNet-1K.</li>
<li>Window size: $16 \times 16$.</li>
<li>Parameters: 88M.</li>
<li>Input resolution: $256 \times 256$.</li>
<li>Features: Scaled cosine attention; log-space continuous position bias.</li>
</ul>
</li>
<li><strong>Decoder</strong>: <strong>GPT-2</strong> + <strong>MLP</strong>.
<ul>
<li><strong>GPT-2</strong>: Used for recognizing atom types.
<ul>
<li>Layers: 24.</li>
<li>Attention Heads: 12.</li>
<li>Hidden Dimension: 768.</li>
<li>Dropout: 0.1.</li>
</ul>
</li>
<li><strong>MLP</strong>: Used for classifying bond types (single, double, triple, aromatic, solid wedge, dashed wedge).</li>
</ul>
</li>
<li><strong>Vocabulary</strong>:
<ul>
<li>Standard: 95 common numbers/characters ([0], [C], [=], etc.).</li>
<li>Extended: 2000 SMARTS-based characters for isomers/groups (e.g., &ldquo;[C2F5]&rdquo;, &ldquo;[halo]&rdquo;).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ol>
<li><strong>Accuracy</strong>: Exact match of the generated SMILES string.</li>
<li><strong>Tanimoto Similarity</strong>: Similarity of Morgan fingerprints between predicted and ground truth molecules.</li>
</ol>
<p><strong>Key Results (Accuracy)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MMSSC-Net</th>
          <th>MolVec (Rule)</th>
          <th>ABC-Net (DL)</th>
          <th>MolScribe (DL)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Indigo</strong></td>
          <td>98.14</td>
          <td>95.63</td>
          <td>96.4</td>
          <td>97.5</td>
      </tr>
      <tr>
          <td><strong>RDKit</strong></td>
          <td>94.91</td>
          <td>86.7</td>
          <td>98.3</td>
          <td>93.8</td>
      </tr>
      <tr>
          <td><strong>USPTO</strong></td>
          <td>94.24</td>
          <td>88.47</td>
          <td>*</td>
          <td>92.6</td>
      </tr>
      <tr>
          <td><strong>CLEF</strong></td>
          <td>91.26</td>
          <td>81.61</td>
          <td>*</td>
          <td>86.9</td>
      </tr>
      <tr>
          <td><strong>UOB</strong></td>
          <td>92.71</td>
          <td>81.32</td>
          <td>96.1</td>
          <td>87.9</td>
      </tr>
      <tr>
          <td><strong>Staker</strong></td>
          <td>89.44</td>
          <td>4.49</td>
          <td>*</td>
          <td>86.9</td>
      </tr>
      <tr>
          <td><strong>JPO</strong></td>
          <td>75.48</td>
          <td>66.8</td>
          <td>*</td>
          <td>76.2</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Configuration</strong>:
<ul>
<li>Batch Size: 128.</li>
<li>Learning Rate: $4 \times 10^{-5}$.</li>
<li>Epochs: 40.</li>
</ul>
</li>
<li><strong>Inference Speed</strong>: The SwinV2 encoder demonstrated higher efficiency (faster inference time) compared to ViT-B and ResNet-50 baselines during ablation.</li>
</ul>
<h3 id="reproducibility">Reproducibility</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Wzew5Lp/MMSSCNet">MMSSCNet (GitHub)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official implementation; includes training and prediction scripts</td>
      </tr>
  </tbody>
</table>
<p>The paper is published in RSC Advances (open access). Source code is available on GitHub, though the repository has minimal documentation and no explicit license. The training data comes from PubChem (public) and USPTO (public patent data). Pre-trained model weights do not appear to be released. No specific GPU hardware or training time is reported in the paper.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, D., Zhao, D., Wang, Z., Li, J., &amp; Li, J. (2024). MMSSC-Net: multi-stage sequence cognitive networks for drug molecule recognition. <em>RSC Advances</em>, 14(26), 18182-18191. <a href="https://doi.org/10.1039/D4RA02442G">https://doi.org/10.1039/D4RA02442G</a></p>
<p><strong>Publication</strong>: RSC Advances 2024</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhangMMSSCNetMultistageSequence2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{MMSSC-Net: Multi-Stage Sequence Cognitive Networks for Drug Molecule Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{MMSSC-Net}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Dehai and Zhao, Di and Wang, Zhengwu and Li, Junhui and Li, Jin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{RSC Advances}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{26}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{18182--18191}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D4RA02442G}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://pubs.rsc.org/en/content/articlelanding/2024/ra/d4ra02442g}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MarkushGrapher: Multi-modal Markush Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/markushgrapher/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/markushgrapher/</guid><description>Multi-modal transformer combining vision, text, and layout encoding to extract complex Markush structures from patent documents with OCSR.</description><content:encoded><![CDATA[<h2 id="overcoming-unimodal-limitations-for-markush-structures">Overcoming Unimodal Limitations for Markush Structures</h2>
<p>The automated analysis of chemical literature, particularly patents, is critical for drug discovery and material science. A major bottleneck is the extraction of <strong>Markush structures</strong>, which are complex chemical templates that represent families of molecules using a core backbone image and textual variable definitions. Existing methods are limited because they either rely solely on images (OCSR) and miss the textual context, or focus solely on text and miss the structural backbone. This creates a practical need for a unified, multi-modal approach that jointly interprets visual and textual data to accurately extract these structures for prior-art search and database construction. This paper proposes a <strong>Method</strong> and introduces a new <strong>Resource</strong> (M2S dataset) to bridge this gap.</p>
<h2 id="markushgrapher-the-multi-modal-architecture">MarkushGrapher: The Multi-Modal Architecture</h2>
<p>The core innovation is <strong>MarkushGrapher</strong>, a multi-modal architecture that jointly encodes image, text, and layout information. Key contributions include:</p>
<ul>
<li><strong>Dual-Encoder Architecture</strong>: Combines a Vision-Text-Layout (VTL) encoder (based on UDOP) with a specialized, pre-trained Optical Chemical Structure Recognition (OCSR) encoder (MolScribe). Let $E_{\text{VTL}}$ represent the combined sequence embedding and $E_{\text{OCSR}}$ represent the domain-specific visual embeddings.</li>
<li><strong>Joint Recognition</strong>: The model autoregressively generates a sequential graph representation (Optimized CXSMILES) and a substituent table simultaneously. It uses cross-modal dependencies, allowing text to clarify ambiguous visual details like bond types.</li>
<li><strong>Synthetic Data Pipeline</strong>: A comprehensive pipeline generates realistic synthetic Markush structures (images and text) from PubChem data, overcoming the lack of labeled training data.</li>
<li><strong>Optimized Representation</strong>: A compacted version of CXSMILES moves variable groups into the SMILES string and adds explicit atom indexing to handle complex &ldquo;frequency&rdquo; and &ldquo;position&rdquo; variation indicators.</li>
</ul>
<h2 id="experimental-validation-on-the-new-m2s-benchmark">Experimental Validation on the New M2S Benchmark</h2>
<p>The authors validated their approach using the following setup:</p>
<ul>
<li><strong>Baselines</strong>: Compared against image-only chemistry models (DECIMER, MolScribe) and general-purpose multi-modal models (Uni-SMART, GPT-4o, Pixtral, Llama-3.2).</li>
<li><strong>Datasets</strong>: Evaluated on three benchmarks:
<ol>
<li><strong>MarkushGrapher-Synthetic</strong>: 1,000 generated samples.</li>
<li><strong>M2S</strong>: A new benchmark of 103 manually annotated real-world patent images.</li>
<li><strong>USPTO-Markush</strong>: 74 Markush backbone images from USPTO patents.</li>
</ol>
</li>
<li><strong>Ablation Studies</strong>: Analyzed the impact of the OCSR encoder, late fusion strategies, and the optimized CXSMILES format. Late fusion improved USPTO-Markush EM from 23% (VTL only) to 32% (Table 3). Removing R-group compression dropped M2S EM from 38% to 30%, and removing atom indexing dropped USPTO-Markush EM from 32% to 24% (Table 4).</li>
</ul>
<h2 id="key-results">Key Results</h2>
<ul>
<li><strong>Performance</strong>: MarkushGrapher outperformed all baselines. On the M2S benchmark, it achieved 38% Exact Match on CXSMILES (compared to 21% for MolScribe) and 29% Exact Match on tables. On USPTO-Markush, it reached 32% CXSMILES EM versus 7% for MolScribe.</li>
<li><strong>Markush Feature Recognition</strong>: The model can recognize complex Markush features like frequency variation (&lsquo;Sg&rsquo;) and position variation (&rsquo;m&rsquo;) indicators. DECIMER and MolScribe scored 0% on both &rsquo;m&rsquo; and &lsquo;Sg&rsquo; sections (Table 2), while MarkushGrapher achieved 76% on &rsquo;m&rsquo; and 31% on &lsquo;Sg&rsquo; sections on M2S.</li>
<li><strong>Cross-Modal Reasoning</strong>: Qualitative analysis showed the model can correctly infer visual details (such as bond order) that appear ambiguous in the image but become apparent with the text description.</li>
<li><strong>Robustness</strong>: The model generalizes well to real-world data despite being trained purely on synthetic data. On augmented versions of M2S and USPTO-Markush simulating low-quality scanned documents, it maintained 31% and 32% CXSMILES EM respectively (Table 6).</li>
</ul>
<h2 id="limitations">Limitations</h2>
<p>The authors note several limitations:</p>
<ul>
<li>MarkushGrapher does not currently handle abbreviations in chemical structures (e.g., &lsquo;OG&rsquo; for oxygen connected to a variable group).</li>
<li>The model relies on ground-truth OCR cells as input, requiring an external OCR model for practical deployment.</li>
<li>Substituent definitions that combine text with interleaved chemical structure drawings are not supported.</li>
<li>The model is trained to predict &rsquo;m&rsquo; sections connecting to all atoms in a cycle, which can technically violate valence constraints, though the output contains enough information to reconstruct only valid connections.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong></p>
<ul>
<li><strong>Source</strong>: Synthetic dataset generated from PubChem SMILES.</li>
<li><strong>Size</strong>: 210,000 synthetic images.</li>
<li><strong>Pipeline</strong>:
<ol>
<li><strong>Selection</strong>: Sampled SMILES from PubChem based on substructure diversity.</li>
<li><strong>Augmentation</strong>: SMILES augmented to artificial CXSMILES using RDKit (inserting variable groups, frequency indicators).</li>
<li><strong>Rendering</strong>: Images rendered using Chemistry Development Kit (CDK) with randomized drawing parameters (font, bond width, spacing).</li>
<li><strong>Text Generation</strong>: Textual definitions generated using manual templates extracted from patents; 10% were paraphrased using Mistral-7B-Instruct-v0.3 to increase diversity.</li>
<li><strong>OCR</strong>: Bounding boxes extracted via a custom SVG parser aligned with MOL files.</li>
</ol>
</li>
</ul>
<p><strong>Evaluation Data</strong></p>
<ul>
<li><strong>M2S Dataset</strong>: 103 images from USPTO, EPO, and WIPO patents (1999-2023), manually annotated with CXSMILES and substituent tables.</li>
<li><strong>USPTO-Markush</strong>: 74 images from USPTO patents (2010-2016).</li>
<li><strong>MarkushGrapher-Synthetic</strong>: 1,000 samples generated via the pipeline.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimized CXSMILES</strong>:
<ul>
<li><strong>Compression</strong>: Variable groups moved from the extension block to the main SMILES string as special atoms to reduce sequence length.</li>
<li><strong>Indexing</strong>: Atom indices appended to each atom (e.g., <code>C:1</code>) to explicitly link the graph to the extension block (crucial for <code>m</code> and <code>Sg</code> sections).</li>
<li><strong>Vocabulary</strong>: Specific tokens used for atoms and bonds.</li>
</ul>
</li>
<li><strong>Augmentation</strong>: Standard image augmentations (shift, scale, blur, pepper noise, random lines) and OCR text augmentations (character substitution/insertion/deletion).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder Transformer.
<ul>
<li><strong>VTL Encoder</strong>: T5-large encoder (initialized from UDOP) that processes image patches, text tokens, and layout (bounding boxes).</li>
<li><strong>OCSR Encoder</strong>: Vision encoder from MolScribe (Swin Transformer), frozen during training.</li>
<li><strong>Text Decoder</strong>: T5-large decoder.</li>
</ul>
</li>
<li><strong>Fusion Strategy</strong>: <strong>Late Fusion</strong>. The core multi-modal alignment combines the textual layout features with specialized chemical vision explicitly. The fused representation relies on the VTL output $e_1$ concatenated with the MLP-projected OCSR output $e_2$ before decoding:
$$ e = e_1(v, t, l) \oplus \text{MLP}(e_2(v)) $$</li>
<li><strong>Parameters</strong>: 831M total (744M trainable).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>CXSMILES Exact Match (EM)</strong>: Requires perfect match of SMILES string, variable groups, <code>m</code> sections, and <code>Sg</code> sections (ignoring stereochemistry).</li>
<li><strong>Tanimoto Score</strong>: Similarity of RDKit DayLight fingerprints (Markush features removed).</li>
<li><strong>Table Exact Match</strong>: All variable groups and substituents must match.</li>
<li><strong>Table F1-Score</strong>: Aggregated recall and precision of substituents per variable group.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Trained on a single NVIDIA H100 GPU.</li>
<li><strong>Training Config</strong>: 10 epochs, batch size of 10, ADAM optimizer, learning rate 5e-4, 100 warmup steps, weight decay 1e-3.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DS4SD/MarkushGrapher">MarkushGrapher</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Morin, L., Weber, V., Nassar, A., Meijer, G. I., Van Gool, L., Li, Y., &amp; Staar, P. (2025). MarkushGrapher: Joint Visual and Textual Recognition of Markush Structures. <em>2025 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)</em>, 14505-14515. <a href="https://doi.org/10.1109/CVPR52734.2025.01352">https://doi.org/10.1109/CVPR52734.2025.01352</a></p>
<p><strong>Publication</strong>: CVPR 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/DS4SD/MarkushGrapher">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{morinMarkushGrapherJointVisual2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{MarkushGrapher: Joint Visual and Textual Recognition of Markush Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{MarkushGrapher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{2025 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Weber, Valéry and Nassar, Ahmed and Meijer, Gerhard Ingmar and Van Gool, Luc and Li, Yawei and Staar, Peter}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jun,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{14505--14515}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1109/CVPR52734.2025.01352}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image2InChI: SwinTransformer for Molecular Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/image2inchi/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/image2inchi/</guid><description>Deep learning model using improved SwinTransformer encoder and attention-based feature fusion to convert molecular images to InChI strings.</description><content:encoded><![CDATA[<h2 id="image2inchi-as-a-methodological-innovation">Image2InChI as a Methodological Innovation</h2>
<p>This is a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong>. It proposes a specific new deep learning architecture (&ldquo;Image2InChI&rdquo;) to solve the task of Optical Chemical Structure Recognition (OCSR). The rhetorical focus is on engineering a system that outperforms baselines on specific metrics (InChI accuracy, MCS accuracy) and providing a valuable reference for future algorithmic work.</p>
<h2 id="bottlenecks-in-chemical-literature-digitization">Bottlenecks in Chemical Literature Digitization</h2>
<p>The accurate digitization of chemical literature is a bottleneck in AI-driven drug discovery. Chemical structures in patents and papers exist as optical images (pixels), but machine learning models require machine-readable string representations (like <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a> or <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>). Efficiently and automatically bridging this gap is a prerequisite for large-scale data mining in chemistry.</p>
<h2 id="hierarchical-swintransformer-and-attention-integration">Hierarchical SwinTransformer and Attention Integration</h2>
<p>The core novelty is the <strong>Image2InChI</strong> architecture, which integrates:</p>
<ol>
<li><strong>Improved SwinTransformer Encoder</strong>: Uses a hierarchical vision transformer to capture image features.</li>
<li><strong>Feature Fusion with Attention</strong>: A novel network designed to integrate image patch features with InChI prediction steps.</li>
<li><strong>End-to-End InChI Prediction</strong>: The architecture frames the problem as a direct image-to-sequence translation targeting InChI strings directly, diverging from techniques predicting independent graph components. The model is optimized using a standard Cross-Entropy Loss over the token vocabulary:
$$ \mathcal{L}_{\text{CE}} = - \sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{X}) $$
where $\mathbf{X}$ represents the input image features, $y_t$ is the predicted token, and $T$ is the sequence length.</li>
</ol>
<h2 id="benchmarking-on-the-bms-dataset">Benchmarking on the BMS Dataset</h2>
<ul>
<li><strong>Benchmark Validation</strong>: The model was trained and tested on the <strong>BMS1000 (Bristol-Myers Squibb)</strong> dataset from a Kaggle competition.</li>
<li><strong>Ablation/Comparative Analysis</strong>: The authors compared their method against other models in the supplement.</li>
<li><strong>Preprocessing Validation</strong>: They justified their choice of denoising algorithms (8-neighborhood vs. Gaussian/Mean) to ensure preservation of bond lines while removing &ldquo;spiky point noise&rdquo;.</li>
</ul>
<h2 id="high-inchi-recognition-metrics">High InChI Recognition Metrics</h2>
<ul>
<li><strong>High Accuracy</strong>: The model achieved <strong>99.8% InChI accuracy</strong>, 94.8% Maximum Common Substructure (MCS) accuracy, and 96.2% Longest Common Subsequence (LCS) accuracy on the benchmarked dataset. It remains to be seen how well these models generalize to heavily degraded real-world patent images.</li>
<li><strong>Effective Denoising</strong>: The authors concluded that <strong>eight-neighborhood filtering</strong> is superior to mean or Gaussian filtering for this specific domain because it removes isolated noise points without blurring the fine edges of chemical bonds.</li>
<li><strong>Open Source</strong>: The authors stated their intention to release the code, though no public repository has been identified.</li>
</ul>
<hr>
<h2 id="artifacts">Artifacts</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.kaggle.com/c/bms-molecular-translation">BMS Dataset (Kaggle)</a></td>
          <td>Dataset</td>
          <td>Competition</td>
          <td>Bristol-Myers Squibb Molecular Translation competition dataset</td>
      </tr>
  </tbody>
</table>
<p>No public code repository has been identified for Image2InChI despite the authors&rsquo; stated intent to release it.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The primary dataset used is the <strong>BMS (Bristol-Myers Squibb) Dataset</strong>.</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Details</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Source</strong></td>
          <td>Kaggle Competition (BMS-Molecular-Translation)</td>
      </tr>
      <tr>
          <td><strong>Total Size</strong></td>
          <td>2.4 million images</td>
      </tr>
      <tr>
          <td><strong>Training Set</strong></td>
          <td>1.8 million images</td>
      </tr>
      <tr>
          <td><strong>Test Set</strong></td>
          <td>0.6 million images</td>
      </tr>
      <tr>
          <td><strong>Content</strong></td>
          <td>Each image corresponds to a unique International Chemical Identifier (<a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>)</td>
      </tr>
  </tbody>
</table>
<p><strong>Other Datasets</strong>: The authors also utilized JPO (Japanese Patent Office), CLEF (CLEF-IP 2012), UOB (MolrecUOB), and USPTO datasets for broader benchmarking.</p>
<p><strong>Preprocessing Pipeline</strong>:</p>
<ol>
<li><strong>Denoising</strong>: <strong>Eight-neighborhood filtering</strong> (threshold &lt; 4 non-white pixels) is used to remove salt-and-pepper noise while preserving bond lines. Mean and Gaussian filtering were rejected due to blurring.</li>
<li><strong>Sequence Padding</strong>:
<ul>
<li>Analysis showed max InChI length &lt; 270.</li>
<li>Fixed sequence length set to <strong>300</strong>.</li>
<li>Tokens: <code>&lt;sos&gt;</code> (190), <code>&lt;eos&gt;</code> (191), <code>&lt;pad&gt;</code> (192) used for padding/framing.</li>
</ul>
</li>
<li><strong>Numerization</strong>: Characters are mapped to integers based on a fixed vocabulary (e.g., &lsquo;C&rsquo; -&gt; 178, &lsquo;H&rsquo; -&gt; 182).</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Eight-Neighborhood Filtering (Denoising)</strong>:</p>
<p>Pseudocode logic:</p>
<ul>
<li>Iterate through every pixel.</li>
<li>Count non-white neighbors in the 3x3 grid (8 neighbors).</li>
<li>If count &lt; threshold (default 4), treat as noise and remove.</li>
</ul>
<p><strong>InChI Tokenization</strong>:</p>
<ul>
<li>InChI strings are split into character arrays.</li>
<li>Example: Vitamin C <code>InChI=1S/C6H8O6...</code> becomes <code>[&lt;sos&gt;, C, 6, H, 8, O, 6, ..., &lt;eos&gt;, &lt;pad&gt;...]</code>.</li>
<li>Mapped to integer tensor for model input.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Image2InChI</p>
<ul>
<li><strong>Encoder</strong>: Improved SwinTransformer (Hierarchical Vision Transformer).</li>
<li><strong>Decoder</strong>: Transformer Decoder with patch embedding.</li>
<li><strong>Fusion</strong>: A novel &ldquo;feature fusion network with attention&rdquo; integrates the visual tokens with the sequence generation process.</li>
<li><strong>Framework</strong>: PyTorch 1.8.1.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>InChI Acc</strong>: Exact match accuracy of the predicted InChI string (Reported: 99.8%).</li>
<li><strong>MCS Acc</strong>: Maximum Common Substructure accuracy (structural similarity) (Reported: 94.8%).</li>
<li><strong>LCS Acc</strong>: Longest Common Subsequence accuracy (string similarity) (Reported: 96.2%).</li>
<li><strong>Morgan FP</strong>: Morgan Fingerprint similarity (Reported: 94.1%).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Specification</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>GPU</strong></td>
          <td>NVIDIA Tesla P100 (16GB VRAM)</td>
      </tr>
      <tr>
          <td><strong>Platform</strong></td>
          <td>MatPool cloud platform</td>
      </tr>
      <tr>
          <td><strong>CPU</strong></td>
          <td>Intel Xeon Gold 6271</td>
      </tr>
      <tr>
          <td><strong>RAM</strong></td>
          <td>32GB System Memory</td>
      </tr>
      <tr>
          <td><strong>Driver</strong></td>
          <td>NVIDIA-SMI 440.100</td>
      </tr>
      <tr>
          <td><strong>OS</strong></td>
          <td>Ubuntu 18.04</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, D., Xu, X., Pan, J., Gao, W., &amp; Zhang, S. (2024). Image2InChI: Automated Molecular Optical Image Recognition. <em>Journal of Chemical Information and Modeling</em>, 64(9), 3640-3649. <a href="https://doi.org/10.1021/acs.jcim.3c02082">https://doi.org/10.1021/acs.jcim.3c02082</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling (JCIM) 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.kaggle.com/c/bms-molecular-translation">BMS Dataset (Kaggle)</a></li>
</ul>
<p><strong>Note</strong>: These notes are based on the Abstract and Supporting Information files only.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{li2024image2inchi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Image2InChI: Automated Molecular Optical Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Da-zhou and Xu, Xin and Pan, Jia-heng and Gao, Wei and Zhang, Shi-rui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{64}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{9}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{3640--3649}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{American Chemical Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.3c02082}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Enhanced DECIMER for Hand-Drawn Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/decimer-hand-drawn/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/decimer-hand-drawn/</guid><description>An improved encoder-decoder model (EfficientNetV2 + Transformer) converts hand-drawn chemical structures into SMILES strings using synthetic training data.</description><content:encoded><![CDATA[<h2 id="method-contribution-architectural-optimization">Method Contribution: Architectural Optimization</h2>
<p>This is a <strong>Method</strong> paper. It proposes an enhanced neural network architecture (EfficientNetV2 + Transformer) specifically designed to solve the problem of recognizing hand-drawn chemical structures. The primary contribution is architectural optimization and a data-driven training strategy, validated through ablation studies (comparing encoders) and benchmarked against existing rule-based and deep learning tools.</p>
<h2 id="motivation-digitizing-dark-chemical-data">Motivation: Digitizing &ldquo;Dark&rdquo; Chemical Data</h2>
<p>Chemical information in legacy laboratory notebooks and modern tablet-based inputs often exists as hand-drawn sketches.</p>
<ul>
<li><strong>Gap:</strong> Existing Optical Chemical Structure Recognition (OCSR) tools (particularly rule-based ones) lack robustness and fail when images have variability in style, line thickness, or noise.</li>
<li><strong>Need:</strong> There is a critical need for automated tools to digitize this &ldquo;dark data&rdquo; effectively to preserve it and make it machine-readable and searchable.</li>
</ul>
<h2 id="core-innovation-decoder-only-design-and-synthetic-scaling">Core Innovation: Decoder-Only Design and Synthetic Scaling</h2>
<p>The core novelty is the <strong>architectural enhancement</strong> and <strong>synthetic training strategy</strong>:</p>
<ol>
<li><strong>Decoder-Only Transformer:</strong> Using only the decoder part of the Transformer (instead of a full encoder-decoder Transformer) improved average accuracy across OCSR benchmarks from 61.28% to 69.27% (Table 3 in the paper).</li>
<li><strong>EfficientNetV2 Integration:</strong> Replacing standard CNNs or EfficientNetV1 with <strong>EfficientNetV2-M</strong> provided better feature extraction and 2x faster training speeds.</li>
<li><strong>Scale of Synthetic Data:</strong> The authors demonstrate that scaling synthetic training data (up to 152 million images generated by RanDepict) directly correlates with improved generalization to real-world hand-drawn images, without ever training on real hand-drawn data.</li>
</ol>
<h2 id="experimental-setup-ablation-and-real-world-baselines">Experimental Setup: Ablation and Real-World Baselines</h2>
<ul>
<li><strong>Model Selection (Ablation):</strong> Tested three architectures (EfficientNetV2-M + Full Transformer, EfficientNetV1-B7 + Decoder-only, EfficientNetV2-M + Decoder-only) on standard benchmarks (JPO, CLEF, USPTO, UOB).</li>
<li><strong>Data Scaling:</strong> Trained the best model on four progressively larger datasets (from 4M to 152M images) to measure performance gains.</li>
<li><strong>Real-World Benchmarking:</strong> Validated the final model on the <strong>DECIMER Hand-drawn dataset</strong> (5088 real images drawn by volunteers) and compared against 9 other tools (OSRA, MolVec, Img2Mol, MolScribe, etc.).</li>
</ul>
<h2 id="results-and-conclusions-strong-accuracy-on-hand-drawn-scans">Results and Conclusions: Strong Accuracy on Hand-Drawn Scans</h2>
<ul>
<li><strong>Strong Performance:</strong> The final DECIMER model achieved <strong>99.72% valid predictions</strong> and <strong>73.25% exact accuracy</strong> on the hand-drawn benchmark. The next best non-DECIMER tool was MolGrapher at 10.81% accuracy, followed by MolScribe at 7.65%.</li>
<li><strong>Robustness:</strong> Deep learning methods outperform rule-based methods (which scored 3% or less accuracy) on hand-drawn data.</li>
<li><strong>Data Saturation:</strong> Quadrupling the dataset from 38M to 152M images yielded only marginal gains (about 3 percentage points in accuracy), suggesting current synthetic data strategies may be hitting a plateau.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER Image Transformer (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official TensorFlow implementation</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10781330">Model Weights (Zenodo)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Pre-trained hand-drawn model weights</td>
      </tr>
      <tr>
          <td><a href="https://pypi.org/project/decimer/">DECIMER PyPi Package</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Installable Python package</td>
      </tr>
      <tr>
          <td><a href="https://github.com/OBrink/RanDepict">RanDepict (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Synthetic hand-drawn image generation toolkit</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The model was trained entirely on <strong>synthetic data</strong> generated using the <a href="https://github.com/OBrink/RanDepict">RanDepict</a> toolkit. No real hand-drawn images were used for training.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Source</th>
          <th>Molecules</th>
          <th>Total Images</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1</td>
          <td>ChEMBL</td>
          <td>2,187,669</td>
          <td>4,375,338</td>
          <td>1 augmented + 1 clean per molecule</td>
      </tr>
      <tr>
          <td>2</td>
          <td>ChEMBL</td>
          <td>2,187,669</td>
          <td>13,126,014</td>
          <td>2 augmented + 4 clean per molecule</td>
      </tr>
      <tr>
          <td>3</td>
          <td>PubChem</td>
          <td>9,510,000</td>
          <td>38,040,000</td>
          <td>1 augmented + 3 clean per molecule</td>
      </tr>
      <tr>
          <td>4</td>
          <td>PubChem</td>
          <td>38,040,000</td>
          <td><strong>152,160,000</strong></td>
          <td>1 augmented + 3 clean per molecule</td>
      </tr>
  </tbody>
</table>
<p>A separate <strong>model selection</strong> experiment used a 1,024,000-molecule subset of ChEMBL to compare the three architectures (Table 1 in the paper). The <strong>DECIMER Hand-Drawn</strong> evaluation dataset consists of 5,088 real hand-drawn images from 23 volunteers.</p>
<p><strong>Preprocessing:</strong></p>
<ul>
<li><a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings length &lt; 300 characters.</li>
<li>Images resized to $512 \times 512$.</li>
<li>Images generated with and without &ldquo;hand-drawn style&rdquo; augmentations.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization:</strong> SMILES split by heavy atoms, brackets, bond symbols, and special characters. Start <code>&lt;start&gt;</code> and end <code>&lt;end&gt;</code> tokens added; padded with <code>&lt;pad&gt;</code>.</li>
<li><strong>Optimization:</strong> Adam optimizer with a custom learning rate schedule (as specified in the original Transformer paper). A dropout rate of 0.1 was used.</li>
<li><strong>Loss Function:</strong> Trained using focal loss to address class imbalance for rare tokens. The focal loss formulation reduces the relative loss for well-classified examples:
$$
\text{FL}(p_{\text{t}}) = -\alpha_{\text{t}} (1 - p_{\text{t}})^\gamma \log(p_{\text{t}})
$$</li>
<li><strong>Augmentations:</strong> RanDepict applied synthetic distortions to mimic handwriting (wobbly lines, variable thickness, etc.).</li>
</ul>
<h3 id="models">Models</h3>
<p>The final architecture (Model 3) is an Encoder-Decoder structure:</p>
<ul>
<li><strong>Encoder:</strong> <strong>EfficientNetV2-M</strong> (pretrained ImageNet backbone).
<ul>
<li>Input: $512 \times 512 \times 3$ image.</li>
<li>Output Features: $16 \times 16 \times 512$ (reshaped to sequence length 256, dimension 512).</li>
<li><em>Note:</em> The final fully connected layer of the CNN is removed.</li>
</ul>
</li>
<li><strong>Decoder:</strong> <strong>Transformer (Decoder-only)</strong>.
<ul>
<li>Layers: 6</li>
<li>Attention Heads: 8</li>
<li>Embedding Dimension: 512</li>
</ul>
</li>
<li><strong>Output:</strong> Predicted SMILES string token by token.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics used for evaluation:</p>
<ol>
<li><strong>Valid Predictions (%):</strong> Percentage of outputs that are syntactically valid SMILES.</li>
<li><strong>Exact Match Accuracy (%):</strong> Canonical SMILES string identity.</li>
<li><strong>Tanimoto Similarity:</strong> Fingerprint similarity (PubChem fingerprints) between ground truth and prediction.</li>
</ol>
<p><strong>Data Scaling Results (Hand-Drawn Dataset, Table 4 in the paper):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Training Images</th>
          <th>Valid Predictions</th>
          <th>Exact Accuracy</th>
          <th>Tanimoto</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1 (ChEMBL)</td>
          <td>4,375,338</td>
          <td>96.21%</td>
          <td>5.09%</td>
          <td>0.490</td>
      </tr>
      <tr>
          <td>2 (ChEMBL)</td>
          <td>13,126,014</td>
          <td>97.41%</td>
          <td>26.08%</td>
          <td>0.690</td>
      </tr>
      <tr>
          <td>3 (PubChem)</td>
          <td>38,040,000</td>
          <td>99.67%</td>
          <td>70.34%</td>
          <td>0.939</td>
      </tr>
      <tr>
          <td>4 (PubChem)</td>
          <td>152,160,000</td>
          <td>99.72%</td>
          <td>73.25%</td>
          <td>0.942</td>
      </tr>
  </tbody>
</table>
<p><strong>Comparison with Other Tools (Hand-Drawn Dataset, Table 5 in the paper):</strong></p>
<table>
  <thead>
      <tr>
          <th>OCSR Tool</th>
          <th>Method</th>
          <th>Valid Predictions</th>
          <th>Exact Accuracy</th>
          <th>Tanimoto</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>DECIMER (Ours)</strong></td>
          <td>Deep Learning</td>
          <td><strong>99.72%</strong></td>
          <td><strong>73.25%</strong></td>
          <td><strong>0.94</strong></td>
      </tr>
      <tr>
          <td>DECIMER.ai</td>
          <td>Deep Learning</td>
          <td>96.07%</td>
          <td>26.98%</td>
          <td>0.69</td>
      </tr>
      <tr>
          <td>MolGrapher</td>
          <td>Deep Learning</td>
          <td>99.94%</td>
          <td>10.81%</td>
          <td>0.51</td>
      </tr>
      <tr>
          <td>MolScribe</td>
          <td>Deep Learning</td>
          <td>95.66%</td>
          <td>7.65%</td>
          <td>0.59</td>
      </tr>
      <tr>
          <td>Img2Mol</td>
          <td>Deep Learning</td>
          <td>98.96%</td>
          <td>5.25%</td>
          <td>0.52</td>
      </tr>
      <tr>
          <td>SwinOCSR</td>
          <td>Deep Learning</td>
          <td>97.37%</td>
          <td>5.11%</td>
          <td>0.64</td>
      </tr>
      <tr>
          <td>ChemGrapher</td>
          <td>Deep Learning</td>
          <td>69.56%</td>
          <td>N/A</td>
          <td>0.09</td>
      </tr>
      <tr>
          <td>Imago</td>
          <td>Rule-based</td>
          <td>43.14%</td>
          <td>2.99%</td>
          <td>0.22</td>
      </tr>
      <tr>
          <td>MolVec</td>
          <td>Rule-based</td>
          <td>71.86%</td>
          <td>1.30%</td>
          <td>0.23</td>
      </tr>
      <tr>
          <td>OSRA</td>
          <td>Rule-based</td>
          <td>54.66%</td>
          <td>0.57%</td>
          <td>0.17</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute:</strong> Google Cloud TPU v4-128 pod slice.</li>
<li><strong>Training Time:</strong>
<ul>
<li>EfficientNetV2-M model trained ~2x faster than EfficientNetV1-B7.</li>
<li>Average training time per epoch: 34 minutes (for Model 3 on 1M dataset subset).</li>
</ul>
</li>
<li><strong>Epochs:</strong> Models trained for 25 epochs.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Brinkhaus, H.O., Zielesny, A. et al. (2024). Advancements in hand-drawn chemical structure recognition through an enhanced DECIMER architecture. <em>Journal of Cheminformatics</em>, 16(78). <a href="https://doi.org/10.1186/s13321-024-00872-7">https://doi.org/10.1186/s13321-024-00872-7</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://pypi.org/project/decimer/">PyPi Package</a></li>
<li><a href="https://doi.org/10.5281/zenodo.10781330">Model Weights (Zenodo)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanAdvancementsHanddrawnChemical2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Advancements in Hand-Drawn Chemical Structure Recognition through an Enhanced {{DECIMER}} Architecture}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Brinkhaus, Henning Otto and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2024</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{78}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00872-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Dual-Path Global Awareness Transformer (DGAT) for OCSR</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/dgat/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/dgat/</guid><description>A Transformer-based OCSR model introducing dual-path modules (CGFE and SDGLA) to improve global context awareness and complex motif recognition.</description><content:encoded><![CDATA[<h2 id="contribution-type-deep-learning-method-for-ocsr">Contribution Type: Deep Learning Method for OCSR</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>The classification is based on the proposal of a novel deep learning architecture (DGAT) designed to address specific limitations in existing Optical Chemical Structure Recognition (OCSR) systems. The contribution is validated through benchmarking against external baselines (DeepOCSR, DECIMER, SwinOCSR) and ablation studies that isolate the impact of the new modules.</p>
<h2 id="motivation-addressing-global-context-loss">Motivation: Addressing Global Context Loss</h2>
<p>Existing multimodal fusion methods for OCSR suffer from limited awareness of global context.</p>
<ul>
<li><strong>Problem</strong>: Models often generate erroneous sequences when processing complex motifs, such as rings or long chains, due to a disconnect between local feature extraction and global structural understanding.</li>
<li><strong>Gap</strong>: Current architectures struggle to capture the &ldquo;fine-grained differences between global and local features,&rdquo; leading to topological errors.</li>
<li><strong>Practical Need</strong>: Accurate translation of chemical images to machine-readable sequences (SMILES/SELFIES) is critical for materials science and AI-guided chemical research.</li>
</ul>
<h2 id="core-innovation-dual-path-global-awareness-transformer">Core Innovation: Dual-Path Global Awareness Transformer</h2>
<p>The authors propose the <strong>Dual-Path Global Awareness Transformer (DGAT)</strong>, which redesigns the decoder with two novel mechanisms to better handle global context:</p>
<ol>
<li>
<p><strong>Cascaded Global Feature Enhancement (CGFE)</strong>: This module bridges cross-modal gaps by emphasizing global context. It concatenates global visual features with sequence features and processes them through a Cross-Modal Assimilation MLP and an Adaptive Alignment MLP to align multimodal representations. The feature enhancement conceptually computes:</p>
<p>$$ f_{\text{enhanced}} = \text{MLP}_{\text{align}}(\text{MLP}_{\text{assimilate}}([f_{\text{global}}, f_{\text{seq}}])) $$</p>
</li>
<li>
<p><strong>Sparse Differential Global-Local Attention (SDGLA)</strong>: A module that dynamically captures fine-grained differences between global and local features. It uses sequence features (embedded with global info) as queries, while utilizing local and global visual features as keys/values in parallel attention heads to generate initial multimodal features.</p>
</li>
</ol>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The model was evaluated on a newly constructed dataset and compared against five major baselines.</p>
<ul>
<li><strong>Baselines</strong>: DeepOCSR, DECIMER 1.0, DECIMER V2, SwinOCSR, and MPOCSR.</li>
<li><strong>Ablation Studies</strong>:
<ul>
<li><strong>Layer Depth</strong>: Tested Transformer depths from 1 to 5 layers; 3 layers proved optimal for balancing gradient flow and parameter sufficiency.</li>
<li><strong>Beam Size</strong>: Tested inference beam sizes 1-5; size 3 achieved the best balance between search depth and redundancy.</li>
<li><strong>Module Contribution</strong>: Validated that removing CGFE results in a drop in structural similarity (Tanimoto), proving the need for pre-fusion alignment.</li>
</ul>
</li>
<li><strong>Robustness Analysis</strong>: Performance broken down by molecule complexity (atom count, ring count, bond count).</li>
<li><strong>Chirality Validation</strong>: Qualitative analysis of attention maps on chiral molecules to verify the model learns stereochemical cues implicitly.</li>
</ul>
<h2 id="results-and-conclusions">Results and Conclusions</h2>
<ul>
<li><strong>Performance Over Baselines</strong>: DGAT outperformed the MPOCSR baseline across all metrics:
<ul>
<li><strong>BLEU-4</strong>: 84.0% (+5.3% improvement)</li>
<li><strong>ROUGE</strong>: 90.8% (+1.9% improvement)</li>
<li><strong>Tanimoto Similarity</strong>: 98.8% (+1.2% improvement)</li>
<li><strong>Exact Match Accuracy</strong>: 54.6% (+10.9% over SwinOCSR)</li>
</ul>
</li>
<li><strong>Chiral Recognition</strong>: The model implicitly recognizes chiral centers (e.g., generating <code>[C@@H1]</code> tokens correctly) based on 2D wedge cues without direct stereochemical supervision.</li>
<li><strong>Limitations</strong>: Performance drops for extreme cases, such as molecules with 4+ rings or 4+ double/triple bonds, due to dataset imbalance. The model still hallucinates branches in highly complex topologies.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is primarily drawn from PubChem and augmented to improve robustness.</p>
<ul>
<li><strong>Augmentation Strategy</strong>: Each sequence generates three images with random rendering parameters.
<ul>
<li><strong>Rotation</strong>: 0, 90, 180, 270, or random [0, 360)</li>
<li><strong>Bond Width</strong>: 1, 2, or 3 pixels</li>
<li><strong>Bond Offset</strong>: Sampled from 0.08-0.18 (inherited from Image2SMILES)</li>
<li><strong>CoordGen</strong>: Enabled with 20% probability</li>
</ul>
</li>
<li><strong>Evaluation Set</strong>: A newly constructed benchmark dataset was used for final reporting.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Training Configuration</strong>:
<ul>
<li><strong>Encoder LR</strong>: $5 \times 10^{-5}$ (Pretrained ResNet-101)</li>
<li><strong>Decoder LR</strong>: $1 \times 10^{-4}$ (Randomly initialized Transformer)</li>
<li><strong>Optimizer</strong>: Implied SGD/Adam (context mentions Momentum 0.9, Weight Decay 0.0001)</li>
<li><strong>Batch Size</strong>: 256</li>
</ul>
</li>
<li><strong>Inference</strong>:
<ul>
<li><strong>Beam Search</strong>: A beam size of <strong>3</strong> is used. Larger beam sizes (4-5) degraded BLEU/ROUGE scores due to increased redundancy.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Visual Encoder</strong>:
<ul>
<li><strong>Backbone</strong>: ResNet-101 initialized with ImageNet weights</li>
<li><strong>Structure</strong>: Convolutional layers preserved up to the final module. Classification head removed.</li>
<li><strong>Pooling</strong>: A $7 \times 7$ average pooling layer is used to extract global visual features.</li>
</ul>
</li>
<li><strong>Sequence Decoder</strong>:
<ul>
<li><strong>Architecture</strong>: Transformer-based with CGFE and SDGLA modules.</li>
<li><strong>Depth</strong>: 3 Transformer layers</li>
<li><strong>Dropout</strong>: Not utilized</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance is reported using sequence-level and structure-level metrics.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">DGAT Score</th>
          <th style="text-align: left">Baseline (MPOCSR)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>BLEU-4</strong></td>
          <td style="text-align: left"><strong>84.0%</strong></td>
          <td style="text-align: left">78.7%</td>
          <td style="text-align: left">Measures n-gram precision</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>ROUGE</strong></td>
          <td style="text-align: left"><strong>90.8%</strong></td>
          <td style="text-align: left">88.9%</td>
          <td style="text-align: left">Sequence recall metric</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Tanimoto</strong></td>
          <td style="text-align: left"><strong>98.8%</strong></td>
          <td style="text-align: left">97.6%</td>
          <td style="text-align: left">Structural similarity fingerprint</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Accuracy</strong></td>
          <td style="text-align: left"><strong>54.6%</strong></td>
          <td style="text-align: left">35.7%</td>
          <td style="text-align: left">Exact structure match rate</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/Drwr97/DGAT">DGAT</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation with training and evaluation scripts</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Wang, R., Ji, Y., Li, Y., &amp; Lee, S.-T. (2025). Dual-Path Global Awareness Transformer for Optical Chemical Structure Recognition. <em>The Journal of Physical Chemistry Letters</em>, 16(50), 12787-12795. <a href="https://doi.org/10.1021/acs.jpclett.5c03057">https://doi.org/10.1021/acs.jpclett.5c03057</a></p>
<p><strong>Publication</strong>: The Journal of Physical Chemistry Letters 2025</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Drwr97/DGAT">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{wang2025dgat,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Dual-Path Global Awareness Transformer for Optical Chemical Structure Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Wang, Rui and Ji, Yujin and Li, Youyong and Lee, Shuit-Tong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{The Journal of Physical Chemistry Letters}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{50}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{12787--12795}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jpclett.5c03057}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DECIMER.ai: Optical Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-ai/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-ai/</guid><description>Open-source OCSR platform combining Mask R-CNN segmentation and Transformer recognition, trained on 450M+ synthetic images from RanDepict.</description><content:encoded><![CDATA[<h2 id="project-scope-and-contribution-type">Project Scope and Contribution Type</h2>
<p>This is primarily a <strong>Resource</strong> paper (Infrastructure Basis) with a significant <strong>Method</strong> component.</p>
<p>The primary contribution is DECIMER.ai, a fully open-source platform (web app and Python packages) for the entire chemical structure mining pipeline, filling a gap where most tools were proprietary or fragmented. It also contributes the RanDepict toolkit for massive synthetic data generation.</p>
<p>The secondary methodological contribution proposes and validates a specific deep learning architecture (EfficientNet-V2 encoder + Transformer decoder) that treats chemical structure recognition as an image-to-text translation task (SMILES generation).</p>
<h2 id="the-scarcity-of-machine-readable-chemical-data">The Scarcity of Machine-Readable Chemical Data</h2>
<p><strong>Data Scarcity</strong>: While the number of chemical publications is increasing, most chemical information is locked in non-machine-readable formats (images in PDFs) and is not available in public databases.</p>
<p><strong>Limitations of Existing Tools</strong>: Prior OCSR (Optical Chemical Structure Recognition) tools were largely rule-based (fragile to noise) or proprietary.</p>
<p><strong>Lack of Integration</strong>: There was no existing open-source system that combined segmentation (finding the molecule on a page), classification (confirming it is a molecule), and recognition (translating it to SMILES) into a single workflow.</p>
<h2 id="decimer-architecture-and-novel-image-to-smiles-approach">DECIMER Architecture and Novel Image-to-SMILES Approach</h2>
<p><strong>Comprehensive Workflow</strong>: It is the first open-source platform to integrate segmentation (Mask R-CNN), classification (EfficientNet), and recognition (Transformer) into a unified pipeline.</p>
<p><strong>Data-Driven Approach</strong>: Unlike tools like MolScribe which use intermediate graph representations and rules, DECIMER uses a purely data-driven &ldquo;image-to-SMILES&rdquo; translation approach without hard-coded chemical rules. The core recognition model operates as a sequence-to-sequence generator, mathematically formalizing the task as maximizing the conditional probability of a SMILES sequence given an image.</p>
<p><strong>Massive Synthetic Training</strong>: The use of RanDepict to generate over 450 million synthetic images, covering diverse depiction styles and augmentations (including Markush structures), to train the model from scratch.</p>
<h2 id="benchmarking-and-evaluation-methodology">Benchmarking and Evaluation Methodology</h2>
<p><strong>Benchmarking</strong>: The system was tested against openly available tools (OSRA, MolVec, Imago, Img2Mol, SwinOCSR, MolScribe) on standard datasets: USPTO, UOB, CLEF, JPO, and a custom &ldquo;Hand-drawn&rdquo; dataset.</p>
<p><strong>Robustness Testing</strong>: Performance was evaluated on both clean images and images with added distortions (rotation, shearing) to test the fragility of rule-based systems vs. DECIMER.</p>
<p><strong>Markush Structure Analysis</strong>: Specific evaluation of the model&rsquo;s ability to interpret Markush structures (generic structures with R-groups).</p>
<p><strong>Comparison of Approaches</strong>: A direct comparison with MolScribe by training DECIMER on MolScribe&rsquo;s smaller training set to isolate the impact of architecture vs. data volume.</p>
<h2 id="performance-outcomes-and-key-findings">Performance Outcomes and Key Findings</h2>
<p><strong>Comparative Performance</strong>: DECIMER Image Transformer consistently produced average Tanimoto similarities above 0.95 on in-domain test data and achieved competitive or leading results across external benchmarks, with extremely low rates of catastrophic failure. Tanimoto similarity is calculated based on molecular fingerprints $A$ and $B$ as:
$$ T(A, B) = \frac{A \cdot B}{|A|^2 + |B|^2 - A \cdot B} $$</p>
<p><strong>Data Volume Necessity</strong>: When trained on small datasets, MolScribe (graph/rule-based) outperformed DECIMER. DECIMER&rsquo;s performance advantage relies heavily on its massive training scale (&gt;400M images).</p>
<p><strong>Robustness</strong>: The model showed no performance degradation on distorted images, unlike rule-based legacy tools.</p>
<p><strong>Generalization</strong>: Despite having no hand-drawn images in the training set, the base model recognized 27% of hand-drawn structures perfectly (average Tanimoto 0.69), outperforming all alternative open tools. After fine-tuning with synthetic hand-drawn-like images from RanDepict, perfect predictions increased to 60% (average Tanimoto 0.89).</p>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/OBrink/DECIMER.ai">DECIMER.ai Web App</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Laravel-based web application for the full pipeline</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER Image Transformer</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Core OCSR Python package</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER-Image-Segmentation">DECIMER Image Segmentation</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Mask R-CNN segmentation for chemical structures in documents</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Iagea/DECIMER-Image-Classifier">DECIMER Image Classifier</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>EfficientNet-based chemical structure image classifier</td>
      </tr>
      <tr>
          <td><a href="https://github.com/OBrink/RanDepict">RanDepict</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Synthetic training data generation toolkit</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The models were trained on synthetic data generated from PubChem molecules.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Generation/Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_1</code></td>
          <td>~108M mols</td>
          <td>PubChem molecules (mass &lt; 1500 Da), processed with RanDepict (v1.0.5). Included image augmentations.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_2</code></td>
          <td>~126M mols</td>
          <td>Included Markush structures generated by pseudo-randomly replacing atoms with R-groups. Image size 299x299.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><code>pubchem_3</code></td>
          <td>&gt;453M images</td>
          <td>Re-depicted <code>pubchem_2</code> molecules at <strong>512x512</strong> resolution. Used RanDepict v1.0.8.</td>
      </tr>
      <tr>
          <td><strong>Test</strong></td>
          <td>In-domain</td>
          <td>250,000</td>
          <td>Held-out set generated similarly to training data.</td>
      </tr>
      <tr>
          <td><strong>Benchmark</strong></td>
          <td>External</td>
          <td>Various</td>
          <td>USPTO (5719), UOB (5740), CLEF (992), JPO (450), Indigo (50k), Hand-drawn (5088).</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Generation</strong>:</p>
<ul>
<li><strong>Tool</strong>: RanDepict (uses CDK, RDKit, Indigo, PIKAChU)</li>
<li><strong>Augmentations</strong>: Rotation, shearing, noise, pixelation, curved arrows, text labels</li>
<li><strong>Format</strong>: Data saved as TFRecord files for TPU training</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>SMILES Tokenization</strong>: Regex-based splitting (atoms, brackets, bonds). Added <code>&lt;start&gt;</code>, <code>&lt;end&gt;</code>, and padded with <code>&lt;pad&gt;</code>. <code>&lt;unk&gt;</code> used for unknown tokens.</li>
<li><strong>Markush Token Handling</strong>: To avoid ambiguity, digits following &lsquo;R&rsquo; (e.g., R1) were replaced with unique non-digit characters during training to distinguish them from ring-closure numbers.</li>
<li><strong>Image Augmentation Pipeline</strong>: Custom RanDepict features (v1.1.4) were used to simulate &ldquo;hand-drawn-like&rdquo; styles based on ChemPIX&rsquo;s implementation.</li>
</ul>
<h3 id="models">Models</h3>
<p>The platform consists of three distinct models:</p>
<ol>
<li>
<p><strong>DECIMER Segmentation</strong>:</p>
<ul>
<li><strong>Architecture</strong>: Mask R-CNN (TensorFlow 2.10.0 implementation)</li>
<li><strong>Purpose</strong>: Detects and cuts chemical structures from full PDF pages</li>
</ul>
</li>
<li>
<p><strong>DECIMER Image Classifier</strong>:</p>
<ul>
<li><strong>Architecture</strong>: EfficientNet-V1-B0</li>
<li><strong>Input</strong>: 224x224 pixels</li>
<li><strong>Training</strong>: Fine-tuned on ~10.9M images (balanced chemical/non-chemical)</li>
<li><strong>Performance</strong>: AUC 0.99 on in-domain test set</li>
</ul>
</li>
<li>
<p><strong>DECIMER Image Transformer (OCSR Engine)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: EfficientNet-V2-M (CNN). Input size <strong>512x512</strong>. 52M parameters</li>
<li><strong>Decoder</strong>: Transformer. 4 encoder blocks, 4 decoder blocks, 8 attention heads. d_model=512, d_ff=2048. 59M parameters</li>
<li><strong>Total Params</strong>: ~111 Million</li>
</ul>
</li>
</ol>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Tanimoto Similarity (calculated on PubChem fingerprints of the predicted vs. ground truth SMILES)</li>
<li><strong>Secondary Metrics</strong>: Exact Match (Identity), BLEU score (for string similarity, esp. Markush)</li>
<li><strong>Failure Analysis</strong>: &ldquo;Catastrophic failure&rdquo; defined as Tanimoto similarity of 0 or invalid SMILES</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on Google Cloud TPUs due to the massive dataset size.</p>
<ul>
<li><strong><code>pubchem_1</code>/<code>pubchem_2</code></strong>: Trained on TPU v3-32 pod slice</li>
<li><strong><code>pubchem_3</code> (Final Model)</strong>: Trained on <strong>TPU v3-256</strong> pod slice</li>
<li><strong>Training Time</strong>:
<ul>
<li>Data generation (512x512): ~2 weeks on cluster (20 threads, 36 cores)</li>
<li>Model Training (EffNet-V2-M): <strong>1 day and 7 hours per epoch</strong> on TPU v3-256</li>
</ul>
</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Brinkhaus, H. O., Agea, M. I., Zielesny, A., &amp; Steinbeck, C. (2023). DECIMER.ai: an open platform for automated optical chemical structure identification, segmentation and recognition in scientific publications. <em>Nature Communications</em>, 14(1), 5045. <a href="https://doi.org/10.1038/s41467-023-40782-0">https://doi.org/10.1038/s41467-023-40782-0</a></p>
<p><strong>Publication</strong>: Nature Communications 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://decimer.ai">Web Application</a></li>
<li><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER Image Transformer GitHub</a></li>
<li><a href="https://github.com/OBrink/RanDepict">RanDepict GitHub</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMERaiOpenPlatform2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{DECIMER.ai: an open platform for automated optical chemical structure identification, segmentation and recognition in scientific publications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Brinkhaus, Henning Otto and Agea, M. Isabel and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Nature Communications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{5045}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41467-023-40782-0}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemVLM: A Multimodal Large Language Model for Chemistry</title><link>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemvlm/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/llm-applications/chemvlm/</guid><description>A 26B parameter multimodal LLM for chemistry, combining InternViT-6B and ChemLLM-20B for molecular structure recognition, property prediction, and reasoning.</description><content:encoded><![CDATA[<h2 id="paper-classification-method-and-resource">Paper Classification: Method and Resource</h2>
<p>This paper is a combination of <strong>Method</strong> (primary) and <strong>Resource</strong> (secondary).</p>
<p>It is primarily a <strong>Method</strong> paper because it proposes <strong>ChemVLM</strong>, a novel multimodal architecture specifically tailored for the chemical domain, utilizing a &ldquo;ViT-MLP-LLM&rdquo; framework. The authors introduce a specific two-stage training strategy to align visual features with chemical text representations.</p>
<p>Secondarily, it is a <strong>Resource</strong> paper as it introduces a comprehensive suite of three new datasets: <strong>ChemOCR</strong>, <strong>MMCR-Bench</strong>, and <strong>MMChemBench</strong>, developed to rigorously evaluate multimodal capabilities in chemistry, covering OCR, reasoning, and property prediction.</p>
<h2 id="bridging-the-visual-gap-in-chemical-llms">Bridging the Visual Gap in Chemical LLMs</h2>
<p>The primary motivation is the limitation of existing models in handling the multimodal nature of chemistry.</p>
<ul>
<li><strong>Visual Data Gap</strong>: Chemical tasks heavily rely on visual information (molecular structures, reactions) which purely text-based chemical LLMs cannot process.</li>
<li><strong>Limitations of Generalist Models</strong>: General multimodal models (like GPT-4V or LLaVA) lack specialized chemical domain knowledge, leading to hallucinations or misinterpretations.</li>
<li><strong>Inadequacy of OCR Tools</strong>: Traditional <a href="/notes/chemistry/optical-structure-recognition/">chemical OCR</a> tools (like <a href="/notes/chemistry/optical-structure-recognition/image-to-graph/molscribe/">MolScribe</a>) excel at modality conversion (Image-to-<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) but fail at complex reasoning tasks.</li>
</ul>
<h2 id="domain-specific-data-curation-and-benchmarking">Domain-Specific Data Curation and Benchmarking</h2>
<ul>
<li><strong>Data-Driven Alignment</strong>: The underlying &ldquo;ViT-MLP-LLM&rdquo; framework is standard in multimodal modeling, paralleling architectures like LLaVA. The core innovation here is the rigorous creation of a bilingual multimodal dataset spanning hand-drawn molecules, reactions, and exam questions augmented with style transfers. The training data pipeline heavily relies on generating synthetic variance using tools like RanDepict and <a href="https://en.wikipedia.org/wiki/RDKit">RDKit</a> to introduce distortions, rotations, and handwritten styles, alongside GPT-4 generated prompts to ensure linguistic diversity.</li>
<li><strong>Model Integration</strong>: ChemVLM merges <strong>InternViT-6B</strong> (a large-scale vision transformer) with <strong><a href="/notes/chemistry/llm-applications/chemllm-chemical-large-language-model/">ChemLLM-20B</a></strong> (a chemical language model). Visual features $X_v$ are mapped into the linguistic embedding space via an MLP projector, producing aligned token sequences alongside text instructions $X_q$. The joint multimodal sequence is trained using standard autoregressive next-token prediction:
$$ \mathcal{L} = -\sum_{i} \log P(y_i \mid X_v, X_q, y_{&lt;i}) $$</li>
<li><strong>Three Custom Benchmarks</strong>: The authors introduce tailored benchmarks to assess distinct competencies:
<ul>
<li><strong>ChemOCR</strong>: For image-to-SMILES conversion.</li>
<li><strong>MMCR-Bench</strong>: College entrance exam questions testing complex logical reasoning.</li>
<li><strong>MMChemBench</strong>: For molecule captioning and zero-shot property prediction.</li>
</ul>
</li>
</ul>
<h2 id="evaluating-chemical-ocr-and-reasoning">Evaluating Chemical OCR and Reasoning</h2>
<p>The authors benchmarked ChemVLM against both open-source (LLaVA, Qwen-VL, InternVL) and proprietary (GPT-4V) models across three primary domains:</p>
<ol>
<li><strong>Chemical OCR</strong>: Evaluated on 1,000 image-text pairs from ChemOCR. The primary metric is the <a href="https://en.wikipedia.org/wiki/Jaccard_index">Tanimoto similarity</a> between the Morgan fingerprints of the generated structure ($A$) and the ground-truth SMILES ($B$):
$$ T(A, B) = \frac{|A \cap B|}{|A| + |B| - |A \cap B|} $$
They report both the average Tanimoto similarity and the strict exact-match rate (<code>Tanimoto@1.0</code>).</li>
<li><strong>Multimodal Chemical Reasoning (MMCR)</strong>: Tested on MMCR-Bench (1,000 exam questions), ScienceQA, and CMMU. Performance was scored based on accuracy for multiple-choice and fill-in-the-blank questions.</li>
<li><strong>Multimodal Molecule Understanding</strong>: Evaluated on MMChemBench for molecule captioning and property prediction.</li>
<li><strong>Text-Only Reasoning</strong>: Tested on SciBench, a text-only benchmark for university-level science, to ensure the model retains fundamental linguistic reasoning.</li>
<li><strong>Generalization</strong>: Tested on non-chemistry subjects within the CMMU framework (Biology, Physics, Math) to assess cross-domain competence.</li>
</ol>
<h2 id="performance-gains-and-existing-limitations">Performance Gains and Existing Limitations</h2>
<ul>
<li><strong>Multimodal Reasoning Leadership</strong>: ChemVLM achieved state-of-the-art results on MMCR-Bench (41.7%), surpassing generalist models like GPT-4V (40.1%). However, scoring for portions of these benchmarks relied heavily on an LLM-as-a-judge (the Qwen-max API), which can introduce bias as LLM evaluators often favor structural characteristics and verbosity produced by similar autoregressive models. Furthermore, the model was fine-tuned on 200,000 exam questions and tested on MMCR-Bench (also derived from Chinese college entrance exams). While the authors state the data was deduplicated, the potential for data leakage remains a significant unaddressed confounder.</li>
<li><strong>Superior Understanding</strong>: In molecule captioning and prediction, ChemVLM showed significant improvements over general baseline models, scoring 80.9% on prediction compared to GPT-4V&rsquo;s 38.6%. This is a natural consequence of testing a custom-trained model on domain-specific benchmarks.</li>
<li><strong>OCR Capabilities vs. Dedicated Tools</strong>: ChemVLM outperformed generalist MLLMs in chemical structure recognition, achieving an average Tanimoto similarity of 71.0% (vs. GPT-4V&rsquo;s 15.0%). However, it remains significantly inferior to pure structural OCR tools like MolScribe in strict modality conversion tasks, only achieving an exact structural match (<code>Tanimoto@1.0</code>) of 42.9% compared to MolScribe&rsquo;s 89.1%.</li>
<li><strong>Textual Retention and Generalization Claims</strong>: The authors claim the diverse training strategy imparts broad scientific reasoning, pointing to performance retention on non-chemistry subjects (Biology, Physics, Math) and strong results on the purely textual SciBench benchmark. However, this cross-domain generalization highly likely stems from the underlying base model (ChemLLM-20B/InternLM2) or the inclusion of 1.3 million &ldquo;General&rdquo; visual QA pairs in their training blend, rather than emergent general scientific skills originating purely from learning chemistry representations.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training and evaluation data relied on a mix of open-source repositories and custom curation. Many of the curated datasets have been formally released by the authors on Hugging Face (<a href="https://huggingface.co/datasets/di-zhang-fdu/chemvlm-sft-datasets"><code>di-zhang-fdu/chemvlm-sft-datasets</code></a>).</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Source/Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong><a href="/notes/chemistry/optical-structure-recognition/hand-drawn/decimer-hand-drawn/">DECIMER HDM</a></strong></td>
          <td>7,000+ hand-drawn molecular images.</td>
      </tr>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong>MolScribe Data</strong></td>
          <td>Scanned/photographed images from literature.</td>
      </tr>
      <tr>
          <td><strong>Training (Molecule)</strong></td>
          <td><strong>Synthetic</strong></td>
          <td>Generated via ChemDraw, RDKit, and Indigo with style transfer (blurring, rotation, handwritten styles).</td>
      </tr>
      <tr>
          <td><strong>Training (Reaction)</strong></td>
          <td><strong>PEACE &amp; USPTO-50K</strong></td>
          <td>Inorganic and organic reaction schemes.</td>
      </tr>
      <tr>
          <td><strong>Training (Reasoning)</strong></td>
          <td><strong>Exam Questions</strong></td>
          <td>200,000 questions from OpenDataLab (Chinese education level). <a href="https://huggingface.co/collections/di-zhang-fdu/multi-corpus-datasets-for-chemllm">Available on Hugging Face</a>.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>ChemOCR</strong></td>
          <td>1,000 bilingual image-text pairs for SMILES recognition. Released via Google Drive link in repo.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>MMCR-Bench</strong></td>
          <td>1,000 multimodal chemistry exam questions. <strong>Requires emailing authors directly for access.</strong></td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>MMChemBench</strong></td>
          <td>Extension of <a href="/notes/chemistry/llm-applications/chembench-llm-chemistry-evaluation/">ChemBench</a> for captioning and property prediction. Released via Google Drive link in repo.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>: Images were augmented using <strong>RanDepict</strong> for style variation. Text data (SMILES) was validated and cleaned. Prompts were diversified using GPT-4 to generate different linguistic styles.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: &ldquo;ViT-MLP-LLM&rdquo; structure.
<ul>
<li><strong>Vision Encoder</strong>: InternViT-6B, processing images at $448 \times 448$ resolution. Images are segmented into tiles (max 12).</li>
<li><strong>Projector</strong>: Multi-Layer Perceptron (MLP) initialized randomly to map visual features to text embedding space.</li>
<li><strong>LLM</strong>: ChemLLM-20B, a domain-specific model.</li>
</ul>
</li>
<li><strong>Training Strategy</strong>: Two-stage supervised fine-tuning.
<ol>
<li><strong>Modal Alignment</strong>: Freeze LLM and base Vision Encoder weights. Train only the randomly initialized MLP projector and LoRA layers (rank 32) of the Vision Encoder. Uses diverse multimodal data.</li>
<li><strong>Supervised Fine-Tuning (SFT)</strong>: Keep LLM and Vision Encoder base weights frozen, but add LoRA (rank 16) to the LLM and retain LoRA (rank 32) on the Vision Encoder. The MLP projector is fully trained. Data includes specialized chemistry and general corpora.</li>
</ol>
</li>
<li><strong>Optimization</strong>:
<ul>
<li>Optimizer: AdamW</li>
<li>Context Length: 2048 tokens</li>
<li>Chat Template: InternLM2 dialogue schema</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>ChemVLM-26B</strong>: The primary model released. It combines the 6B parameter vision encoder and the 20B parameter language model. Weights are fully available at <a href="https://huggingface.co/AI4Chem/ChemVLM-26B-1-2"><code>AI4Chem/ChemVLM-26B-1-2</code></a>. An 8B version is also available.</li>
<li><strong>Baselines</strong>: Comparisons were made against <strong>GPT-4V</strong>, <strong>Qwen-VL-Chat</strong>, <strong>LLaVA-v1.5-13B</strong>, <strong>InternVL-v1.5</strong>, and <strong>Yi-VL-Plus</strong>.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance was measured across three distinct task types. Exact <a href="https://github.com/lijunxian111/ChemVlm/tree/master/evaluation">evaluation scripts</a> have been released in the official repository.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Method</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto Similarity</strong></td>
          <td>ChemOCR</td>
          <td>Comparison of generated SMILES vs. ground truth using RDKit. Reports Average Similarity and <code>Tanimoto@1.0</code> (exact match).</td>
      </tr>
      <tr>
          <td><strong>Accuracy</strong></td>
          <td>MMCR (Reasoning)</td>
          <td>+1 point for correct multiple-choice/fill-in-the-blank; 0 otherwise. Scored via Qwen-max API prompting.</td>
      </tr>
      <tr>
          <td><strong>Prediction Score</strong></td>
          <td>Property Prediction</td>
          <td>Evaluated on MMChemBench subsets.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Compute</strong>: Training utilized <strong>16 NVIDIA A100 (80GB)</strong> GPUs.</li>
<li><strong>Configuration</strong>:
<ul>
<li>Batch size: 4 (per GPU, resulting in an effective global batch size of 256)</li>
<li>Gradient Accumulation: 4 iterations</li>
<li>Precision: <strong><a href="https://en.wikipedia.org/wiki/DeepSpeed">Deepspeed</a> bfloat16 (bf16)</strong> with <strong>ZeRO-3</strong> offloading strategy</li>
<li>Framework: Training runs on the InternVL-v1.5 codebase rather than standalone scripts.</li>
</ul>
</li>
<li><strong>Inference Compute</strong>: Evaluating the 26B model requires at least one 80GB A100 GPU (with Flash Attention + bfloat16). The 8B variant requires a GPU with at least 48GB of VRAM.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemVLM-26B">ChemVLM-26B</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Original 26B model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/AI4Chem/ChemVLM-26B-1-2">ChemVLM-26B-1-2</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>Updated 26B model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/di-zhang-fdu/chemvlm-sft-datasets">chemvlm-sft-datasets</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>SFT training data (~51.7k rows)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/lijunxian111/ChemVlm">ChemVlm (GitHub)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training, evaluation, and inference code</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, J., et al. (2025). ChemVLM: Exploring the Power of Multimodal Large Language Models in Chemistry Area. <em>Proceedings of the AAAI Conference on Artificial Intelligence</em>, 39(1), 415-423. <a href="https://doi.org/10.1609/aaai.v39i1.32020">https://doi.org/10.1609/aaai.v39i1.32020</a></p>
<p><strong>Publication</strong>: AAAI 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{li2025chemvlm,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemVLM: Exploring the Power of Multimodal Large Language Models in Chemistry Area}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Li, Junxian and Zhang, Di and Wang, Xunzhi and Hao, Zeying and Lei, Jingdi and Tan, Qian and Zhou, Cai and Liu, Wei and Yang, Yaotian and Xiong, Xinrui and Wang, Weiyun and Chen, Zhe and Wang, Wenhai and Li, Wei and Su, Mao and Zhang, Shufei and Ouyang, Wanli and Li, Yuqiang and Zhou, Dongzhan}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the AAAI Conference on Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{39}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{415--423}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://doi.org/10.1609/aaai.v39i1.32020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1609/aaai.v39i1.32020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/lijunxian111/ChemVlm">Official Repository</a></li>
</ul>
]]></content:encoded></item><item><title>ChemReco: Hand-Drawn Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/chemreco/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/chemreco/</guid><description>A deep learning method using EfficientNet and Transformer to convert hand-drawn chemical structures into SMILES codes, achieving 96.9% accuracy.</description><content:encoded><![CDATA[<h2 id="research-contribution--classification">Research Contribution &amp; Classification</h2>
<p>This is a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a significant <strong>Resource ($\Psi_{\text{Resource}}$)</strong> component.</p>
<ul>
<li><strong>Method</strong>: The primary contribution is &ldquo;ChemReco,&rdquo; a specific deep learning pipeline (EfficientNet + Transformer) designed to solve the Optical Chemical Structure Recognition (OCSR) task for hand-drawn images. The authors conduct extensive ablation studies on architecture and data mixing ratios to validate performance.</li>
<li><strong>Resource</strong>: The authors explicitly state that &ldquo;the primary focus of this paper is constructing datasets&rdquo; due to the scarcity of hand-drawn molecular data. They introduce a comprehensive synthetic data generation pipeline involving RDKit modifications and image degradation to create training data.</li>
</ul>
<h2 id="motivation-digitizing-hand-drawn-chemical-sketches">Motivation: Digitizing Hand-Drawn Chemical Sketches</h2>
<p>Hand-drawing is the most intuitive method for chemists and students to record molecular structures. However, digitizing these drawings into machine-readable formats (like <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) usually requires time-consuming manual entry or specialized software.</p>
<ul>
<li><strong>Gap</strong>: Existing OCSR tools and rule-based methods often fail on hand-drawn sketches due to diverse writing styles, poor image quality, and the absence of labeled data.</li>
<li><strong>Application</strong>: Automated recognition enables efficient chemical research and allows for automatic grading in educational settings.</li>
</ul>
<h2 id="core-innovation-synthetic-pipeline-and-hybrid-architecture">Core Innovation: Synthetic Pipeline and Hybrid Architecture</h2>
<p>The paper introduces <strong>ChemReco</strong>, an end-to-end system for recognizing C-H-O structures. Key novelties include:</p>
<ol>
<li><strong>Synthetic Data Pipeline</strong>: A multi-stage generation method that modifies RDKit source code to randomize bond/angle parameters, followed by OpenCV-based augmentation, degradation, and background addition to simulate realistic hand-drawn artifacts.</li>
<li><strong>Architectural Choice</strong>: The specific application of <strong>EfficientNet</strong> (encoder) combined with a <strong>Transformer</strong> (decoder) for this domain, which the authors demonstrate outperforms the more common ResNet+LSTM baselines.</li>
<li><strong>Hybrid Training Strategy</strong>: Finding that a mix of 90% synthetic and 10% real data yields optimal performance, superior to using either dataset alone.</li>
</ol>
<h2 id="methodology--ablation-studies">Methodology &amp; Ablation Studies</h2>
<p>The authors performed a series of ablation studies and comparisons:</p>
<ul>
<li><strong>Synthesis Ablation</strong>: Evaluated the impact of each step in the generation pipeline (RDKit only $\rightarrow$ Augmentation $\rightarrow$ Degradation $\rightarrow$ Background) on validation loss and accuracy.</li>
<li><strong>Dataset Size Ablation</strong>: Tested model performance when trained on synthetic datasets ranging from 100,000 to 1,000,000 images.</li>
<li><strong>Real/Synthetic Ratio</strong>: Investigated the optimal mixing ratio of synthetic to real hand-drawn images (100:0, 90:10, 50:50, 10:90, 0:100), finding that the 90:10 ratio achieved 93.81% exact match, compared to 63.33% for synthetic-only and 65.83% for real-only.</li>
<li><strong>Architecture Comparison</strong>: Benchmarked four encoder-decoder combinations: ResNet vs. EfficientNet encoders paired with LSTM vs. Transformer decoders.</li>
<li><strong>Baseline Comparison</strong>: Compared results against a related study utilizing a CNN+LSTM framework.</li>
</ul>
<h2 id="results--interpretations">Results &amp; Interpretations</h2>
<ul>
<li><strong>Best Performance</strong>: The EfficientNet + Transformer model trained on a 90:10 synthetic-to-real ratio achieved a <strong>96.90% Exact Match</strong> rate on the test set.</li>
<li><strong>Background Robustness</strong>: When training on synthetic data alone (no real images), the best accuracy on background-free test images was approximately 46% (using RDKit-aug-deg), while background test images reached approximately 53% (using RDKit-aug-bkg-deg). Adding random backgrounds during training helped prevent the model from overfitting to clean white backgrounds.</li>
<li><strong>Data Volume</strong>: Increasing the synthetic dataset size from 100k to 1M consistently improved accuracy (average exact match: 49.40% at 100k, 54.29% at 200k, 61.31% at 500k, 63.33% at 1M, all without real images in training).</li>
<li><strong>Encoder-Decoder Comparison</strong> (at 90:10 mix with 1M images):</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Encoder</th>
          <th style="text-align: left">Decoder</th>
          <th style="text-align: left">Avg. Exact Match (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">ResNet</td>
          <td style="text-align: left">LSTM</td>
          <td style="text-align: left">93.81</td>
      </tr>
      <tr>
          <td style="text-align: left">ResNet</td>
          <td style="text-align: left">Transformer</td>
          <td style="text-align: left">94.76</td>
      </tr>
      <tr>
          <td style="text-align: left">EfficientNet</td>
          <td style="text-align: left">LSTM</td>
          <td style="text-align: left">96.31</td>
      </tr>
      <tr>
          <td style="text-align: left">EfficientNet</td>
          <td style="text-align: left">Transformer</td>
          <td style="text-align: left"><strong>96.90</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Superiority over Baselines</strong>: The model outperformed the cited CNN+LSTM baseline from ChemPix (93% vs 76% on the ChemPix test set).</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li><strong>Restricted atom types</strong>: The system only handles molecules composed of carbon, hydrogen, and oxygen (C-H-O), excluding nitrogen, sulfur, halogens, and other heteroatoms commonly found in organic chemistry.</li>
<li><strong>Structural complexity</strong>: Only structures with at most one ring are supported. Complex multi-ring systems and fused ring structures are not covered.</li>
<li><strong>Dataset availability</strong>: The real hand-drawn dataset (2,598 images) is not publicly released and is only available upon request from the corresponding author.</li>
<li><strong>Future directions</strong>: The authors suggest expanding to more heteroatoms, complex ring structures, and applications in automated grading of chemistry exams.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/a-die/hdr-DeepLearning">hdr-DeepLearning</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation in PyTorch</td>
      </tr>
      <tr>
          <td style="text-align: left">Paper</td>
          <td style="text-align: left">Publication</td>
          <td style="text-align: left">CC-BY-4.0</td>
          <td style="text-align: left">Open access via Nature</td>
      </tr>
  </tbody>
</table>
<p>The real hand-drawn dataset (2,598 images) is available upon request from the corresponding author, not publicly downloadable. The synthetic data generation pipeline is described in detail but relies on modified RDKit source code, which is included in the repository.</p>
<h3 id="data">Data</h3>
<p>The study utilizes a combination of collected SMILES data, real hand-drawn images, and generated synthetic images.</p>
<ul>
<li><strong>Source Data</strong>: SMILES codes collected from PubChem, ZINC, <a href="/notes/chemistry/datasets/gdb-11/">GDB-11</a>, and <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a>. Filtered for C, H, O atoms and max 1 ring.</li>
<li><strong>Real Dataset</strong>: 670 selected SMILES codes drawn by multiple volunteers, totaling <strong>2,598 images</strong>.</li>
<li><strong>Synthetic Dataset</strong>: Generated up to <strong>1,000,000 images</strong> using the pipeline below.</li>
<li><strong>Training Mix</strong>: The optimal training set used 1 million images with a <strong>90:10 ratio</strong> of synthetic to real images.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Dataset Type</th>
          <th style="text-align: left">Source</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Real</strong></td>
          <td style="text-align: left">Volunteer Drawings</td>
          <td style="text-align: left">2,598 images</td>
          <td style="text-align: left">Used for mixed training and testing</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Synthetic</strong></td>
          <td style="text-align: left">Generated</td>
          <td style="text-align: left">100k - 1M</td>
          <td style="text-align: left">Generated via modified RDKit + OpenCV augmentation/degradation; optionally enhanced with Stable Diffusion</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The <strong>Synthetic Image Generation Pipeline</strong> is critical for reproduction:</p>
<ol>
<li><strong>RDKit Modification</strong>: Modify source code to introduce random keys, character width, length, and bond angles.</li>
<li><strong>Augmentation (OpenCV)</strong>: Apply sequence: Resize ($p=0.5$), Blur ($p=0.4$), Erode/Dilate ($p=0.2$), Distort ($p=0.8$), Flip ($p=0.5$), Affine ($p=0.7$).</li>
<li><strong>Degradation</strong>: Apply sequence: Salt+pepper noise ($p=0.1$), Contrast ($p=0.7$), Sharpness ($p=0.5$), Invert ($p=0.3$).</li>
<li><strong>Background Addition</strong>: Random backgrounds are augmented (Crop, Distort, Flip) and added to the molecular image to prevent background overfitting.</li>
<li><strong>Diffusion Enhancement</strong>: Stable Diffusion (v1-4) is used for image-to-image enhancement to better simulate hand-drawn styles (prompt: &ldquo;A pencil sketch of [Formula]&hellip; without charge distribution&rdquo;).</li>
</ol>
<h3 id="models">Models</h3>
<p>The system uses an encoder-decoder architecture:</p>
<ul>
<li><strong>Encoder</strong>: <strong>EfficientNet</strong> (pre-trained on ImageNet). The last layer is removed, and features are extracted into a Numpy array.</li>
<li><strong>Decoder</strong>: <strong>Transformer</strong>. Utilizes self-attention to generate the SMILES sequence. Chosen over LSTM for better handling of long-range dependencies.</li>
<li><strong>Output</strong>: Canonical SMILES string.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: <strong>Exact Match (EM)</strong>. A strict binary evaluation checking whether the complete generated SMILES perfectly replicates the target string.</li>
<li><strong>Other Metrics</strong>: <strong>Levenshtein Distance</strong> measures edit-level character proximity, while the <strong>Tanimoto coefficient</strong> evaluates structural similarity based on chemical fingerprints. Both were monitored during validation ablation runs.</li>
</ul>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Metric</th>
          <th style="text-align: left">Value</th>
          <th style="text-align: left">Baseline (CNN+LSTM)</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Exact Match</strong></td>
          <td style="text-align: left"><strong>96.90%</strong></td>
          <td style="text-align: left">76%</td>
          <td style="text-align: left">Tested on the provided test set</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CPU</strong>: Intel(R) Xeon(R) Gold 6130 (40 GB RAM).</li>
<li><strong>GPU</strong>: NVIDIA Tesla V100 (32 GB video memory).</li>
<li><strong>Framework</strong>: PyTorch 1.9.1.</li>
<li><strong>Training Configuration</strong>:
<ul>
<li>Optimizer: Adam (learning rate 1e-4).</li>
<li>Batch size: 32.</li>
<li>Epochs: 100.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Ouyang, H., Liu, W., Tao, J., et al. (2024). ChemReco: automated recognition of hand-drawn carbon-hydrogen-oxygen structures using deep learning. <em>Scientific Reports</em>, 14, 17126. <a href="https://doi.org/10.1038/s41598-024-67496-7">https://doi.org/10.1038/s41598-024-67496-7</a></p>
<p><strong>Publication</strong>: Scientific Reports 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/a-die/hdr-DeepLearning">Official Code Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{ouyangChemRecoAutomatedRecognition2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{ChemReco: Automated Recognition of Hand-Drawn Carbon--Hydrogen--Oxygen Structures Using Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Ouyang, Hengjie and Liu, Wei and Tao, Jiajun and Luo, Yanghong and Zhang, Wanjia and Zhou, Jiayu and Geng, Shuqi and Zhang, Chengpeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Scientific Reports}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{14}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{17126}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Nature Publishing Group}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1038/s41598-024-67496-7}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>AtomLenz: Atom-Level OCSR with Limited Supervision</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/atomlenz/</link><pubDate>Fri, 19 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/atomlenz/</guid><description>Weakly supervised OCSR framework combining object detection and graph construction to recognize chemical structures from hand-drawn images using SMILES.</description><content:encoded><![CDATA[<h2 id="dual-contribution-method-and-data-resource">Dual Contribution: Method and Data Resource</h2>
<p>The paper proposes an architecture (AtomLenz) and training framework (ProbKT* + Edit-Correction) to solve the problem of Optical Chemical Structure Recognition (OCSR) in data-sparse domains. It also releases a curated, relabeled dataset of hand-drawn molecules with atom-level bounding box annotations.</p>
<h2 id="overcoming-annotation-bottlenecks-in-ocsr">Overcoming Annotation Bottlenecks in OCSR</h2>
<p>Optical Chemical Structure Recognition (OCSR) is critical for digitizing chemical literature and lab notes. However, existing methods face three main limitations:</p>
<ol>
<li><strong>Generalization Limits:</strong> They struggle with sparse or stylistically unique domains, such as hand-drawn images, where massive datasets for pretraining are unavailable.</li>
<li><strong>Annotation Cost:</strong> &ldquo;Atom-level&rdquo; methods (which detect individual atoms and bonds) require expensive bounding box annotations, which are rarely available for real-world sketch data.</li>
<li><strong>Lack of Interpretability/Localization:</strong> Pure &ldquo;Image-to-SMILES&rdquo; models (like DECIMER) work well but fail to localize the atoms or bonds in the original image, limiting human-in-the-loop review and mechanistic interpretability.</li>
</ol>
<h2 id="atomlenz-probkt-and-graph-edit-correction">AtomLenz, ProbKT*, and Graph Edit-Correction</h2>
<p>The core contribution is <strong>AtomLenz</strong>, an OCSR framework that achieves atom-level entity detection using <strong>only SMILES supervision</strong> on target domains. The authors construct an explicit object detection pipeline using Faster R-CNN trained via a composite multi-task loss. The objective aims to optimize a multi-class log loss $L_{cls}$ for predicted class $\hat{c}$ and a regression loss $L_{reg}$ for predicted bounding box coordinates $\hat{b}$:</p>
<p>$$ \mathcal{L} = L_{cls}(c, \hat{c}) + L_{reg}(b, \hat{b}) $$</p>
<p>To bridge the gap between image inputs and the weakly supervised SMILES labels, the system leverages:</p>
<ul>
<li><em><em>ProbKT</em> (Probabilistic Knowledge Transfer):</em>* Uses probabilistic logic and Hungarian matching to align predicted objects with the &ldquo;ground truth&rdquo; derived from the SMILES strings, enabling backpropagation without explicit bounding boxes.</li>
<li><strong>Graph Edit-Correction:</strong> Generates pseudo-labels by solving an optimization problem that finds the smallest edit on the predicted graph such that the corrected graph and the ground-truth SMILES graph become isomorphic, which forces fine-tuning on less frequent atom types. The combination of ProbKT* and Edit-Correction is abbreviated as <strong>EditKT</strong>*.</li>
<li><strong>ChemExpert:</strong> A chemically sound ensemble strategy that cascades predictions from multiple models (e.g., passing through DECIMER, then AtomLenz), halting at the first output that clears basic RDKit chemical validity checks.</li>
</ul>
<h2 id="data-efficiency-and-domain-adaptation-experiments">Data Efficiency and Domain Adaptation Experiments</h2>
<p>The authors evaluated the model specifically on domain adaptation and sample efficiency, treating hand-drawn molecules as the primary low-data target distribution:</p>
<ul>
<li><strong>Pretraining:</strong> Initially trained on ~214k synthetic images from ChEMBL explicitly labeled with bounding boxes (generated via RDKit).</li>
<li><strong>Target Domain Adaptation:</strong> Fine-tuned on the Brinkhaus hand-drawn dataset (4,070 images) using purely SMILES supervision.</li>
<li><strong>Evaluation Sets:</strong>
<ul>
<li><strong>Hand-drawn test set</strong>: 1,018 images.</li>
<li><strong>ChemPix</strong>: 614 out-of-domain hand-drawn images.</li>
<li><strong>Atom Localization set</strong>: 1,000 synthetic images to evaluate precise bounding box capabilities.</li>
</ul>
</li>
<li><strong>Baselines:</strong> Compared against leading OCSR methods, including DECIMER (v2.2.0), Img2Mol, MolScribe, ChemGrapher, and OSRA.</li>
</ul>
<h2 id="state-of-the-art-ensembles-vs-standalone-limitations">State-of-the-Art Ensembles vs. Standalone Limitations</h2>
<ul>
<li><strong>SOTA Ensemble Performance:</strong> The <strong>ChemExpert</strong> module (combining AtomLenz and DECIMER) achieved state-of-the-art accuracy on both hand-drawn (63.5%) and ChemPix (51.8%) test sets.</li>
<li><strong>Data Efficiency under Bottleneck Regimes:</strong> AtomLenz effectively bypassed the massive data constraints of competing models. When all methods were retrained from scratch on the same 4,070-sample hand-drawn training set (enriched with atom-level annotations from EditKT*), AtomLenz achieved 33.8% exact accuracy, outperforming baselines like Img2Mol (0.0%), MolScribe (1.3%), and DECIMER (0.1%), illustrating its sample efficiency.</li>
<li><strong>Localization Success:</strong> The base framework achieved strong localization (mAP 0.801), a capability not provided by end-to-end transformers like DECIMER.</li>
<li><strong>Methodological Tradeoffs:</strong> While AtomLenz is highly sample efficient, its standalone performance when fine-tuned on the target domain (33.8% accuracy) underperforms fine-tuned models trained on larger datasets like DECIMER (62.2% accuracy). AtomLenz achieves state-of-the-art results primarily when deployed as part of the ChemExpert ensemble alongside DECIMER, since errors from the two approaches tend to occur on different samples, allowing them to complement each other.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/molden/atomlenz">Official Repository (AtomLenz)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Complete pipeline for AtomLenz, ProbKT*, and Graph Edit-Correction.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/molden/atomlenz/tree/main/models">Pre-trained Models</a></td>
          <td style="text-align: left">Model</td>
          <td style="text-align: left">MIT</td>
          <td style="text-align: left">Downloadable weights for Faster R-CNN detection backbones.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://dx.doi.org/10.6084/m9.figshare.24599412">Hand-drawn Dataset (Brinkhaus)</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Images and SMILES used for target domain fine-tuning and evaluation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://dx.doi.org/10.6084/m9.figshare.24599172">Relabeled Hand-drawn Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">1,417 images with bounding box annotations generated via EditKT*.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://huggingface.co/spaces/moldenhof/atomlenz">AtomLenz Web Demo</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Interactive Hugging Face space for testing model inference.</td>
      </tr>
  </tbody>
</table>
<h3 id="data">Data</h3>
<p>The study utilizes a mix of large synthetic datasets and smaller curated hand-drawn datasets.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Pretraining</strong></td>
          <td>Synthetic ChEMBL</td>
          <td>~214,000</td>
          <td>Generated via RDKit/Indigo. Annotated with atoms, bonds, charges, stereocenters.</td>
      </tr>
      <tr>
          <td><strong>Fine-tuning</strong></td>
          <td>Hand-drawn (Brinkhaus)</td>
          <td>4,070</td>
          <td>Used for weakly supervised adaptation (SMILES only).</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Hand-drawn Test</td>
          <td>1,018</td>
          <td></td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>ChemPix</td>
          <td>614</td>
          <td>Out-of-distribution hand-drawn images.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td>Atom Localization</td>
          <td>1,000</td>
          <td>Synthetic images with ground truth bounding boxes.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Molecular Graph Constructor (Algorithm 1):</strong> A rule-based system to assemble the graph from detected objects:
<ol>
<li><strong>Filtering:</strong> Removes overlapping atom boxes (IoU threshold).</li>
<li><strong>Node Creation:</strong> Merges overlapping charge and stereocenter objects with their corresponding atom objects.</li>
<li><strong>Edge Creation:</strong> Iterates over bond objects; if a bond overlaps with exactly two atoms, an edge is added. If &gt;2, it selects the most probable pair.</li>
<li><strong>Validation:</strong> Checks valency constraints; removes bonds iteratively if constraints are violated.</li>
</ol>
</li>
<li><strong>Weakly Supervised Training:</strong>
<ul>
<li><strong>ProbKT*:</strong> Uses Hungarian matching to align predicted objects with the &ldquo;ground truth&rdquo; implied by the SMILES string, allowing backpropagation without explicit boxes.</li>
<li><strong>Graph Edit-Correction:</strong> Finds the smallest edit on the predicted graph such that the corrected and true SMILES graphs become isomorphic, then uses the correction to generate pseudo-labels for retraining.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Object Detection Backbone:</strong> <strong>Faster R-CNN</strong>.
<ul>
<li>Four distinct models are trained for different entity types: Atoms ($O^a$), Bonds ($O^b$), Charges ($O^c$), and Stereocenters ($O^s$).</li>
<li><strong>Loss Function:</strong> Multi-task loss combining Multi-class Log Loss ($L_{cls}$) and Regression Loss ($L_{reg}$).</li>
</ul>
</li>
<li><strong>ChemExpert:</strong> An ensemble wrapper that prioritizes models based on user preference (e.g., DECIMER first, then AtomLenz). It accepts the first prediction that passes RDKit chemical validity checks.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Primary metrics focused on structural correctness and localization accuracy.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (Hand-drawn)</th>
          <th>Baseline (DECIMER FT)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (T=1)</strong></td>
          <td>33.8% (AtomLenz+EditKT*)</td>
          <td>62.2%</td>
          <td>Exact ECFP6 fingerprint match.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto Sim.</strong></td>
          <td>0.484</td>
          <td>0.727</td>
          <td>Average similarity.</td>
      </tr>
      <tr>
          <td><strong>mAP</strong></td>
          <td>0.801</td>
          <td>N/A</td>
          <td>Localization accuracy (IoU 0.05-0.35).</td>
      </tr>
      <tr>
          <td><strong>Ensemble Acc.</strong></td>
          <td><strong>63.5%</strong></td>
          <td>62.2%</td>
          <td>ChemExpert (DECIMER + AtomLenz).</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute:</strong> Experiments utilized the Flemish Supercomputer Center (VSC) resources.</li>
<li><strong>Note:</strong> Specific GPU models (e.g., A100/V100) are not explicitly detailed in the text, but Faster R-CNN training is standard on consumer or enterprise GPUs.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Oldenhof, M., De Brouwer, E., Arany, Á., &amp; Moreau, Y. (2024). Atom-Level Optical Chemical Structure Recognition with Limited Supervision. In <em>Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)</em>, 2024.</p>
<p><strong>Publication venue/year</strong>: CVPR 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/molden/atomlenz">Official Repository</a></li>
<li><a href="https://dx.doi.org/10.6084/m9.figshare.24599412">Hand-drawn Dataset on Figshare</a></li>
</ul>
<p><strong>BibTeX</strong>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{oldenhofAtomLevelOpticalChemical2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Atom-Level Optical Chemical Structure Recognition with Limited Supervision}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Oldenhof, Martijn and De Brouwer, Edward and Arany, {\&#39;A}d{\&#39;a}m and Moreau, Yves}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2404.01743}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs.CV}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SwinOCSR: End-to-End Chemical OCR with Swin Transformers</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/swinocsr/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/swinocsr/</guid><description>Deep learning model using Swin Transformer and Focal Loss for OCSR, achieving 98.58% accuracy on synthetic benchmarks.</description><content:encoded><![CDATA[<h2 id="contribution-methodological-architecture-and-datasets">Contribution: Methodological Architecture and Datasets</h2>
<p>This is a <strong>Methodological Paper</strong> with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architecture (Swin Transformer backbone) and a specific loss function optimization (Focal Loss) for the task of Optical Chemical Structure Recognition (OCSR).</li>
<li><strong>Resource</strong>: It constructs a large-scale synthetic dataset of 5 million molecules, specifically designing it to cover complex cases like substituents and aromatic rings.</li>
</ul>
<h2 id="motivation-addressing-visual-context-and-data-imbalance">Motivation: Addressing Visual Context and Data Imbalance</h2>
<ul>
<li><strong>Problem</strong>: OCSR (converting images of chemical structures to <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>) is difficult due to complex chemical patterns and long sequences. Existing deep learning methods (often CNN-based) struggle to achieve satisfactory recognition rates.</li>
<li><strong>Technical Gap</strong>: Standard CNN backbones (like ResNet or EfficientNet) focus on local feature extraction and miss global dependencies required for interpreting complex molecular diagrams.</li>
<li><strong>Data Imbalance</strong>: Chemical strings suffer from severe class imbalance (e.g., &lsquo;C&rsquo; and &lsquo;H&rsquo; are frequent; &lsquo;Br&rsquo; or &lsquo;Cl&rsquo; are rare), which causes standard Cross Entropy loss to underperform.</li>
</ul>
<h2 id="core-innovation-swin-transformers-and-focal-loss">Core Innovation: Swin Transformers and Focal Loss</h2>
<ul>
<li><strong>Swin Transformer Backbone</strong>: SwinOCSR replaces the standard CNN backbone with a <strong>Swin Transformer</strong>, using shifted window attention to capture both local and global image features more effectively.</li>
<li><strong>Multi-label Focal Loss (MFL)</strong>: The paper introduces a modified Focal Loss to OCSR, the first explicit attempt to address token imbalance in OCSR (per the authors). This penalizes the model for errors on rare tokens, addressing the &ldquo;long-tail&rdquo; distribution of chemical elements. The standard Focal Loss formulation heavily weights hard-to-classify examples:
$$
\begin{aligned}
FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t) \\
\end{aligned}
$$</li>
<li><strong>Structured Synthetic Dataset</strong>: Creation of a dataset explicitly balanced across four structural categories: Kekule rings, Aromatic rings, and their combinations with substituents.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<ul>
<li><strong>Backbone Comparison</strong>: The authors benchmarked SwinOCSR against the backbones of leading competitors: ResNet-50 (used in Image2SMILES) and EfficientNet-B3 (used in DECIMER 1.0).</li>
<li><strong>Loss Function Ablation</strong>: They compared the performance of standard Cross Entropy (CE) loss against their proposed Multi-label Focal Loss (MFL).</li>
<li><strong>Category Stress Test</strong>: Performance was evaluated separately on molecules with/without substituents and with/without aromaticity to test robustness.</li>
<li><strong>Real-world Evaluation</strong>: The model was tested on 100 images manually extracted from the literature (with manually labeled SMILES), and separately on 100 CDK-generated images from those same SMILES, to measure the domain gap between synthetic and real-world data.</li>
</ul>
<h2 id="results-and-limitations">Results and Limitations</h2>
<ul>
<li><strong>Synthetic test set performance</strong>: With Multi-label Focal Loss (MFL), SwinOCSR achieved <strong>98.58% accuracy</strong> on the synthetic test set, compared to 97.36% with standard CE loss. Both ResNet-50 (89.17%) and EfficientNet-B3 (86.70%) backbones scored lower when using CE loss (Table 3).</li>
<li><strong>Handling of long sequences</strong>: The model maintained high accuracy (94.76%) even on very long DeepSMILES strings (76-100 characters), indicating effective global feature extraction.</li>
<li><strong>Per-category results</strong>: Performance was consistent across molecule categories: Category 1 (Kekule, 98.20%), Category 2 (Aromatic, 98.46%), Category 3 (Kekule + Substituents, 98.76%), Category 4 (Aromatic + Substituents, 98.89%). The model performed slightly better on molecules with substituents and aromatic rings.</li>
<li><strong>Domain shift</strong>: While performance on synthetic data was strong, accuracy dropped to <strong>25%</strong> on 100 real-world literature images. On 100 CDK-generated images from the same SMILES strings, accuracy was 94%, confirming that the gap stems from stylistic differences between CDK-rendered and real-world images. The authors attribute this to noise, low resolution, and variations such as condensed structural formulas and abbreviations.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: The first 8.5 million structures from <strong>PubChem</strong> were downloaded, yielding ~6.9 million unique SMILES.</li>
<li><strong>Generation Pipeline</strong>:
<ul>
<li><strong>Tools</strong>: <strong>CDK</strong> (Chemistry Development Kit) for image rendering; <strong>RDKit</strong> for SMILES canonicalization.</li>
<li><strong>Augmentation</strong>: To ensure diversity, the dataset was split into 4 categories (1.25M each): (1) Kekule, (2) Aromatic, (3) Kekule + Substituents, (4) Aromatic + Substituents. Substituents were randomly added from a list of 224 common patent substituents.</li>
<li><strong>Preprocessing</strong>: Images rendered as binary, resized to <strong>224x224</strong>, and copied to 3 channels (RGB simulation).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>4,500,000</td>
          <td>18:1:1 split (Train/Val/Test)</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>250,000</td>
          <td></td>
      </tr>
      <tr>
          <td>Test</td>
          <td>Synthetic (PubChem-derived)</td>
          <td>250,000</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Loss Function</strong>: <strong>Multi-label Focal Loss (MFL)</strong>. The single-label classification task was cast as multi-label to apply Focal Loss, using a sigmoid activation on logits.</li>
<li><strong>Optimization</strong>:
<ul>
<li><strong>Optimizer</strong>: <strong>Adam</strong> with initial learning rate <code>5e-4</code>.</li>
<li><strong>Schedulers</strong>: Cosine decay for the Swin Transformer backbone; Step decay for the Transformer encoder/decoder.</li>
<li><strong>Regularization</strong>: Dropout rate of <code>0.1</code>.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Backbone (Encoder 1)</strong>: <strong>Swin Transformer</strong>.
<ul>
<li>Patch size: $4 \times 4$.</li>
<li>Linear embedding dimension: 192.</li>
<li>Structure: 4 stages with Swin Transformer Blocks (Window MSA + Shifted Window MSA).</li>
<li>Output: Flattened patch sequence $S_b$.</li>
</ul>
</li>
<li><strong>Transformer Encoder (Encoder 2)</strong>: 6 standard Transformer encoder layers. Uses Positional Embedding + Multi-Head Attention + MLP.</li>
<li><strong>Transformer Decoder</strong>: 6 standard Transformer decoder layers. Uses Masked Multi-Head Attention (to prevent look-ahead) + Multi-Head Attention (connecting to encoder output $S_e$).</li>
<li><strong>Tokenization</strong>: <strong>DeepSMILES</strong> format used (syntactically more robust than SMILES). Vocabulary size: <strong>76 tokens</strong> (76 unique characters found in dataset). Embedding dimension: 256.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Accuracy (Exact Match), Tanimoto Similarity (PubChem fingerprints), BLEU, ROUGE.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SwinOCSR (CE)</th>
          <th>SwinOCSR (MFL)</th>
          <th>ResNet-50 (CE)</th>
          <th>EfficientNet-B3 (CE)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td>97.36%</td>
          <td><strong>98.58%</strong></td>
          <td>89.17%</td>
          <td>86.70%</td>
      </tr>
      <tr>
          <td>Tanimoto</td>
          <td>99.65%</td>
          <td><strong>99.77%</strong></td>
          <td>98.79%</td>
          <td>98.46%</td>
      </tr>
      <tr>
          <td>BLEU</td>
          <td>99.46%</td>
          <td><strong>99.59%</strong></td>
          <td>98.62%</td>
          <td>98.37%</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>99.64%</td>
          <td><strong>99.78%</strong></td>
          <td>98.87%</td>
          <td>98.66%</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Trained on <strong>NVIDIA Tesla V100-PCIE</strong>.</li>
<li><strong>Training Time</strong>: 30 epochs.</li>
<li><strong>Batch Size</strong>: 256 images ($224 \times 224$ pixels).</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/suanfaxiaohuo/SwinOCSR">SwinOCSR</a></td>
          <td>Code + Data</td>
          <td>Unknown</td>
          <td>Official implementation with dataset and trained models</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xu, Z., Li, J., Yang, Z. et al. (2022). SwinOCSR: end-to-end optical chemical structure recognition using a Swin Transformer. <em>Journal of Cheminformatics</em>, 14(41). <a href="https://doi.org/10.1186/s13321-022-00624-5">https://doi.org/10.1186/s13321-022-00624-5</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/suanfaxiaohuo/SwinOCSR">GitHub Repository</a></li>
</ul>
]]></content:encoded></item><item><title>String Representations for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/benchmarks/rajan-string-representations-2022/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/benchmarks/rajan-string-representations-2022/</guid><description>Ablation study comparing SMILES, DeepSMILES, SELFIES, and InChI for OCSR. SMILES achieves highest accuracy; SELFIES guarantees validity.</description><content:encoded><![CDATA[<h2 id="empirical-focus-and-resource-contributions">Empirical Focus and Resource Contributions</h2>
<p>This is an <strong>Empirical Paper</strong> ($\Psi_{\text{Empirical}}$) with a secondary contribution as a <strong>Resource Paper</strong> ($\Psi_{\text{Resource}}$).</p>
<p>It functions as a systematic ablation study, keeping the model architecture (EfficientNet-B3 + Transformer) constant while varying the input/output representation (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, DeepSMILES, <a href="/notes/chemistry/molecular-representations/notations/selfies/">SELFIES</a>, <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>) to determine which format yields the best performance for Optical Chemical Structure Recognition (OCSR). It also contributes large-scale benchmarking datasets derived from ChEMBL and PubChem.</p>
<h2 id="the-syntax-challenge-in-chemical-image-recognition">The Syntax Challenge in Chemical Image Recognition</h2>
<p>Optical Chemical Structure Recognition (OCSR) is essential for extracting chemical information buried in scientific literature and patents. While deep learning offers a promising alternative to rule-based approaches, neural networks struggle with the syntax of standard chemical representations like SMILES. Specifically, the tokenization of SMILES strings (where ring closures and branches are marked by single characters potentially far apart in the sequence) creates learning difficulties for sequence-to-sequence models. Newer representations like DeepSMILES and SELFIES were developed to address these syntax issues, but their comparative performance in image-to-text tasks had not been rigorously benchmarked.</p>
<h2 id="isolating-string-representation-variables">Isolating String Representation Variables</h2>
<p>The core novelty is the <strong>comparative isolation of the string representation variable</strong> in an OCSR context. Previous approaches often selected a representation (usually SMILES) without validating if it was optimal for the learning task. This study specifically tests the hypothesis that syntax-robust representations (like SELFIES) improve deep learning performance compared to standard SMILES. It provides empirical evidence on the trade-off between <em>validity</em> (guaranteed by SELFIES) and <em>accuracy</em> (highest with SMILES).</p>
<h2 id="large-scale-image-to-text-translation-experiments">Large-Scale Image-to-Text Translation Experiments</h2>
<p>The authors performed a large-scale image-to-text translation experiment:</p>
<ul>
<li><strong>Task</strong>: Converting 2D chemical structure images into text strings.</li>
<li><strong>Data</strong>:
<ul>
<li><strong>ChEMBL</strong>: ~1.6M molecules, split into two datasets (with and without stereochemistry).</li>
<li><strong>PubChem</strong>: ~3M molecules, split similarly, to test performance scaling with data size.</li>
</ul>
</li>
<li><strong>Representations</strong>: The same chemical structures were converted into four formats: SMILES, DeepSMILES, SELFIES, and InChI.</li>
<li><strong>Metric</strong>: The models were evaluated on:
<ul>
<li><strong>Validity</strong>: Can the predicted string be decoded back to a molecule?</li>
<li><strong>Exact Match</strong>: Is the predicted string identical to the ground truth?</li>
<li><strong>Tanimoto Similarity</strong>: How chemically similar is the prediction to the ground truth (using PubChem fingerprints)? The similarity $\mathcal{T}$ between two molecular fingerprints $A$ and $B$ is calculated as:
$$ \mathcal{T}(A, B) = \frac{A \cdot B}{|A|^2 + |B|^2 - A \cdot B} $$</li>
</ul>
</li>
</ul>
<h2 id="comparative-performance-and-validity-trade-offs">Comparative Performance and Validity Trade-offs</h2>
<ul>
<li><strong>SMILES is the most accurate</strong>: Contrary to the hypothesis that syntax-robust formats would learn better, SMILES consistently achieved the highest exact match accuracy (up to 88.62% on PubChem data) and average Tanimoto similarity (0.98). This is likely due to SMILES having shorter string lengths and fewer unique tokens compared to SELFIES.</li>
<li><strong>SELFIES guarantees validity</strong>: While slightly less accurate in direct translation, SELFIES achieved 100% structural validity (every prediction could be decoded), whereas SMILES predictions occasionally contained syntax errors.</li>
<li><strong>InChI is unsuitable</strong>: InChI performed significantly worse (approx. 64% exact match) due to extreme maximum string lengths (up to 273 characters).</li>
<li><strong>Stereochemistry adds difficulty</strong>: Including stereochemistry reduced accuracy across all representations due to increased token count and visual complexity.</li>
<li><strong>Recommendation</strong>: Use SMILES for maximum accuracy; use SELFIES if generating valid structures is the priority (e.g., generative tasks).</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used curated subsets from ChEMBL and PubChem. Images were generated synthetically.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL (Dataset 1/2)</td>
          <td>~1.5M</td>
          <td>Filtered for MW &lt; 1500, specific elements (C,H,O,N,P,S,F,Cl,Br,I,Se,B).</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>PubChem (Dataset 3/4)</td>
          <td>~3.0M</td>
          <td>Same filtering rules, used to test scaling.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Test Split</td>
          <td>~120k - 250k</td>
          <td>Created using RDKit MaxMin algorithm to ensure chemical diversity.</td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation</strong>:</p>
<ul>
<li><strong>Tool</strong>: CDK Structure Diagram Generator (SDG).</li>
<li><strong>Specs</strong>: $300 \times 300$ pixels, rotated by random angles ($0-360^{\circ}$), saved as 8-bit PNG.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Tokenization Rules</strong> (Critical for replication):</p>
<ul>
<li><strong>SELFIES</strong>: Split at every <code>][</code> (e.g., <code>[C][N]</code> $\rightarrow$ <code>[C]</code>, <code>[N]</code>).</li>
<li><strong>SMILES / DeepSMILES</strong>: Regex-based splitting:
<ul>
<li>Every heavy atom (e.g., <code>C</code>, <code>N</code>).</li>
<li>Every bracket <code>(</code> and <code>)</code>.</li>
<li>Every bond symbol <code>=</code> and <code>#</code>.</li>
<li>Every single-digit number.</li>
<li>Everything inside square brackets <code>[]</code> is kept as a single token.</li>
</ul>
</li>
<li><strong>InChI</strong>: The prefix <code>InChI=1S/</code> was treated as a single token and removed during training, then re-added for evaluation.</li>
</ul>
<h3 id="models">Models</h3>
<p>The model follows the <strong>DECIMER</strong> architecture.</p>
<ul>
<li><strong>Encoder</strong>: EfficientNet-B3 (pre-trained with &ldquo;Noisy Student&rdquo; weights).
<ul>
<li>Output: Image feature vectors of shape $10 \times 10 \times 1536$.</li>
</ul>
</li>
<li><strong>Decoder</strong>: Transformer (similar to the &ldquo;Base&rdquo; model from <em>Attention Is All You Need</em>).
<ul>
<li>Layers: 4 encoder-decoder layers.</li>
<li>Attention Heads: 8.</li>
<li>Dimension ($d_{\text{model}}$): 512.</li>
<li>Feed-forward ($d_{\text{ff}}$): 2048.</li>
<li>Dropout: 10%.</li>
</ul>
</li>
<li><strong>Loss</strong>: Sparse categorical cross-entropy.</li>
<li><strong>Optimizer</strong>: Adam with custom learning rate scheduler.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics were calculated after converting all predictions back to standard SMILES.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Baseline (SMILES)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Identical Match</strong></td>
          <td>88.62% (PubChem)</td>
          <td>Strict character-for-character equality.</td>
      </tr>
      <tr>
          <td><strong>Valid Structure</strong></td>
          <td>99.78%</td>
          <td>SMILES had rare syntax errors; SELFIES achieved 100%.</td>
      </tr>
      <tr>
          <td><strong>Tanimoto (Avg)</strong></td>
          <td>0.98</td>
          <td>Calculated using PubChem fingerprints via CDK.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: Google Cloud TPUs (v3-8).</li>
<li><strong>Format</strong>: Data converted to TFRecords (128 image/text pairs per record) for TPU efficiency.</li>
<li><strong>Batch Size</strong>: 1024.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER_Short_Communication">DECIMER Short Communication</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Training and evaluation scripts (Python, Java)</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.5155037">Datasets on Zenodo</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>SMILES data and processing scripts</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Steinbeck, C., &amp; Zielesny, A. (2022). Performance of chemical structure string representations for chemical image recognition using transformers. <em>Digital Discovery</em>, 1(2), 84-90. <a href="https://doi.org/10.1039/D1DD00013F">https://doi.org/10.1039/D1DD00013F</a></p>
<p><strong>Publication</strong>: Digital Discovery 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://chemrxiv.org/doi/pdf/10.26434/chemrxiv-2021-7c9wf">ChemRxiv Preprint (PDF)</a></li>
<li><a href="https://github.com/Kohulan/DECIMER_Short_Communication">Official Code Repository</a></li>
<li><a href="https://doi.org/10.5281/zenodo.5155037">Data on Zenodo</a></li>
<li>Related work: <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer/">DECIMER</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-1.0/">DECIMER 1.0</a>, <a href="/notes/chemistry/optical-structure-recognition/image-to-sequence/img2smi/">IMG2SMI</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanPerformanceChemicalStructure2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Performance of Chemical Structure String Representations for Chemical Image Recognition Using Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Steinbeck, Christoph and Zielesny, Achim}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Digital Discovery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{84--90}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1039/D1DD00013F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>One Strike, You're Out: Detecting Markush Structures</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/jurriaans-markush-detection-2023/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/markush/jurriaans-markush-detection-2023/</guid><description>Patch-based CNN method for detecting Markush structures in chemical images, addressing low signal-to-noise ratios in OCSR.</description><content:encoded><![CDATA[<h2 id="methodology-and-classification">Methodology and Classification</h2>
<p>This is a <strong>Method</strong> paper (Classification: $\Psi_{\text{Method}}$).</p>
<p>It proposes a patch-based classification pipeline to solve a technical failure mode in Optical Chemical Structure Recognition (OCSR). Distinct rhetorical indicators include a baseline comparison (CNN vs. traditional ORB), ablation studies (architecture, pretraining), and a focus on evaluating the filtering efficacy against a known failure mode.</p>
<h2 id="the-markush-structure-challenge">The Markush Structure Challenge</h2>
<p><strong>The Problem</strong>: Optical Chemical Structure Recognition (OCSR) tools convert 2D images of molecules into machine-readable formats. These tools struggle with &ldquo;Markush structures,&rdquo; generic structural templates used frequently in patents that contain variables rather than specific atoms (e.g., $R$, $X$, $Y$).</p>
<p><strong>The Gap</strong>: Markush structures are difficult to detect because they often appear as small indicators (a single &ldquo;R&rdquo; or variable) within a large image, resulting in a very low Signal-to-Noise Ratio (SNR). Existing OCSR research pipelines typically bypass this by manually excluding these structures from their datasets.</p>
<p><strong>The Goal</strong>: To build an automated filter that can identify images containing Markush structures so they can be removed from OCSR pipelines, improving overall database quality without requiring manual data curation.</p>
<h2 id="patch-based-classification-pipeline">Patch-Based Classification Pipeline</h2>
<p>The core technical contribution is an end-to-end deep learning pipeline tailored for low-SNR chemical images where standard global resizing or cropping fails due to large variations in image resolution and pixel scales.</p>
<ul>
<li><strong>Patch Generation</strong>: The system slices input images into overlapping patches generated from two offset grids, ensuring that variables falling on boundaries are fully captured in at least one crop.</li>
<li><strong>Targeted Annotation</strong>: The labels rely on pixel-level bounding boxes around Markush indicators, minimizing the noise that would otherwise overwhelm a full-image classification attempt.</li>
<li><strong>Inference Strategy</strong>: During inference, the query image is broken into patches, individually classified, and aggregated entirely using a maximum pooling rule where $X = \max_{i=1}^{n} \{ x_i \}$.</li>
<li><strong>Evaluation</strong>: Provides the first systematic comparison between fixed-feature extraction (ORB + XGBoost) and end-to-end deep learning for this specific domain.</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The authors compared two distinct paradigms on a manually annotated dataset:</p>
<ol>
<li>
<p><strong>Fixed-Feature Baseline</strong>: Used <strong>ORB</strong> (Oriented FAST and Rotated BRIEF) to detect keypoints and match them against a template bank of known Markush symbols. Features (match counts, Hamming distances) were fed into an <strong>XGBoost</strong> model.</p>
</li>
<li>
<p><strong>Deep Learning Method</strong>: Fine-tuned <strong>ResNet18</strong> and <strong>Inception V3</strong> models on the generated image patches.</p>
<ul>
<li><strong>Ablations</strong>: Contrasted pretraining sources, evaluating general domain (ImageNet) against chemistry-specific domain (USPTO images).</li>
<li><strong>Fine-tuning</strong>: Compared full-network fine-tuning against freezing all but the fully connected layers.</li>
</ul>
</li>
</ol>
<p>To handle significant class imbalance, the primary evaluation metric was the Macro F1 score, defined as:</p>
<p>$$ \text{Macro F1} = \frac{1}{N} \sum_{i=1}^{N} \frac{2 \cdot \text{precision}_i \cdot \text{recall}_i}{\text{precision}_i + \text{recall}_i} $$</p>
<h2 id="performance-outcomes">Performance Outcomes</h2>
<ul>
<li>
<p><strong>CNN vs. ORB</strong>: Deep learning architectures outperformed the fixed-feature baseline. The best model (<strong>Inception V3</strong> pretrained on ImageNet) achieved an image-level Macro F1 of <strong>0.928</strong>, compared to <strong>0.701</strong> (image-level) for the ORB baseline, and a patch-level Macro F1 of <strong>0.917</strong>.</p>
</li>
<li>
<p><strong>The Pretraining Surprise</strong>: Counterintuitively, ImageNet pretraining consistently outperformed the domain-specific USPTO pretraining. The authors hypothesize that the filters learned from ImageNet pretraining generalize well outside the ImageNet domain, though why the USPTO-pretrained filters underperform remains unclear.</p>
</li>
<li>
<p><strong>Full Model Tuning</strong>: Unfreezing the entire network yielded higher performance than tuning only the classifier head, indicating that standard low-level visual filters require substantial adaptation to reliably distinguish chemical line drawings.</p>
</li>
<li>
<p><strong>Limitations and Edge Cases</strong>: The best CNN achieved an ROC AUC of <strong>0.97</strong> on the primary patch test set, while the ORB baseline scored <strong>0.81</strong> on the auxiliary dataset (the paper notes these ROC curves are not directly comparable due to different evaluation sets). The aggregation metric ($X = \max \{ x_i \}$) is naive and has not been optimized. Furthermore, the patching approach creates inherent label noise when a Markush indicator is cleanly bisected by a patch edge, potentially forcing the network to learn incomplete visual features.</p>
</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used a primary dataset labeled by domain experts and a larger auxiliary dataset for evaluation.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Val</strong></td>
          <td><strong>Primary Dataset</strong></td>
          <td>272 Images</td>
          <td>Manually annotated with bounding boxes for Markush indicators. Split 60/20/20.</td>
      </tr>
      <tr>
          <td><strong>Evaluation</strong></td>
          <td><strong>Auxiliary Dataset</strong></td>
          <td>~5.4k Images</td>
          <td>5117 complete structures, 317 Markush. Used for image-level testing only (no bbox).</td>
      </tr>
  </tbody>
</table>
<p><strong>Patch Generation</strong>:</p>
<ul>
<li>Images are cropped into patches of size <strong>224x224</strong> (ResNet) or <strong>299x299</strong> (Inception).</li>
<li>Patches are generated from 2 grids offset by half the patch width/height to ensure annotations aren&rsquo;t lost on edges.</li>
<li><strong>Labeling Rule</strong>: A patch is labeled &ldquo;Markush&rdquo; if &gt;50% of an annotation&rsquo;s pixels fall inside it.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>ORB (Baseline)</strong>:</p>
<ul>
<li>Matches query images against a bank of template patches containing Markush indicators.</li>
<li><strong>Features</strong>: Number of keypoints, number of matches, Hamming distance of best 5 matches.</li>
<li><strong>Classifier</strong>: XGBoost trained on these features.</li>
<li><strong>Hyperparameters</strong>: Search over number of features (500-2000) and template patches (50-250).</li>
</ul>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Framework</strong>: PyTorch with Optuna for optimization.</li>
<li><strong>Optimization</strong>: 25 trials per configuration.</li>
<li><strong>Augmentations</strong>: Random perspective shift, posterization, sharpness/blur.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main architectures were compared.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Input Size</th>
          <th>Parameters</th>
          <th>Pretraining Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ResNet18</strong></td>
          <td>224x224</td>
          <td>11.5M</td>
          <td>ImageNet</td>
      </tr>
      <tr>
          <td><strong>Inception V3</strong></td>
          <td>299x299</td>
          <td>23.8M</td>
          <td>ImageNet &amp; USPTO</td>
      </tr>
  </tbody>
</table>
<p><strong>Best Configuration</strong>: Inception V3, ImageNet weights, Full Model fine-tuning (all layers unfrozen).</p>
<h3 id="evaluation">Evaluation</h3>
<p>Primary metric was <strong>Macro F1</strong> due to class imbalance.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best CNN (Inception V3)</th>
          <th>Baseline (ORB)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Patch Test F1</strong></td>
          <td>$0.917 \pm 0.014$</td>
          <td>N/A</td>
          <td>ORB does not support patch-level</td>
      </tr>
      <tr>
          <td><strong>Image Test F1</strong></td>
          <td>$0.928 \pm 0.035$</td>
          <td>$0.701 \pm 0.052$</td>
          <td>CNN aggregates patch predictions</td>
      </tr>
      <tr>
          <td><strong>Aux Test F1</strong></td>
          <td>0.914</td>
          <td>0.533</td>
          <td>Evaluation on large secondary dataset</td>
      </tr>
      <tr>
          <td><strong>ROC AUC</strong></td>
          <td>0.97</td>
          <td>0.81</td>
          <td></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Tesla V100-SXM2-16GB</li>
<li><strong>CPU</strong>: Intel Xeon E5-2686 @ 2.30GHz</li>
<li><strong>RAM</strong>: 64 GB</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Thomasjurriaans/markush-recognition-msc-thesis">GitHub Repository</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>MSc thesis code: CNN training, ORB baseline, evaluation scripts</td>
      </tr>
  </tbody>
</table>
<p>The primary dataset was manually annotated by Elsevier domain experts and is not publicly available. The auxiliary dataset (from Elsevier) is also not public. Pre-trained model weights are not released in the repository.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Jurriaans, T., Szarkowska, K., Nalisnick, E., Schwörer, M., Thorne, C., &amp; Akhondi, S. (2023). One Strike, You&rsquo;re Out: Detecting Markush Structures in Low Signal-to-Noise Ratio Images. <em>arXiv preprint arXiv:2311.14633</em>. <a href="https://doi.org/10.48550/arXiv.2311.14633">https://doi.org/10.48550/arXiv.2311.14633</a></p>
<p><strong>Publication</strong>: arXiv 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Thomasjurriaans/markush-recognition-msc-thesis">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{jurriaansOneStrikeYoure2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{One {{Strike}}, {{You}}&#39;re {{Out}}: {{Detecting Markush Structures}} in {{Low Signal-to-Noise Ratio Images}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{One {{Strike}}, {{You}}&#39;re {{Out}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Jurriaans, Thomas and Szarkowska, Kinga and Nalisnick, Eric and Schwoerer, Markus and Thorne, Camilo and Akhondi, Saber}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = nov,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2311.14633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MICER: Molecular Image Captioning with Transfer Learning</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/micer/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/micer/</guid><description>Encoder-decoder model using pre-trained ResNet and attention-based LSTM to translate molecular images into SMILES strings, reaching 97.54% sequence accuracy.</description><content:encoded><![CDATA[<h2 id="micers-contribution-to-optical-structure-recognition">MICER&rsquo;s Contribution to Optical Structure Recognition</h2>
<p>This is a <strong>Method</strong> paper according to the AI for Physical Sciences taxonomy. It proposes MICER, an encoder-decoder architecture that integrates transfer learning (fine-tuning pre-trained models) and attention mechanisms for Optical Chemical Structure Recognition (OCSR). The study includes rigorous benchmarking comparing MICER against three rule-based tools (OSRA, MolVec, Imago) and existing deep learning methods (DECIMER). The authors conduct extensive factor comparison experiments to isolate the effects of stereochemistry, molecular complexity, data volume, and encoder backbone choices.</p>
<h2 id="the-challenge-of-generalizing-in-ocsr">The Challenge of Generalizing in OCSR</h2>
<p>Chemical structures in scientific literature are valuable for drug discovery, but they are locked in image formats that are difficult to mine automatically. Traditional OCSR tools (like OSRA) rely on hand-crafted rules and expert knowledge. They are brittle, struggle with stylistic variations, and have low generalization ability. While deep learning has been applied (e.g., DECIMER), previous attempts often used frozen pre-trained feature extractors (without fine-tuning) or failed to fully exploit transfer learning, leading to suboptimal performance. The goal of this work is to build an end-to-end &ldquo;image captioning&rdquo; system that translates molecular images directly into <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a> strings without intermediate segmentation steps.</p>
<h2 id="integrating-fine-tuning-and-attention-for-chemistry">Integrating Fine-Tuning and Attention for Chemistry</h2>
<p>The core novelty lies in the specific architectural integration of transfer learning with fine-tuning for the chemical domain. Unlike DECIMER, which used a frozen network, MICER fine-tunes a pre-trained ResNet on molecular images. This allows the encoder to adapt from general object recognition to specific chemical feature extraction.</p>
<p>The model incorporates an attention mechanism into the LSTM decoder, allowing the model to focus on specific image regions (atoms and bonds) when generating each character of the SMILES string. The paper explicitly analyzes &ldquo;intrinsic features&rdquo; of molecular data (stereochemistry, complexity) to guide the design of the training dataset, combining multiple chemical toolkits (Indigo, RDKit) to generate diverse styles.</p>
<h2 id="experimental-setup-and-ablation-studies">Experimental Setup and Ablation Studies</h2>
<p>The authors performed two types of experiments: Factor Comparison (ablations) and Benchmarking.</p>
<p><strong>Factor Comparisons</strong>: They evaluated how performance is affected by:</p>
<ul>
<li><strong>Stereochemistry (SI)</strong>: Comparing models trained on data with and without stereochemical information.</li>
<li><strong>Molecular Complexity (MC)</strong>: Analyzing performance across 5 molecular weight intervals.</li>
<li><strong>Data Volume (DV)</strong>: Training on datasets ranging from 0.64 million to 10 million images.</li>
<li><strong>Pre-trained Models (PTMs)</strong>: Comparing 8 different backbones (e.g., ResNet, VGG, Inception, MobileNet) versus a base CNN.</li>
</ul>
<p><strong>Benchmarking</strong>:</p>
<ul>
<li><strong>Baselines</strong>: OSRA, MolVec, Imago (rule-based); Base CNN, DECIMER (deep learning).</li>
<li><strong>Datasets</strong>: Four test sets (100k images each, except UOB): Uni-style, Multi-style, Noisy, and Real-world (UOB dataset).</li>
<li><strong>Metrics</strong>: Sequence Accuracy (Exact Match), Levenshtein Distance (ALD), and Tanimoto Similarity (Fingerprint match).</li>
</ul>
<h2 id="results-and-core-insights">Results and Core Insights</h2>
<p>MICER achieved 97.54% Sequence Accuracy on uni-style data and 82.33% on the real-world UOB dataset, outperforming rule-based and deep learning baselines across all four test sets.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>SA (%)</th>
          <th>AMFTS (%)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Uni-style</td>
          <td>OSRA</td>
          <td>23.14</td>
          <td>56.83</td>
      </tr>
      <tr>
          <td>Uni-style</td>
          <td>DECIMER</td>
          <td>35.32</td>
          <td>86.92</td>
      </tr>
      <tr>
          <td>Uni-style</td>
          <td><strong>MICER</strong></td>
          <td><strong>97.54</strong></td>
          <td><strong>99.74</strong></td>
      </tr>
      <tr>
          <td>Multi-style</td>
          <td>OSRA</td>
          <td>15.68</td>
          <td>44.50</td>
      </tr>
      <tr>
          <td>Multi-style</td>
          <td><strong>MICER</strong></td>
          <td><strong>95.09</strong></td>
          <td><strong>99.28</strong></td>
      </tr>
      <tr>
          <td>Noisy</td>
          <td><strong>MICER</strong></td>
          <td><strong>94.95</strong></td>
          <td><strong>99.25</strong></td>
      </tr>
      <tr>
          <td>UOB (real-world)</td>
          <td>OSRA</td>
          <td>80.24</td>
          <td>91.17</td>
      </tr>
      <tr>
          <td>UOB (real-world)</td>
          <td>DECIMER</td>
          <td>21.75</td>
          <td>65.15</td>
      </tr>
      <tr>
          <td>UOB (real-world)</td>
          <td><strong>MICER</strong></td>
          <td><strong>82.33</strong></td>
          <td><strong>94.47</strong></td>
      </tr>
  </tbody>
</table>
<p>ResNet101 was identified as the most effective encoder (87.58% SA in preliminary tests on 0.8M images), outperforming deeper (DenseNet121 at 81.41%) and lighter (MobileNetV2 at 39.83%) networks. Performance saturates around 6 million training samples, reaching 98.84% SA. Stereochemical information drops accuracy by approximately 6.1% (from 87.61% to 81.50%), indicating wedge and dash bonds are harder to recognize. Visualizing attention maps showed the model correctly attends to specific atoms (e.g., focusing on &lsquo;S&rsquo; or &lsquo;Cl&rsquo; pixels) when generating the corresponding character.</p>
<h2 id="limitations">Limitations</h2>
<p>The authors acknowledge several limitations. MICER struggles with superatoms, R-groups, text labels, and uncommon atoms (e.g., Sn) that were not seen during training. On noisy data, noise spots near Cl atoms can cause misclassification as O atoms. Complex molecular images with noise lead to misrecognition of noise points as single bonds and wedge-shaped bonds as double bonds. All methods, including MICER, have substantial room for improvement on real-world datasets that contain these challenging elements.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data was curated from the <strong>ZINC20</strong> database.</p>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Filtering</strong>: Removed organometallics, mixtures, and invalid molecules.</li>
<li><strong>Standardization</strong>: SMILES were canonicalized and de-duplicated.</li>
<li><strong>Generation</strong>: Images generated using <strong>Indigo</strong> and <strong>RDKit</strong> toolkits to vary styles.</li>
</ul>
<p><strong>Dataset Size</strong>:</p>
<ul>
<li><strong>Total</strong>: 10 million images selected for the final model.</li>
<li><strong>Composition</strong>: 6 million &ldquo;default style&rdquo; (Indigo) + 4 million &ldquo;multi-style&rdquo; (Indigo + RDKit).</li>
<li><strong>Splits</strong>: 8:1:1 ratio for Training/Validation/Test.</li>
</ul>
<p><strong>Vocabulary</strong>: A token dictionary of 39 SMILES characters plus 3 special tokens: <code>[pad]</code>, <code>[sos]</code>, <code>[eos]</code>, <code>[0]</code>-<code>[9]</code>, <code>[C]</code>, <code>[l]</code>, <code>[c]</code>, <code>[O]</code>, <code>[N]</code>, <code>[n]</code>, <code>[F]</code>, <code>[H]</code>, <code>[o]</code>, <code>[S]</code>, <code>[s]</code>, <code>[B]</code>, <code>[r]</code>, <code>[I]</code>, <code>[i]</code>, <code>[P]</code>, <code>[p]</code>, <code>(</code>, <code>)</code>, <code>[</code>, <code>]</code>, <code>@</code>, <code>=</code>, <code>#</code>, <code>/</code>, <code>-</code>, <code>+</code>, <code>\</code>, <code>%</code>. Two-letter atoms like &lsquo;Br&rsquo; are tokenized as distinct characters <code>[B]</code>, <code>[r]</code>, and &lsquo;Cl&rsquo; as <code>[C]</code>, <code>[l]</code>.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Tokenization</strong>: Character-level tokenization (not atom-level); the model learns to assemble &lsquo;C&rsquo; and &rsquo;l&rsquo; into &lsquo;Cl&rsquo;.</li>
<li><strong>Attention Mechanism</strong>: Uses a soft attention mechanism where the decoder calculates an attention score between the encoder&rsquo;s feature map ($8 \times 8 \times 512$) and the current hidden vector. Formula:
$$
\begin{aligned}
\text{att_score} &amp;= \text{softmax}(L_a(\tanh(L_f(F) + L_b(b_t))))
\end{aligned}
$$</li>
<li><strong>Training Configuration</strong>:
<ul>
<li><strong>Loss Function</strong>: Cross-entropy loss</li>
<li><strong>Optimizer</strong>: Adam optimizer</li>
<li><strong>Learning Rate</strong>: 2e-5</li>
<li><strong>Batch Size</strong>: 256</li>
<li><strong>Epochs</strong>: 15</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Encoder</strong>:</p>
<ul>
<li><strong>Backbone</strong>: Pre-trained <strong>ResNet101</strong> (trained on ImageNet).</li>
<li><strong>Modifications</strong>: The final layer is removed to output a Feature Map of size $8 \times 8 \times 512$.</li>
<li><strong>Flattening</strong>: Reshaped to a $64 \times 512$ feature matrix for the decoder.</li>
</ul>
<p><strong>Decoder</strong>:</p>
<ul>
<li><strong>Type</strong>: Long Short-Term Memory (LSTM) with Attention.</li>
<li><strong>Dropout</strong>: 0.3 applied to minimize overfitting.</li>
</ul>
<p>The encoder uses a pilot network (for universal feature extraction), a max-pooling layer, and multiple feature extraction layers containing convolutional blocks (CBs), feeding into the attention LSTM.</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>SA (Sequence Accuracy)</strong>: Strict exact match of SMILES strings.</li>
<li><strong>ALD (Average Levenshtein Distance)</strong>: Edit distance for character-level error analysis.</li>
<li><strong>AMFTS / <a href="mailto:MFTS@1.0">MFTS@1.0</a></strong>: Tanimoto similarity of ECFP4 fingerprints to measure structural similarity.</li>
</ul>
<p><strong>Test Sets</strong>:</p>
<ul>
<li><strong>Uni-style</strong>: 100,000 images (Indigo default).</li>
<li><strong>Multi-style</strong>: 100,000 images (&gt;10 styles).</li>
<li><strong>Noisy</strong>: 100,000 images with noise added.</li>
<li><strong>UOB</strong>: 5,575 real-world images from literature.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA Tesla V100 GPUs</li>
<li><strong>Training Time</strong>: Approximately 42 hours for the final model</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Jiacai-Yi/MICER">MICER</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<p>The training data (generated from ZINC20) and pre-trained model weights are not publicly released. The repository contains code but has minimal documentation (2 commits, no description).</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yi, J., Wu, C., Zhang, X., Xiao, X., Qiu, Y., Zhao, W., Hou, T., &amp; Cao, D. (2022). MICER: a pre-trained encoder-decoder architecture for molecular image captioning. <em>Bioinformatics</em>, 38(19), 4562-4572. <a href="https://doi.org/10.1093/bioinformatics/btac545">https://doi.org/10.1093/bioinformatics/btac545</a></p>
<p><strong>Publication</strong>: Bioinformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Jiacai-Yi/MICER">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{yiMICERPretrainedEncoder2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{MICER}}: A Pre-Trained Encoder--Decoder Architecture for Molecular Image Captioning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{MICER}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Yi, Jiacai and Wu, Chengkun and Zhang, Xiaochen and Xiao, Xinyi and Qiu, Yanlong and Zhao, Wentao and Hou, Tingjun and Cao, Dongsheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{38}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{19}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{4562--4572}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1367-4811}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1093/bioinformatics/btac545}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image2SMILES: Transformer OCSR with Synthetic Data Pipeline</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/image2smiles/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/image2smiles/</guid><description>Transformer-based OCSR using a novel synthetic data generation pipeline for robust molecular image interpretation across diverse drawing styles.</description><content:encoded><![CDATA[<h2 id="contribution-image2smiles-as-a-method-and-resource">Contribution: Image2SMILES as a Method and Resource</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a specific neural architecture (ResNet backbone and Transformer Decoder) to solve the Optical Chemical Structure Recognition (OCSR) task, answering &ldquo;How well does this work?&rdquo; with extensive benchmarks against rule-based systems like OSRA.</li>
<li><strong>Resource</strong>: A core contribution is the &ldquo;Generate and Train!&rdquo; paradigm, where the authors release a comprehensive synthetic data generator to overcome the lack of labeled training data in the field.</li>
</ul>
<h2 id="motivation-bottlenecks-in-recognizing-trapped-chemical-structures">Motivation: Bottlenecks in Recognizing Trapped Chemical Structures</h2>
<p>Retrieving chemical structure data from legacy scientific literature is a major bottleneck in cheminformatics.</p>
<ul>
<li><strong>Problem</strong>: Chemical structures are often &ldquo;trapped&rdquo; in image formats (PDFs, scans). Manual extraction is slow, and existing rule-based tools (e.g., OSRA) are brittle when facing diverse drawing styles, &ldquo;Markush&rdquo; structures (templates), or visual contamination.</li>
<li><strong>Gap</strong>: Deep learning approaches require massive datasets, but no large-scale annotated dataset of chemical figures exists.</li>
<li><strong>Goal</strong>: To create a robust, data-driven recognition engine that can handle the messiness of real-world chemical publications (e.g., text overlays, arrows, partial overlaps).</li>
</ul>
<h2 id="core-innovation-the-generate-and-train-pipeline-and-fg-smiles">Core Innovation: The &ldquo;Generate and Train!&rdquo; Pipeline and FG-SMILES</h2>
<ul>
<li><strong>&ldquo;Generate and Train!&rdquo; Paradigm</strong>: The authors assert that architecture is secondary to data simulation. They developed an advanced augmentation pipeline that simulates geometry (rotation, bonds) alongside specific chemical drawing artifacts like &ldquo;Markush&rdquo; variables ($R_1$, $R_2$), functional group abbreviations (e.g., -OMe, -Ph), and visual &ldquo;contamination&rdquo; (stray text, arrows).</li>
<li><strong>FG-SMILES</strong>: A modified SMILES syntax designed to handle functional groups and Markush templates as single tokens (pseudo-atoms), allowing the model to predict generalized scaffolds.</li>
<li><strong>Encoder-Free Architecture</strong>: The authors found that a standard Transformer Encoder was unnecessary. They feed the flattened feature map from a ResNet backbone directly into the Transformer Decoder, which improved performance.</li>
</ul>
<h2 id="methodology-and-benchmarking-against-osra">Methodology and Benchmarking Against OSRA</h2>
<ul>
<li><strong>Training</strong>: The model was trained on 10 million synthetically generated images derived from PubChem structures, selected via a complexity-biased sampling algorithm.</li>
<li><strong>Validation (Synthetic)</strong>: Evaluated on a hold-out set of 1M synthetic images.</li>
<li><strong>Validation (Real World)</strong>:
<ul>
<li><strong>Dataset A</strong>: 332 manually cropped structures from 10 specific articles, excluding reaction schemes.</li>
<li><strong>Dataset B</strong>: 296 structures systematically extracted from <em>Journal of Organic Chemistry</em> (one paper per issue from 2020) to reduce selection bias.</li>
</ul>
</li>
<li><strong>Comparison</strong>: Benchmarked against OSRA (v2.11), a widely used rule-based OCSR tool.</li>
</ul>
<h2 id="results-high-precision-extraction-and-key-limitations">Results: High-Precision Extraction and Key Limitations</h2>
<ul>
<li><strong>Performance</strong>:
<ul>
<li><strong>Synthetic</strong>: 90.7% exact match accuracy.</li>
<li><strong>Real Data (Dataset A)</strong>: Image2SMILES achieved <strong>79.2%</strong> accuracy compared to OSRA&rsquo;s <strong>62.1%</strong>.</li>
<li><strong>Real Data (Dataset B)</strong>: Image2SMILES achieved <strong>62.5%</strong> accuracy compared to OSRA&rsquo;s <strong>24.0%</strong>.</li>
</ul>
</li>
<li><strong>Confidence Correlation</strong>: There is a strong correlation between the model&rsquo;s confidence score and prediction validity. Thresholding at 0.995 yields 99.85% accuracy while ignoring 22.5% of data, enabling high-precision automated pipelines.</li>
<li><strong>Key Failures</strong>: The model struggles with functional groups absent from its training dictionary (e.g., $\text{NMe}_2$, Ms), confusion of R-group indices ($R&rsquo;$ vs $R_1$), and explicit hydrogens rendered as groups.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: A subset of 10 million molecules sampled from PubChem.</li>
<li><strong>Selection Logic</strong>: Bias towards complex/rare structures using a &ldquo;Full Coefficient&rdquo; (FC) probability metric based on molecule size and ring/atom rarity.
<ul>
<li>Formula: $BC=0.1+1.2\left(\frac{n_{\max}-n}{n_{\max}}\right)^{3}$ where $n_{\max}=60$.</li>
</ul>
</li>
<li><strong>Generation</strong>: Uses RDKit for rendering with augmentations: rotation, font size, line thickness, whitespace, and CoordGen (20% probability).</li>
<li><strong>Contamination</strong>: &ldquo;Visual noise&rdquo; is stochastically added, including parts of other structures, labels, and arrows cropped from real documents.</li>
<li><strong>Target Format</strong>: <strong>FG-SMILES</strong> (Functional Group SMILES). Replaces common functional groups with pseudo-atoms (e.g., [Me], [Ph], [NO2]) and supports variable R-group positions using a <code>v</code> token.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Contamination Augmentation</strong>: A dedicated algorithm simulates visual noise (arrows, text) touching or overlapping the main molecule to force robustness.</li>
<li><strong>Functional Group Resolution</strong>: An algorithm identifies overlapping functional group templates (SMARTS) and resolves them to prevent nested group conflicts (e.g., resolving Methyl vs Methoxy).</li>
<li><strong>Markush Support</strong>: Stochastic replacement of substituents with R-group labels ($R_1$, $R&rsquo;$, etc.) based on a defined probability table (e.g., $P(R)=0.2$, $P(R_1)=0.15$).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: &ldquo;Image-to-Sequence&rdquo; hybrid model.
<ul>
<li><strong>Backbone</strong>: ResNet-50, but with the last two residual blocks removed. Output shape: $512 \times 48 \times 48$.</li>
<li><strong>Neck</strong>: No Transformer Encoder. CNN features are flattened and passed directly to the Decoder.</li>
<li><strong>Decoder</strong>: Standard Transformer Decoder with parameters from the original Transformer architecture.</li>
</ul>
</li>
<li><strong>Input</strong>: Images resized to $384 \times 384 \times 3$.</li>
<li><strong>Output</strong>: Sequence of FG-SMILES tokens.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric</strong>: Binary &ldquo;Exact Match&rdquo; (valid/invalid).
<ul>
<li>Strict criteria: Stereo and R-group indices must match exactly (e.g., $R&rsquo;$ vs $R_1$ is a failure).</li>
</ul>
</li>
<li><strong>Datasets</strong>:
<ul>
<li><strong>Internal</strong>: 5% random split of generated data (500k samples).</li>
<li><strong>External (Dataset A &amp; B)</strong>: Manually cropped real-world images from specified journals.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4 $\times$ Nvidia V100 GPUs + 36 CPU cores.</li>
<li><strong>Duration</strong>: ~2 weeks for training (5 epochs, ~63 hours/epoch). Data generation took 3 days on 80 CPUs.</li>
<li><strong>Optimizer</strong>: RAdam with learning rate $3 \cdot 10^{-4}$.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/syntelly/img2smiles_generator">Data Generator (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Synthetic training data generator</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.5069806">1M Generated Samples (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Randomly generated image-SMILES pairs</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.5356500">Real-World Test Images (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Cropped structures from real papers with target FG-SMILES</td>
      </tr>
      <tr>
          <td><a href="https://app.syntelly.com/pdf2smiles">Syntelly Demo</a></td>
          <td>Other</td>
          <td>Proprietary</td>
          <td>Web demo for PDF-to-SMILES extraction</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Khokhlov, I., Krasnov, L., Fedorov, M. V., &amp; Sosnin, S. (2022). Image2SMILES: Transformer-Based Molecular Optical Recognition Engine. <em>Chemistry-Methods</em>, 2(1), e202100069. <a href="https://doi.org/10.1002/cmtd.202100069">https://doi.org/10.1002/cmtd.202100069</a></p>
<p><strong>Publication</strong>: Chemistry-Methods 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/syntelly/img2smiles_generator">Official Code (Data Generator)</a></li>
<li><a href="https://app.syntelly.com/pdf2smiles">Syntelly Demo</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{khokhlovImage2SMILESTransformerBasedMolecular2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Image2SMILES: Transformer-Based Molecular Optical Recognition Engine}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Image2SMILES}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Khokhlov, Ivan and Krasnov, Lev and Fedorov, Maxim V. and Sosnin, Sergey}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Chemistry-Methods}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{e202100069}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2628-9725}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1002/cmtd.202100069}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://chemistry-europe.onlinelibrary.wiley.com/doi/10.1002/cmtd.202100069}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Image-to-Graph Transformers for Chemical Structures</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/image-to-graph-transformers/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/image-to-graph-transformers/</guid><description>A deep learning model that converts molecular images directly into graph structures, enabling recognition of abbreviated non-atomic symbols.</description><content:encoded><![CDATA[<h2 id="contribution-and-taxonomic-classification">Contribution and Taxonomic Classification</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel deep learning architecture designed to extract molecular structures from images by directly predicting the graph topology. The paper validates this approach through ablation studies (comparing ResNet-only baselines to the Transformer-augmented model) and extensive benchmarking against existing tools.</p>
<h2 id="the-challenge-with-smiles-and-non-atomic-symbols">The Challenge with SMILES and Non-Atomic Symbols</h2>
<ul>
<li><strong>Handling Abbreviations:</strong> Chemical structures in scientific literature often use non-atomic symbols (superatoms like &ldquo;R&rdquo; or &ldquo;Ph&rdquo;) to reduce complexity. Standard tools that generate SMILES strings fail here because SMILES syntax does not support arbitrary non-atomic symbols.</li>
<li><strong>Robustness to Style:</strong> Existing rule-based tools are brittle to the diverse drawing styles found in literature.</li>
<li><strong>Data Utilization:</strong> Pixel-wise graph recognition tools (like ChemGrapher) require expensive pixel-level labeling. An end-to-end approach can utilize massive amounts of image-molecule pairs (like USPTO data) without needing exact coordinate labels.</li>
</ul>
<h2 id="the-image-to-graph-i2g-architecture">The Image-to-Graph (I2G) Architecture</h2>
<p>The core novelty is the <strong>Image-to-Graph (I2G)</strong> architecture that bypasses string representations entirely:</p>
<ul>
<li><strong>Hybrid Encoder:</strong> Combines a ResNet backbone (for locality) with a Transformer encoder (for global context), allowing the model to capture relationships between atoms that are far apart in the image.</li>
<li><strong>Graph Decoder (GRAT):</strong> A modified Transformer decoder that generates the graph auto-regressively. It uses feature-wise transformations to modulate attention weights based on edge information (bond types).</li>
<li><strong>Coordinate-Aware Training:</strong> The model is forced to predict the exact 2D coordinates of atoms in the source image. Combined with auxiliary losses, this boosts SMI accuracy from 0.009 to 0.567 on the UoB ablation (Table 1 in the paper).</li>
</ul>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<ul>
<li><strong>Baselines:</strong> The model was compared against OSRA (rule-based), MolVec (rule-based), and ChemGrapher (deep learning pixel-wise).</li>
<li><strong>Benchmarks:</strong> Evaluated on four standard datasets: UoB, USPTO, CLEF, and JPO. Images were converted to PDF and back to simulate degradation.</li>
<li><strong>Large Molecule Test:</strong> A custom dataset (<strong>OLED</strong>) was created from 12 journal papers (434 images) to test performance on larger, more complex structures (average 52.8 atoms).</li>
<li><strong>Ablations:</strong> The authors tested the impact of the Transformer encoder, auxiliary losses, and coordinate prediction.</li>
</ul>
<h2 id="empirical-results-and-robustness">Empirical Results and Robustness</h2>
<ul>
<li><strong>Benchmark Performance:</strong> The proposed model outperformed existing models with a 17.1% relative improvement on benchmark datasets.</li>
<li><strong>Robustness:</strong> On large molecules (OLED dataset), it achieved a 12.8% relative improvement over MolVec (and 20.0% over OSRA).</li>
<li><strong>Data Scaling:</strong> Adding real-world USPTO data to the synthetic training set improved performance by 20.5%, demonstrating the model&rsquo;s ability to learn from noisy, unlabeled coordinates.</li>
<li><strong>Handling Superatoms:</strong> The model successfully recognized pseudo-atoms (e.g., $R_1$, $R_2$, $R_3$) as distinct nodes. OSRA, which outputs SMILES, collapsed them into generic &ldquo;Any&rdquo; atoms since SMILES does not support non-atomic symbols. MolVec could not recognize them properly at all.</li>
</ul>
<h2 id="limitations-and-error-analysis">Limitations and Error Analysis</h2>
<p>The paper identifies two main failure modes on the USPTO, CLEF, and JPO benchmarks:</p>
<ol>
<li><strong>Unrecognized superatoms:</strong> The model struggles with complex multi-character superatoms not seen during training (e.g., NHNHCOCH$_3$ or H$_3$CO$_2$S). The authors propose character-level atom decoding as a future solution.</li>
<li><strong>Caption interference:</strong> The model sometimes misidentifies image captions as atoms, particularly on the JPO dataset. Data augmentation with arbitrary caption text or a dedicated image segmentation step could mitigate this.</li>
</ol>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors used a combination of synthetic and real-world data for training.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td><strong>PubChem</strong></td>
          <td>4.6M</td>
          <td>Synthetic images generated using RDKit. Random superatoms (e.g., $CF_3$, $NO_2$) were substituted to simulate abbreviations.</td>
      </tr>
      <tr>
          <td>Training</td>
          <td><strong>USPTO</strong></td>
          <td>2.5M</td>
          <td>Real image-molecule pairs from patents. Used for robustness; lacks coordinate labels.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td><strong>Benchmarks</strong></td>
          <td>~5.7k</td>
          <td>UoB, USPTO, CLEF, JPO. Average ~15.8 atoms per molecule.</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td><strong>OLED</strong></td>
          <td>434</td>
          <td>Manually segmented from 12 journal papers. Large molecules (avg 52.8 atoms).</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing:</strong></p>
<ul>
<li>Input resolution is fixed at $800 \times 800$ pixels.</li>
<li>Images are virtually split into a $25 \times 25$ grid (625 patches total), where each patch is $32 \times 32$ pixels.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Encoder Logic:</strong></p>
<ul>
<li><strong>Grid Serialization:</strong> The $25 \times 25$ grid is flattened into a 1D sequence. 2D position information is concatenated to ResNet features before the Transformer.</li>
<li><strong>Auxiliary Losses:</strong> To aid convergence, classifiers on the encoder predict three things <em>per patch</em>: (1) number of atoms, (2) characters in atom labels, and (3) edge-sharing neighbors. These losses decrease to zero during training.</li>
</ul>
<p><strong>Decoder Logic:</strong></p>
<ul>
<li><strong>Auto-regressive Generation:</strong> At step $t$, the decoder generates a new node and connects it to existing nodes.</li>
<li><strong>Attention Modulation:</strong> Attention weights are transformed using bond information:
$$
\begin{aligned}
\text{Att}(Q, K, V) = \text{softmax} \left( \frac{\Gamma \odot (QK^T) + B}{\sqrt{d_k}} \right) V
\end{aligned}
$$
where $(\gamma_{ij}, \beta_{ij}) = f(e_{ij})$, with $e_{ij}$ being the edge type (in one-hot representation) between nodes $i$ and $j$, and $f$ is a multi-layer perceptron. $\Gamma$ and $B$ are matrices whose elements at position $(i, j)$ are $\gamma_{ij}$ and $\beta_{ij}$, respectively.</li>
<li><strong>Coordinate Prediction:</strong> The decoder outputs coordinates for each atom, which acts as a mechanism to track attention history.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Image Encoder:</strong> ResNet-34 backbone followed by a Transformer encoder.</li>
<li><strong>Graph Decoder:</strong> A &ldquo;Graph-Aware Transformer&rdquo; (GRAT) that outputs nodes (atom labels, coordinates) and edges (bond types).</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics focus on structural identity, as standard string matching (SMILES) is insufficient for graphs with superatoms.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>SMI</strong></td>
          <td>Canonical SMILES Match</td>
          <td>Correct if predicted SMILES is identical to ground truth.</td>
      </tr>
      <tr>
          <td><strong>TS 1</strong></td>
          <td>Tanimoto Similarity = 1.0</td>
          <td>Ratio of predictions with perfect fingerprint overlap.</td>
      </tr>
      <tr>
          <td><strong>Sim.</strong></td>
          <td>Average Tanimoto Similarity</td>
          <td>Measures average structural overlap across all predictions.</td>
      </tr>
  </tbody>
</table>
<h2 id="reproducibility">Reproducibility</h2>
<p>The paper does not release source code, pre-trained models, or the custom OLED evaluation dataset. The training data sources (PubChem, USPTO) are publicly available, but the specific image generation pipeline (modified RDKit with coordinate extraction and superatom substitution) is not released. Key architectural details (ResNet-34 backbone, Transformer encoder/decoder configuration) and training techniques are described, but exact hyperparameters for full reproduction are limited.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td>Dataset</td>
          <td>Public Domain</td>
          <td>Source of 4.6M molecules for synthetic image generation</td>
      </tr>
      <tr>
          <td><a href="https://www.uspto.gov/">USPTO</a></td>
          <td>Dataset</td>
          <td>Public Domain</td>
          <td>2.5M real image-molecule pairs from patents</td>
      </tr>
      <tr>
          <td><a href="https://www.rdkit.org/">RDKit</a></td>
          <td>Code</td>
          <td>BSD-3-Clause</td>
          <td>Used (with modifications) for synthetic image generation</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Yoo, S., Kwon, O., &amp; Lee, H. (2022). Image-to-Graph Transformers for Chemical Structure Recognition. <em>ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)</em>, 3393-3397. <a href="https://doi.org/10.1109/ICASSP43922.2022.9746088">https://doi.org/10.1109/ICASSP43922.2022.9746088</a></p>
<p><strong>Publication</strong>: ICASSP 2022</p>
]]></content:encoded></item><item><title>ICMDT: Automated Chemical Structure Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/icmdt/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/icmdt/</guid><description>A Transformer-based model (ICMDT) for converting chemical structure images into InChI text strings using a novel Deep TNT block.</description><content:encoded><![CDATA[<h2 id="contribution-image-to-text-translation-for-chemical-structures">Contribution: Image-to-Text Translation for Chemical Structures</h2>
<p>This is a <strong>Method</strong> paper.</p>
<p>It proposes a novel neural network architecture, the <strong>Image Captioning Model based on Deep TNT (ICMDT)</strong>, to solve the specific problem of &ldquo;molecular translation&rdquo; (image-to-text). The classification is supported by the following rhetorical indicators:</p>
<ul>
<li><strong>Novel Mechanism:</strong> It introduces the &ldquo;Deep TNT block&rdquo; to improve upon the existing TNT architecture by fusing features at three levels (pixel, small patch, large patch).</li>
<li><strong>Baseline Comparison:</strong> The authors explicitly compare their model against four other architectures (CNN+RNN and CNN+Transformer variants).</li>
<li><strong>Ablation Study:</strong> Section 4.3 is dedicated to ablating specific components (position encoding, patch fusion) to prove their contribution to the performance gain.</li>
</ul>
<h2 id="motivation-digitizing-historical-chemical-literature">Motivation: Digitizing Historical Chemical Literature</h2>
<p>The primary motivation is to speed up chemical research by digitizing historical chemical literature.</p>
<ul>
<li><strong>Problem:</strong> Historical sources often contain corrupted or noisy images, making automated recognition difficult.</li>
<li><strong>Gap:</strong> Existing models like the standard TNT (Transformer in Transformer) function primarily as encoders for classification and fail to effectively integrate local pixel-level information required for precise structure generation.</li>
<li><strong>Goal:</strong> To build a dependable generative model that can accurately translate these noisy images into <strong><a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a></strong> (International Chemical Identifier) text strings.</li>
</ul>
<h2 id="novelty-multi-level-feature-fusion-with-deep-tnt">Novelty: Multi-Level Feature Fusion with Deep TNT</h2>
<p>The core contribution is the <strong>Deep TNT block</strong> and the resulting <strong>ICMDT</strong> architecture.</p>
<ul>
<li><strong>Deep TNT Block:</strong> The Deep TNT block expands upon standard local and global modeling by stacking three transformer blocks to process information at three granularities:
<ol>
<li><strong>Internal Transformer:</strong> Processes pixel embeddings.</li>
<li><strong>Middle Transformer:</strong> Processes small patch embeddings.</li>
<li><strong>Exterior Transformer:</strong> Processes large patch embeddings.</li>
</ol>
</li>
<li><strong>Multi-level Fusion:</strong> The model fuses pixel-level features into small patches, and small patches into large patches, allowing for finer integration of local details.</li>
<li><strong>Position Encoding:</strong> A specific strategy of applying shared position encodings to small patches and pixels, while using a learnable 1D encoding for large patches.</li>
</ul>
<h2 id="methodology-benchmarking-on-the-bms-dataset">Methodology: Benchmarking on the BMS Dataset</h2>
<p>The authors evaluated the model on the <strong>Bristol-Myers Squibb Molecular Translation</strong> dataset.</p>
<ul>
<li><strong>Baselines:</strong> They constructed four comparative models:
<ul>
<li>EfficientNetb0 + RNN (Bi-LSTM)</li>
<li>ResNet50d + RNN (Bi-LSTM)</li>
<li>EfficientNetb0 + Transformer</li>
<li>ResNet101d + Transformer</li>
</ul>
</li>
<li><strong>Ablation:</strong> They tested the impact of removing the large patch position encoding (ICMDT*), reverting the encoder to a standard TNT-S (TNTD), and setting the patch size to 32 directly on TNT-S without the exterior transformer block (TNTD-B).</li>
<li><strong>Pre-processing Study:</strong> They experimented with denoising ratios and cropping strategies.</li>
</ul>
<h2 id="results--conclusions-improved-inchi-translation-accuracy">Results &amp; Conclusions: Improved InChI Translation Accuracy</h2>
<ul>
<li><strong>Performance:</strong> ICMDT achieved the lowest <strong>Levenshtein distance (0.69)</strong> among all five models tested (Table 3). The best-performing baseline was ResNet101d+Transformer.</li>
<li><strong>Convergence:</strong> The model converged significantly faster than the baselines, outperforming others as early as epoch 6.7.</li>
<li><strong>Ablation Results:</strong> The full Deep TNT block reduced error by nearly half compared to the standard TNT encoder (0.69 vs 1.29 Levenshtein distance). Removing large patch position encoding (ICMDT*) degraded performance to 1.04, and directly using patch size 32 on TNT-S (TNTD-B) scored 1.37.</li>
<li><strong>Limitations:</strong> The model struggles with <strong>stereochemical layers</strong> (e.g., identifying clockwise neighbors or +/- signs) compared to non-stereochemical layers.</li>
<li><strong>Inference &amp; Fusion:</strong> The multi-model inference and fusion pipeline (beam search, TTA, step-wise logit ensemble, and voting) improved results by 0.24 to 2.5 Levenshtein distance reduction over single models.</li>
<li><strong>Future Work:</strong> Integrating full object detection to predict atom/bond coordinates to better resolve 3D stereochemical information.</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<p><strong>Status: Partially Reproducible.</strong> The dataset is publicly available through Kaggle, and the paper provides detailed hyperparameters and architecture specifications. However, no source code or pretrained model weights have been released.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.kaggle.com/c/bms-molecular-translation">BMS Molecular Translation (Kaggle)</a></td>
          <td>Dataset</td>
          <td>Competition Terms</td>
          <td>Training/test images with InChI labels</td>
      </tr>
  </tbody>
</table>
<p><strong>Missing components:</strong> No official code repository or pretrained weights. Reimplementation requires reconstructing the Deep TNT block, training pipeline, and inference/fusion strategy from the paper description alone.</p>
<p><strong>Hardware/compute requirements:</strong> Not explicitly stated in the paper.</p>
<h3 id="data">Data</h3>
<p>The experiments used the <strong>Bristol-Myers Squibb Molecular Translation</strong> dataset from Kaggle.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>BMS Training Set</td>
          <td>2,424,186 images</td>
          <td>Supervised; contains noise and blur</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>BMS Test Set</td>
          <td>1,616,107 images</td>
          <td>Higher noise variation than training set</td>
      </tr>
  </tbody>
</table>
<p><strong>Pre-processing Strategy</strong>:</p>
<ul>
<li><strong>Effective:</strong> Padding resizing (reshaping to square using the longer edge, padding insufficient parts with pixels from the middle of the image).</li>
<li><strong>Ineffective:</strong> Smart cropping (removing white borders degraded performance).</li>
<li><strong>Augmentation:</strong> GaussNoise, Blur, RandomRotate90, and PepperNoise ($SNR=0.996$).</li>
<li><strong>Denoising:</strong> Best results found by mixing denoised and original data (Ratio 2:13) during training.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer:</strong> Lookahead ($\alpha=0.5, k=5$) and RAdam ($\beta_1=0.9, \beta_2=0.99$).</li>
<li><strong>Loss Function:</strong> Anti-Focal loss ($\gamma=0.5$) combined with Label Smoothing. Standard Focal Loss adds a modulating factor $(1-p_t)^\gamma$ to cross-entropy to focus on hard negatives. Anti-Focal Loss (Raunak et al., 2020) modifies this factor to reduce the disparity between training and inference distributions in Seq2Seq models.</li>
<li><strong>Training Schedule:</strong>
<ul>
<li>Initial resolution: $224 \times 224$</li>
<li>Fine-tuning: Resolution $384 \times 384$ for labels $&gt;150$ length.</li>
<li>Batch size: Dynamic, increasing from 16 to 1024 (with proportional learning rate scaling).</li>
<li>Noisy Labels: Randomly replacing chemical elements in labels with a certain probability to improve robustness during inference.</li>
</ul>
</li>
<li><strong>Inference Strategy:</strong>
<ul>
<li>Beam Search ($k=16$ initially, $k=64$ if failing InChI validation).</li>
<li>Test Time Augmentation (TTA): Rotations of $90^\circ$.</li>
<li>Ensemble: Step-wise logit ensemble and voting based on Levenshtein distance scores.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>ICMDT Architecture:</strong></p>
<ul>
<li><strong>Encoder (Deep TNT)</strong> (Depth: 12 layers):
<ul>
<li><strong>Internal Block:</strong> Dim 160, Heads 4, Hidden size 640, MLP act GELU, Pixel patch size 4.</li>
<li><strong>Middle Block:</strong> Dim 10, Heads 6, Hidden size 128, MLP act GELU, Small patch size 16.</li>
<li><strong>Exterior Block:</strong> Dim 2560, Heads 10, Hidden size 5120, MLP act GELU, Large patch size 32.</li>
</ul>
</li>
<li><strong>Decoder (Vanilla Transformer)</strong>:
<ul>
<li>Decoder dim: 2560, FFN dim: 1024.</li>
<li>Depth: 3 layers, Heads: 8.</li>
<li>Vocab size: 193 (InChI tokens), text_dim: 384.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric:</strong> Levenshtein Distance (measures single-character edit operations between generated and ground truth InChI strings).</p>
<p><strong>Ablation Results (Table 3 from paper):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params (M)</th>
          <th>Levenshtein Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ICMDT</strong></td>
          <td>138.16</td>
          <td><strong>0.69</strong></td>
      </tr>
      <tr>
          <td>ICMDT*</td>
          <td>138.16</td>
          <td>1.04</td>
      </tr>
      <tr>
          <td>TNTD</td>
          <td>114.36</td>
          <td>1.29</td>
      </tr>
      <tr>
          <td>TNTD-B</td>
          <td>114.36</td>
          <td>1.37</td>
      </tr>
  </tbody>
</table>
<p><strong>Baseline Comparison (from convergence curves, Figure 9):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Params (M)</th>
          <th>Convergence (Epochs)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>ICMDT</strong></td>
          <td>138.16</td>
          <td>~9.76</td>
      </tr>
      <tr>
          <td>ResNet101d + Transformer</td>
          <td>302.02</td>
          <td>14+</td>
      </tr>
      <tr>
          <td>EfficientNetb0 + Transformer</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ResNet50d + RNN</td>
          <td>90.6</td>
          <td>14+</td>
      </tr>
      <tr>
          <td>EfficientNetb0 + RNN</td>
          <td>46.3</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Li, Y., Chen, G., &amp; Li, X. (2022). Automated Recognition of Chemical Molecule Images Based on an Improved TNT Model. <em>Applied Sciences</em>, 12(2), 680. <a href="https://doi.org/10.3390/app12020680">https://doi.org/10.3390/app12020680</a></p>
<p><strong>Publication</strong>: MDPI Applied Sciences 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.kaggle.com/c/bms-molecular-translation">Kaggle Competition: BMS Molecular Translation</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{liAutomatedRecognitionChemical2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Automated {{Recognition}} of {{Chemical Molecule Images Based}} on an {{Improved TNT Model}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Li, Yanchi and Chen, Guanyu and Li, Xiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2022</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jan,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Applied Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{680}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Multidisciplinary Digital Publishing Institute}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{2076-3417}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.3390/app12020680}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Handwritten Chemical Structure Recognition with RCGD</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hu-handwritten-rcgd-2023/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hu-handwritten-rcgd-2023/</guid><description>An end-to-end framework (RCGD) and unambiguous markup language (SSML) for recognizing complex handwritten chemical structures with guided graph traversal.</description><content:encoded><![CDATA[<h2 id="contribution-and-methodological-framework">Contribution and Methodological Framework</h2>
<p>This is primarily a <strong>Method</strong> paper with a significant <strong>Resource</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architectural framework (<strong>RCGD</strong>) and a new representation syntax (<strong>SSML</strong>) to solve the specific problem of handwritten chemical structure recognition.</li>
<li><strong>Resource</strong>: It introduces a new benchmark dataset, <strong>EDU-CHEMC</strong>, containing 50,000 handwritten images to address the lack of public data in this domain.</li>
</ul>
<h2 id="the-ambiguity-of-handwritten-chemical-structures">The Ambiguity of Handwritten Chemical Structures</h2>
<p>Recognizing handwritten chemical structures is significantly harder than printed ones due to:</p>
<ol>
<li><strong>Inherent Ambiguity</strong>: Handwritten atoms and bonds vary greatly in appearance.</li>
<li><strong>Projection Complexity</strong>: Converting 2D projected layouts (like Natta or Fischer projections) into linear strings is difficult.</li>
<li><strong>Limitations of Existing Formats</strong>: Standard formats like SMILES require domain knowledge (valence rules) and have a high semantic gap with the visual image. They often fail to represent &ldquo;invalid&rdquo; structures commonly found in educational/student work.</li>
</ol>
<h2 id="bridging-the-semantic-gap-with-ssml-and-rcgd">Bridging the Semantic Gap with SSML and RCGD</h2>
<p>The paper introduces two core contributions to bridge the semantic gap between image and markup:</p>
<ol>
<li>
<p><strong>Structure-Specific Markup Language (SSML)</strong>: An extension of Chemfig that provides an unambiguous, visual-based graph representation. Unlike SMILES, it describes <em>how to draw</em> the molecule step-by-step, making it easier for models to learn visual alignments. It supports &ldquo;reconnection marks&rdquo; to handle cyclic structures explicitly.</p>
</li>
<li>
<p><strong>Random Conditional Guided Decoder (RCGD)</strong>: A decoder that treats recognition as a graph traversal problem. It introduces three novel mechanisms:</p>
<ul>
<li><strong>Conditional Attention Guidance</strong>: Uses branch angle directions to guide the attention mechanism, preventing the model from getting lost in complex structures.</li>
<li><strong>Memory Classification</strong>: A module that explicitly stores and classifies &ldquo;unexplored&rdquo; branch points to handle ring closures (reconnections).</li>
<li><strong>Path Selection</strong>: A training strategy that randomly samples traversal paths to prevent overfitting to a specific serialization order.</li>
</ul>
</li>
</ol>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p><strong>Datasets</strong>:</p>
<ul>
<li><strong>Mini-CASIA-CSDB</strong> (Printed): A subset of 97,309 printed molecular structure images, upscaled to $500 \times 500$ resolution.</li>
<li><strong>EDU-CHEMC</strong> (Handwritten): A new dataset of 52,987 images collected from educational settings (cameras, scanners, screens), including erroneous/non-existent structures.</li>
</ul>
<p><strong>Baselines</strong>:</p>
<ul>
<li>Compared against standard <strong>String Decoders (SD)</strong> (based on DenseWAP), tested with both SMILES and SSML on Mini-CASIA-CSDB and exclusively with SSML on EDU-CHEMC.</li>
<li>Compared against <strong>BTTR</strong> and <strong>ABM</strong> (recent mathematical expression recognition models) adapted for the chemical structure task, both using SSML on EDU-CHEMC.</li>
<li>On Mini-CASIA-CSDB, also compared against <strong>WYGIWYS</strong> (a SMILES-based string decoder at 300x300 resolution).</li>
</ul>
<p><strong>Ablation Studies</strong>:</p>
<ul>
<li>Evaluated the impact of removing Path Selection (PS) and Memory Classification (MC) mechanisms on EDU-CHEMC.</li>
<li>Tested robustness to image rotation ($180^{\circ}$) on Mini-CASIA-CSDB.</li>
</ul>
<h2 id="recognition-performance-and-robustness">Recognition Performance and Robustness</h2>
<ul>
<li><strong>Superiority of SSML</strong>: Models trained with SSML significantly outperformed those trained with SMILES (92.09% vs 81.89% EM on printed data) due to reduced semantic gap.</li>
<li><strong>Best Performance</strong>: RCGD achieved the highest Exact Match (EM) scores on both datasets:
<ul>
<li><strong>Mini-CASIA-CSDB</strong>: 95.01% EM.</li>
<li><strong>EDU-CHEMC</strong>: 62.86% EM.</li>
</ul>
</li>
<li><strong>EDU-CHEMC Baselines</strong>: On the handwritten dataset, SD (DenseWAP) achieved 61.35% EM, outperforming both BTTR (58.21% EM) and ABM (58.78% EM). The authors note that BTTR and ABM&rsquo;s reverse training mode, which helps in regular formula recognition, does not transfer well to graph-structured molecular data.</li>
<li><strong>Ablation Results</strong> (Table 5, EDU-CHEMC): Removing Path Selection alone dropped EM from 62.86% to 62.15%. Removing both Path Selection and Memory Classification dropped EM further to 60.31%, showing that memory classification has a larger impact.</li>
<li><strong>Robustness</strong>: RCGD showed minimal performance drop (0.85%) on rotated images compared to SMILES-based methods (10.36% drop). The SD with SSML dropped by 2.19%, confirming that SSML itself improves rotation invariance.</li>
<li><strong>Educational Utility</strong>: The method can recognize and reconstruct chemically invalid structures (e.g., a Carbon atom with 5 bonds), making it applicable for correcting and revising handwritten answers in chemistry education.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>1. EDU-CHEMC (Handwritten)</strong></p>
<ul>
<li><strong>Total Size</strong>: 52,987 images.</li>
<li><strong>Splits</strong>: Training (48,998), Validation (999), Test (2,992).</li>
<li><strong>Characteristics</strong>: Real-world educational data, mixture of isolated molecules and reaction equations, includes invalid chemical structures.</li>
</ul>
<p><strong>2. Mini-CASIA-CSDB (Printed)</strong></p>
<ul>
<li><strong>Total Size</strong>: 97,309 images.</li>
<li><strong>Splits</strong>: Training (80,781), Validation (8,242), Test (8,286).</li>
<li><strong>Preprocessing</strong>: Original $300 \times 300$ images were upscaled to $500 \times 500$ RGB to resolve blurring issues.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. SSML Generation</strong></p>
<p>To convert a molecular graph to SSML:</p>
<ol>
<li><strong>Traverse</strong>: Start from the left-most atom.</li>
<li><strong>Bonds/Atoms</strong>: Output atom text and bond format <code>&lt;bond&gt;[:&lt;angle&gt;]</code>.</li>
<li><strong>Branches</strong>: At branch points, use phantom symbols <code>(</code> and <code>)</code> to enclose branches, ordered by ascending bond angle.</li>
<li><strong>Reconnections</strong>: Use <code>?[tag]</code> and <code>?[tag, bond]</code> to mark start/end of ring closures.</li>
</ol>
<p><strong>2. RCGD Specifics</strong></p>
<ul>
<li><strong>RCGD-SSML</strong>: Modified version of SSML for the decoder. Removes <code>(</code> <code>)</code> delimiters; adds <code>\eob</code> (end of branch). Maintains a dynamic <strong>Branch Angle Set ($M$)</strong>.</li>
<li><strong>Path Selection</strong>: During training, when multiple branches exist in $M$, the model randomly selects one to traverse next. During inference, it uses beam search to score candidate paths.</li>
<li><strong>Loss Function</strong>:
$$
\begin{aligned}
L_{\text{total}} = L_{\text{ce}} + L_{\text{bc}}
\end{aligned}
$$
<ul>
<li>$L_{\text{ce}}$: Cross-entropy loss for character sequence generation.</li>
<li>$L_{\text{bc}}$: Multi-label classification loss for the memory module (predicting reconnection bond types for stored branch states).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Encoder</strong>: DenseNet</p>
<ul>
<li><strong>Structure</strong>: 3 dense blocks.</li>
<li><strong>Growth Rate</strong>: 24.</li>
<li><strong>Depth</strong>: 32 per block.</li>
<li><strong>Output</strong>: High-dimensional feature map $x \in \mathbb{R}^{d_x \times h \times w}$.</li>
</ul>
<p><strong>Decoder</strong>: GRU with Attention</p>
<ul>
<li><strong>Hidden State Dimension</strong>: 256.</li>
<li><strong>Embedding Dimension</strong>: 256.</li>
<li><strong>Attention Projection</strong>: 128.</li>
<li><strong>Memory Classification Projection</strong>: 256.</li>
</ul>
<p><strong>Training Config</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: Adam.</li>
<li><strong>Learning Rate</strong>: 2e-4 with multi-step decay (gamma 0.5).</li>
<li><strong>Dropout</strong>: 15%.</li>
<li><strong>Strategy</strong>: Teacher-forcing used for validation selection.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Exact Match (EM)</strong>: Percentage of samples where the predicted graph structure perfectly matches the label. For SMILES, string comparison; for SSML, converted to graph for isomorphism check.</li>
<li><strong>Structure EM</strong>: Auxiliary metric for samples with mixed content (text + molecules), counting samples where <em>all</em> molecular structures are correct.</li>
</ul>
<p><strong>Artifacts</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/iFLYTEK-CV/EDU-CHEMC">EDU-CHEMC</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Dataset annotations and download links (actual data hosted on Google Drive)</td>
      </tr>
  </tbody>
</table>
<p><strong>Missing Components</strong>:</p>
<ul>
<li>No training or inference code is publicly released; only the dataset is available.</li>
<li>Pre-trained model weights are not provided.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hu, J., Wu, H., Chen, M., Liu, C., Wu, J., Yin, S., Yin, B., Yin, B., Liu, C., Du, J., &amp; Dai, L. (2023). Handwritten Chemical Structure Image to Structure-Specific Markup Using Random Conditional Guided Decoder. <em>Proceedings of the 31st ACM International Conference on Multimedia</em> (pp. 8114-8124). <a href="https://doi.org/10.1145/3581783.3612573">https://doi.org/10.1145/3581783.3612573</a></p>
<p><strong>Publication</strong>: ACM Multimedia 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/iFLYTEK-CV/EDU-CHEMC">GitHub Repository / EDU-CHEMC Dataset</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{huHandwrittenChemicalStructure2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Handwritten Chemical Structure Image to Structure-Specific Markup Using Random Conditional Guided Decoder}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 31st ACM International Conference on Multimedia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Hu, Jinshui and Wu, Hao and Chen, Mingjun and Liu, Chenyu and Wu, Jiajia and Yin, Shi and Yin, Baocai and Yin, Bing and Liu, Cong and Du, Jun and Dai, Lirong}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{8114--8124}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{ACM}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Ottawa ON Canada}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1145/3581783.3612573}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{979-8-4007-0108-5}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>End-to-End Transformer for Molecular Image Captioning</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/vit-inchi-transformer/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/vit-inchi-transformer/</guid><description>Vision Transformer encoder with Transformer decoder for molecular image-to-InChI translation, outperforming CNN baselines on noisy molecular datasets.</description><content:encoded><![CDATA[<h2 id="methodological-contribution">Methodological Contribution</h2>
<p>This is a <strong>Methodological Paper</strong>. It proposes a novel architectural approach to molecular image translation by replacing the standard CNN encoder with a Vision Transformer (ViT). The authors validate this method through comparative benchmarking against standard CNN+RNN baselines (e.g., ResNet+LSTM) and provide optimizations for inference speed.</p>
<h2 id="motivation-and-problem-statement">Motivation and Problem Statement</h2>
<p>The core problem addressed is existing molecular translation methods (extracting chemical structure from images into computer-readable InChI format) rely heavily on rule-based systems or CNN+RNN architectures. These current approaches often underperform when handling noisy images (common in scanned old journals) or images with few distinguishable features. There is a significant need in drug discovery to digitize and analyze legacy experimental data locked in image format within scientific publications.</p>
<h2 id="core-innovations-end-to-end-vit-encoder">Core Innovations: End-to-End ViT Encoder</h2>
<p>The primary contribution is the use of a completely convolution-free Vision Transformer (ViT) as the encoder, allowing the model to utilize long-range dependencies among image patches from the very beginning via self-attention:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
The architecture is a pure Transformer (Encoder-Decoder), treating the molecular image similarly to a sequence of tokens (patches). Furthermore, the authors implement a specific caching strategy for the decoder to avoid recomputing embeddings for previously decoded tokens, reducing the time complexity of the decoding step.</p>
<h2 id="experimental-setup-and-baselines">Experimental Setup and Baselines</h2>
<p>The model was compared against standard CNN + RNN and ResNet (18, 34, 50) + LSTM with attention. Ablation studies were conducted varying the number of transformer layers (3, 6, 12, 24) and image resolution (224x224 vs 384x384). The model trained on a large combined dataset, including Bristol Myers Squibb data, SMILES, GDB-13, and synthetically augmented images containing noise and artifacts. Performance was evaluated using the Levenshtein distance metric, which computes the minimum number of single-character edits to transform the predicted string into the ground truth.</p>
<h2 id="performance-outcomes-and-capabilities">Performance Outcomes and Capabilities</h2>
<p>The proposed 24-layer ViT model (input size 384) achieved the lowest Levenshtein distance of <strong>6.95</strong>, outperforming the ResNet50+LSTM baseline (7.49) and the standard CNN+RNN (103.7). Increasing the number of layers had a strong positive impact, with the 24-layer model becoming competitive with current approaches. The authors note the model was evaluated on datasets with low distinguishable features and noise, where the ViT encoder&rsquo;s self-attention over all patches from the first layer helped capture relevant structure. The proposed caching optimization reduced the total decoding time complexity from $O(MN^2 + N^3)$ to $O(MN + N^2)$ for $N$ timesteps, by reducing the per-timestep cost to $O(M + N)$.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The model was trained on a combined dataset randomly split into 70% training, 10% test, and 20% validation.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Bristol Myers Squibb</strong></td>
          <td>~2.4 million synthetic images with InChI labels.</td>
          <td>Provided by BMS global biopharmaceutical company.</td>
      </tr>
      <tr>
          <td><strong>SMILES</strong></td>
          <td>Kaggle contest data converted to InChI.</td>
          <td>Images generated using RDKit.</td>
      </tr>
      <tr>
          <td><strong><a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a></strong></td>
          <td>Subset of 977 million small organic molecules (up to 13 atoms).</td>
          <td>Converted from SMILES using RDKit.</td>
      </tr>
      <tr>
          <td><strong>Augmented Images</strong></td>
          <td>Synthetic images with salt/pepper noise, dropped atoms, and bond modifications.</td>
          <td>Used to improve robustness against noise.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Training Objective</strong>: Cross-entropy loss minimization.</li>
<li><strong>Inference Decoding</strong>: Autoregressive decoding predicting the next character of the InChI string.</li>
<li><strong>Positional Encoding</strong>: Standard sine and cosine functions of different frequencies.</li>
<li><strong>Optimization</strong>:
<ul>
<li><strong>Caching</strong>: Caches the output of each layer during decoding to avoid recomputing embeddings for already decoded tokens.</li>
<li><strong>JIT</strong>: PyTorch JIT compiler used for graph optimization (1.2-1.5x speed increase on GPU).</li>
<li><strong>Self-Critical Training</strong>: Finetuning performed using self-critical sequence training (SCST).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Encoder (Vision Transformer)</strong>:
<ul>
<li>Input: Flattened 2D patches of the image. Patch size: $16 \times 16$.</li>
<li>Projection: Trainable linear projection to latent vector size $D$.</li>
<li>Structure: Alternating layers of Multi-Head Self-Attention (MHSA) and MLP blocks.</li>
</ul>
</li>
<li><strong>Decoder (Vanilla Transformer)</strong>:
<ul>
<li>Input: Tokenized InChI string + sinusoidal positional embedding.</li>
<li>Vocabulary: 275 tokens (including <code>&lt;SOS&gt;</code>, <code>&lt;PAD&gt;</code>, <code>&lt;EOS&gt;</code>).</li>
</ul>
</li>
<li><strong>Hyperparameters (Best Model)</strong>:
<ul>
<li>Image Size: $384 \times 384$.</li>
<li>Layers: 24.</li>
<li>Feature Dimension: 512.</li>
<li>Attention Heads: 12.</li>
<li>Optimizer: Adam.</li>
<li>Learning Rate: $3 \times 10^{-5}$ (decayed by 0.5 in last 2 epochs).</li>
<li>Batch Size: Varied [64-512].</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Levenshtein Distance (lower is better).</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Image Size</th>
          <th>Layers</th>
          <th>Epochs</th>
          <th>Levenshtein Dist.</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Standard CNN+RNN</td>
          <td>224</td>
          <td>3</td>
          <td>10</td>
          <td>103.7</td>
      </tr>
      <tr>
          <td>ResNet18 + LSTM</td>
          <td>224</td>
          <td>4</td>
          <td>10</td>
          <td>75.03</td>
      </tr>
      <tr>
          <td>ResNet34 + LSTM</td>
          <td>224</td>
          <td>4</td>
          <td>10</td>
          <td>45.72</td>
      </tr>
      <tr>
          <td>ResNet50 + LSTM</td>
          <td>224</td>
          <td>5</td>
          <td>10</td>
          <td>7.49</td>
      </tr>
      <tr>
          <td>ViT Transformers</td>
          <td>224</td>
          <td>3</td>
          <td>5</td>
          <td>79.82</td>
      </tr>
      <tr>
          <td>ViT Transformers</td>
          <td>224</td>
          <td>6</td>
          <td>5</td>
          <td>54.58</td>
      </tr>
      <tr>
          <td>ViT Transformers</td>
          <td>224</td>
          <td>12</td>
          <td>5</td>
          <td>31.30</td>
      </tr>
      <tr>
          <td>ViT Transformers (Best)</td>
          <td>384</td>
          <td>24</td>
          <td>10</td>
          <td><strong>6.95</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>System</strong>: 70GB GPU system.</li>
<li><strong>Framework</strong>: PyTorch and PyTorch Lightning.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Sundaramoorthy, C., Kelvin, L. Z., Sarin, M., &amp; Gupta, S. (2021). End-to-End Attention-based Image Captioning. <em>arXiv preprint arXiv:2104.14721</em>. <a href="https://doi.org/10.48550/arXiv.2104.14721">https://doi.org/10.48550/arXiv.2104.14721</a></p>
<p><strong>Publication</strong>: arXiv 2021 (preprint)</p>
<p><strong>Note</strong>: This is an arXiv preprint and has not undergone formal peer review.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{sundaramoorthyEndtoEndAttentionbasedImage2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{End-to-{{End Attention-based Image Captioning}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Sundaramoorthy, Carola and Kelvin, Lin Ziwen and Sarin, Mahak and Gupta, Shubham}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2021</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = apr,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{arXiv:2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.48550/arXiv.2104.14721}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DECIMER 1.0: Transformers for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-1.0/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer-1.0/</guid><description>Transformer-based approach for Optical Chemical Structure Recognition converting chemical images to SELFIES strings with 96% accuracy.</description><content:encoded><![CDATA[<h2 id="evaluating-the-contribution-a-methodological-shift">Evaluating the Contribution: A Methodological Shift</h2>
<p><strong>Method (Dominant)</strong> with strong <strong>Resource</strong> elements.</p>
<p>This is primarily a <strong>Method</strong> paper because it proposes a specific architectural evolution. It replaces CNN-RNN/Encoder-Decoder models with a <strong>Transformer-based network</strong> to solve the problem of image-to-structure translation. It validates this methodological shift through rigorous ablation studies comparing feature extractors (InceptionV3 vs. EfficientNet) and decoder architectures.</p>
<p>It also serves as a <strong>Resource</strong> contribution by releasing the open-source software, trained models, and describing the curation of a massive synthetic training dataset (&gt;35 million molecules).</p>
<h2 id="motivation-inaccessible-chemical-knowledge">Motivation: Inaccessible Chemical Knowledge</h2>
<ul>
<li><strong>Data Inaccessibility</strong>: A vast amount of chemical knowledge (pre-1990s) is locked in printed or scanned literature and is not machine-readable.</li>
<li><strong>Manual Bottlenecks</strong>: Manual curation and extraction of this data is tedious, slow, and error-prone.</li>
<li><strong>Limitations of Prior Tools</strong>: Existing Optical Chemical Structure Recognition (OCSR) tools are often rule-based or struggle with the noise and variability of full-page scanned articles. Previous deep learning attempts were not publicly accessible or robust enough.</li>
</ul>
<h2 id="key-innovation-transformer-based-molecular-translation">Key Innovation: Transformer-Based Molecular Translation</h2>
<ul>
<li><strong>Transformer Architecture</strong>: Shifts from the standard CNN-RNN (Encoder-Decoder) approach to a <strong>Transformer-based decoder</strong>, significantly improving accuracy.</li>
<li><strong>EfficientNet Backbone</strong>: Replaces the standard InceptionV3 feature extractor with <strong>EfficientNet-B3</strong>, which improved feature extraction quality for chemical images.</li>
<li><strong>SELFIES Representation</strong>: Utilizes <a href="/notes/chemistry/molecular-representations/notations/selfies/"><strong>SELFIES</strong></a> (SELF-referencing Embedded Strings) as the target output. This guarantees 100% robust molecular strings and eliminates the &ldquo;invalid SMILES&rdquo; problem common in generative models.</li>
<li><strong>Massive Scaling</strong>: Trains on synthetic datasets derived from PubChem (up to <strong>39 million molecules</strong> total, with the largest training subset at ~35 million), demonstrating that scaling data size directly correlates with improved model performance.</li>
</ul>
<h2 id="methodology-and-experimental-validation">Methodology and Experimental Validation</h2>
<ul>
<li><strong>Feature Extractor Ablation</strong>: Compared InceptionV3 vs. EfficientNet-B3 (and B7) on a 1-million molecule subset to determine the optimal image encoder.</li>
<li><strong>Architecture Comparison</strong>: Benchmarked the Encoder-Decoder (CNN+RNN) against the Transformer model using Tanimoto similarity metrics. The structural similarity between predicted and ground truth molecules was measured via Tanimoto similarity over molecular fingerprints:
$$ T(\mathbf{A}, \mathbf{B}) = \frac{\mathbf{A} \cdot \mathbf{B}}{|\mathbf{A}|^2 + |\mathbf{B}|^2 - \mathbf{A} \cdot \mathbf{B}} $$</li>
<li><strong>Data Scaling</strong>: Evaluated performance across increasing training set sizes (1M, 10M, 15M, 35M) to observe scaling laws.</li>
<li><strong>Stereochemistry &amp; Ions</strong>: Tested the model&rsquo;s ability to handle complex stereochemical information and charged groups (ions), creating separate datasets for these tasks.</li>
<li><strong>Augmentation Robustness</strong>: Evaluated the model on augmented images (blur, noise, varying contrast) to simulate real-world scanned document conditions.</li>
</ul>
<h2 id="results-and-scaling-observations">Results and Scaling Observations</h2>
<ul>
<li><strong>Architecture Comparison</strong>: The Transformer model with EfficientNet-B3 features outperformed the Encoder-Decoder baseline by a wide margin. On the 1M dataset, the Transformer achieved <strong>74.57%</strong> exact matches (Tanimoto 1.0) compared to only <strong>7.03%</strong> for the Encoder-Decoder (Table 4 in the paper).</li>
<li><strong>High Accuracy at Scale</strong>: With the full 35-million molecule training set (Dataset 1), the model achieved a <strong>Tanimoto 1.0 score of 96.47%</strong> and an average Tanimoto similarity of 0.99.</li>
<li><strong>Isomorphism</strong>: 99.75% of predictions with a Tanimoto score of 1.0 were confirmed to be structurally isomorphic to the ground truth (checked via <a href="/notes/chemistry/molecular-representations/notations/inchi-2013/">InChI</a>).</li>
<li><strong>Stereochemistry Costs</strong>: Including stereochemistry and ions increased the token count and difficulty, resulting in slightly lower accuracy (~89.87% exact match on Dataset 2).</li>
<li><strong>Hardware Efficiency</strong>: Training on TPUs (v3-8) was ~4x faster than Nvidia V100 GPUs. For the 1M molecule model, convergence took ~8h 41min on TPU v3-8 vs ~29h 48min on V100 GPU. The largest model (35M) took less than 14 days on TPU.</li>
<li><strong>Augmentation Robustness (Dataset 3)</strong>: When trained on augmented images and tested on non-augmented images, the model achieved 86.43% Tanimoto 1.0. Using a pre-trained model from Dataset 2 and refitting on augmented images improved this to 88.04% on non-augmented test images and 80.87% on augmented test images, retaining above 97% isomorphism rates.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors generated synthetic data from PubChem.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 1 (Clean)</td>
          <td>39M total (35M train)</td>
          <td>No stereo/ions. Filtered for MW &lt; 1500, bond count 3-40, SMILES len &lt; 40.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 2 (Complex)</td>
          <td>37M total (33M train)</td>
          <td>Includes stereochemistry and charged groups (ions).</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td>Dataset 3 (Augmented)</td>
          <td>37M total (33M train)</td>
          <td>Dataset 2 with image augmentations applied.</td>
      </tr>
      <tr>
          <td><strong>Preprocessing</strong></td>
          <td>N/A</td>
          <td>N/A</td>
          <td>Molecules converted to <strong>SELFIES</strong>. Images generated via CDK Structure Diagram Generator (SDG) as $299 \times 299$ 8-bit PNGs.</td>
      </tr>
      <tr>
          <td><strong>Format</strong></td>
          <td>TFRecords</td>
          <td>75 MB chunks</td>
          <td>128 Data points (image vector + tokenized string) per record.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Text Representation</strong>: <strong>SELFIES</strong> used to avoid invalid intermediate strings. Tokenized via Keras tokenizer.
<ul>
<li><em>Dataset 1 Tokens</em>: 27 unique tokens. Max length 47.</li>
<li><em>Dataset 2/3 Tokens</em>: 61 unique tokens (due to stereo/ion tokens).</li>
</ul>
</li>
<li><strong>Augmentation</strong>: Implemented using <code>imgaug</code> python package. Random application of:
<ul>
<li>Gaussian/Average Blur, Additive Gaussian Noise, Salt &amp; Pepper, Coarse Dropout, Gamma Contrast, Sharpen, Brightness.</li>
</ul>
</li>
<li><strong>Optimization</strong>: Adam optimizer with a custom learning rate scheduler (following the &ldquo;Attention is all you need&rdquo; paper).</li>
</ul>
<h3 id="models">Models</h3>
<p>The final architecture is an <strong>Image-to-SELFIES Transformer</strong>.</p>
<ul>
<li><strong>Encoder (Feature Extractor)</strong>:
<ul>
<li><strong>EfficientNet-B3</strong> (pre-trained on Noisy-student).</li>
<li>Input: $299 \times 299 \times 3$ images (normalized -1 to 1).</li>
<li>Output Feature Vector: $10 \times 10 \times 1536$.</li>
</ul>
</li>
<li><strong>Decoder (Transformer)</strong>:
<ul>
<li>4 Encoder-Decoder layers.</li>
<li>8 Parallel Attention Heads.</li>
<li>Dimension size: 512.</li>
<li>Feed-forward size: 2048.</li>
<li>Dropout: 0.1.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation was performed on a held-out test set (10% of total data) selected via RDKit MaxMin algorithm for diversity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto 1.0</strong></td>
          <td><strong>96.47%</strong></td>
          <td>74.57% (1M subset)</td>
          <td>Percentage of predictions with perfect fingerprint match (Dataset 1, 35M training).</td>
      </tr>
      <tr>
          <td><strong>Avg Tanimoto</strong></td>
          <td><strong>0.9923</strong></td>
          <td>0.9371 (1M subset)</td>
          <td>Average similarity score (Dataset 1, 35M training).</td>
      </tr>
      <tr>
          <td><strong>Isomorphism</strong></td>
          <td><strong>99.75%</strong></td>
          <td>-</td>
          <td>Percentage of Tanimoto 1.0 predictions that are structurally identical (checked via InChI).</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Hardware</strong>: TPU v3-8 (Google Cloud). TPU v3-32 was tested but v3-8 was chosen for cost-effectiveness.</li>
<li><strong>Comparison Hardware</strong>: Nvidia Tesla V100 (32GB GPU).</li>
<li><strong>Performance</strong>:
<ul>
<li>TPU v3-8 was ~4x faster than V100 GPU.</li>
<li>1 Million molecule model convergence: 8h 41min on TPU vs ~29h 48min on GPU.</li>
<li>Largest model (35M) took less than 14 days on TPU.</li>
</ul>
</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<p>The paper is open-access, and both code and data are publicly available.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">DECIMER-TPU (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation using TensorFlow and TPU training</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.4730515">Code Archive (Zenodo)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Archival snapshot of the codebase</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.4766251">Training Data (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>SMILES data used for training (images generated via CDK SDG)</td>
      </tr>
      <tr>
          <td><a href="https://decimer.ai/">DECIMER Project Page</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Project landing page</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Hardware Requirements</strong>: Training requires TPU v3-8 (Google Cloud) or Nvidia V100 GPU. The largest model (35M molecules) took less than 14 days on TPU v3-8.</li>
<li><strong>Missing Components</strong>: Augmentation parameters are documented in the paper (Table 14). Pre-trained model weights are available through the GitHub repository.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A. &amp; Steinbeck, C. (2021). DECIMER 1.0: deep learning for chemical image recognition using transformers. <em>Journal of Cheminformatics</em>, 13(1), 61. <a href="https://doi.org/10.1186/s13321-021-00538-8">https://doi.org/10.1186/s13321-021-00538-8</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2021</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Kohulan/DECIMER-Image_Transformer">GitHub Repository</a></li>
<li><a href="https://decimer.ai/">DECIMER Project Page</a></li>
<li><a href="https://doi.org/10.5281/zenodo.4730515">Code Archive (Zenodo)</a></li>
<li><a href="https://doi.org/10.5281/zenodo.4766251">Training Data (Zenodo)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMER10Deep2021,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{DECIMER 1.0: Deep Learning for Chemical Image Recognition Using Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{DECIMER 1.0}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = <span style="color:#e6db74">{aug}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{13}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{61}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-021-00538-8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1186/s13321-021-00538-8}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemPix: Hand-Drawn Hydrocarbon Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/chempix/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/chempix/</guid><description>Deep learning framework using CNN-LSTM image captioning to convert hand-drawn hydrocarbon structures into SMILES strings with 76% accuracy.</description><content:encoded><![CDATA[<h2 id="paper-classification-and-core-contribution">Paper Classification and Core Contribution</h2>
<p>This is primarily a <strong>Method</strong> paper, with a secondary contribution as a <strong>Resource</strong> paper.</p>
<p>The paper&rsquo;s core contribution is the <strong>ChemPix architecture and training strategy</strong> using neural image captioning (CNN-LSTM) to convert hand-drawn chemical structures to SMILES. The extensive ablation studies on synthetic data generation (augmentation, degradation, backgrounds) and ensemble learning strategies confirm the methodological focus. The secondary resource contribution includes releasing a curated dataset of hand-drawn hydrocarbons and code for generating synthetic training data.</p>
<h2 id="the-structural-input-bottleneck-in-computational-chemistry">The Structural Input Bottleneck in Computational Chemistry</h2>
<p>Inputting molecular structures into computational chemistry software for quantum calculations is often a bottleneck, requiring domain expertise and cumbersome manual entry in drawing software. While optical chemical structure recognition (OCSR) tools exist, they typically struggle with the noise and variability of hand-drawn sketches. There is a practical need for a tool that allows chemists to simply photograph a hand-drawn sketch and immediately convert it into a machine-readable format (<a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>), making computational workflows more accessible.</p>
<h2 id="cnn-lstm-image-captioning-and-synthetic-generalization">CNN-LSTM Image Captioning and Synthetic Generalization</h2>
<ol>
<li><strong>Image Captioning Paradigm</strong>: The authors treat the problem as <strong>neural image captioning</strong>, using an encoder-decoder (CNN-LSTM) framework to &ldquo;translate&rdquo; an image directly to a SMILES string. This avoids the complexity of explicit atom/bond detection and graph assembly.</li>
<li><strong>Synthetic Data Engineering</strong>: The paper introduces a rigorous synthetic data generation pipeline that transforms clean RDKit-generated images into &ldquo;pseudo-hand-drawn&rdquo; images via randomized backgrounds, degradation, and heavy augmentation. This allows the model to achieve &gt;50% accuracy on real hand-drawn data without ever seeing it during training.</li>
<li><strong>Ensemble Uncertainty Estimation</strong>: The method utilizes a &ldquo;committee&rdquo; (ensemble) of networks to improve accuracy and estimate confidence based on vote agreement, providing users with reliability indicators for predictions.</li>
</ol>
<h2 id="extensive-ablation-and-real-world-evaluation">Extensive Ablation and Real-World Evaluation</h2>
<ol>
<li><strong>Ablation Studies on Data Pipeline</strong>: The authors trained models on datasets generated at different stages of the pipeline (Clean RDKit $\rightarrow$ Augmented $\rightarrow$ Backgrounds $\rightarrow$ Degraded) to quantify the value of each transformation in bridging the synthetic-to-real domain gap.</li>
<li><strong>Sample Size Scaling</strong>: They analyzed performance scaling by training on synthetic dataset sizes ranging from 10,000 to 500,000 images to understand data requirements.</li>
<li><strong>Real-world Validation</strong>: The model was evaluated on a held-out test set of hand-drawn images collected via a custom web app, providing genuine out-of-distribution testing.</li>
<li><strong>Fine-tuning Experiments</strong>: Comparisons of synthetic-only training versus fine-tuning with a small fraction of real hand-drawn data to assess the value of limited real-world supervision.</li>
</ol>
<h2 id="state-of-the-art-hand-drawn-ocsr-performance">State-of-the-Art Hand-Drawn OCSR Performance</h2>
<ol>
<li>
<p><strong>Pipeline Efficacy</strong>: Augmentation and image degradation were the most critical factors for generalization, achieving over 50% accuracy on hand-drawn data when training with 500,000 synthetic images. Adding backgrounds had a negligible effect on accuracy compared to degradation.</p>
</li>
<li>
<p><strong>State-of-the-Art Performance</strong>: The final ensemble model (5 out of 17 trained NNs, selected for achieving &gt;50% individual accuracy) achieved <strong>76% accuracy</strong> (top-1) and <strong>85.5% accuracy</strong> (top-3) on the hand-drawn test set, a significant improvement over the best single model&rsquo;s 67.5%.</p>
</li>
<li>
<p><strong>Synthetic Generalization</strong>: A model trained on 500,000 synthetic images achieved &gt;50% accuracy on real hand-drawn data without any fine-tuning, validating the synthetic data generation strategy as a viable alternative to expensive manual labeling.</p>
</li>
<li>
<p><strong>Ensemble Benefits</strong>: The voting committee approach improved accuracy and provided interpretable uncertainty estimates through vote distributions. When all five committee members agree ($V=5$), the confidence value reaches 98%.</p>
</li>
</ol>
<h2 id="limitations">Limitations</h2>
<p>The authors acknowledge several limitations of the current system:</p>
<ul>
<li><strong>Hydrocarbons only</strong>: The model is restricted to hydrocarbon structures and does not handle heteroatoms or functional groups.</li>
<li><strong>No conjoined rings</strong>: Molecules with multiple conjoined rings are excluded due to limitations of RDKit&rsquo;s image generation, which depicts bridges differently from standard chemistry drawing conventions.</li>
<li><strong>Resonance hybrid notation</strong>: The network struggles with benzene rings drawn in the resonance hybrid style (with a circle) compared to the Kekule structure, since the RDKit training images use exclusively Kekule representations.</li>
<li><strong>Challenging backgrounds</strong>: Lined and squared paper increase recognition difficulty, and structures bleeding through from the opposite side of the page can confuse the network.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study relies on two primary data sources: a massive synthetic dataset generated procedurally and a smaller collected dataset of real drawings.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td>Synthetic (RDKit)</td>
          <td>500,000 images</td>
          <td>Generated via RDKit with &ldquo;heavy&rdquo; augmentation: rotation ($0-360°$), blur, salt+pepper noise, and background texture addition.</td>
      </tr>
      <tr>
          <td><strong>Fine-tuning</strong></td>
          <td>Hand-Drawn (Real)</td>
          <td>613 images</td>
          <td>Crowdsourced via a web app from over 100 unique users; split into 200-image test set and 413 training/validation images.</td>
      </tr>
      <tr>
          <td><strong>Backgrounds</strong></td>
          <td>Texture Images</td>
          <td>1,052 images</td>
          <td>A pool of unlabeled texture photos (paper, desks, shadows) used to generate synthetic backgrounds.</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Generation Parameters</strong>:</p>
<ul>
<li><strong>Augmentations</strong>: Rotation, Resize ($200-300px$), Blur, Dilate, Erode, Aspect Ratio, Affine transform ($\pm 20px$), Contrast, Quantize, Sharpness</li>
<li><strong>Backgrounds</strong>: Randomly translated $\pm 100$ pixels and reflected</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Ensemble Voting</strong><br>
A committee of networks casts votes for the predicted SMILES string. The final prediction is the one with the highest vote count. Validity of SMILES is checked using RDKit.</p>
<p><strong>Beam Search</strong><br>
Used in the decoding layer with a beam width of $k=5$ to explore multiple potential SMILES strings. It approximates the sequence $\mathbf{\hat{y}}$ that maximizes the joint probability:</p>
<p>$$ \mathbf{\hat{y}} = \arg\max_{\mathbf{y}} \sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$</p>
<p><strong>Optimization</strong>:</p>
<ul>
<li>
<p><strong>Optimizer</strong>: Adam</p>
</li>
<li>
<p><strong>Learning Rate</strong>: $1 \times 10^{-4}$</p>
</li>
<li>
<p><strong>Batch Size</strong>: 20</p>
</li>
<li>
<p><strong>Loss Function</strong>: Cross-entropy loss across the sequence of $T$ tokens, computed as:</p>
<p>$$ \mathcal{L} = -\sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$</p>
<p>where $\mathbf{x}$ is the image representation and $y_t$ is the predicted SMILES character. This is calculated as perplexity for validation.</p>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture is a standard image captioning model (Show, Attend and Tell style) adapted for chemical structures.</p>
<p><strong>Encoder (CNN)</strong>:</p>
<ul>
<li><strong>Input</strong>: 256x256 pixel PNG images</li>
<li><strong>Structure</strong>: 4 blocks of Conv2D + MaxPool
<ul>
<li>Block 1: 64 filters, (3,3) kernel</li>
<li>Block 2: 128 filters, (3,3) kernel</li>
<li>Block 3: 256 filters, (3,3) kernel</li>
<li>Block 4: 512 filters, (3,3) kernel</li>
</ul>
</li>
<li><strong>Activation</strong>: ReLU throughout</li>
</ul>
<p><strong>Decoder (LSTM)</strong>:</p>
<ul>
<li><strong>Hidden Units</strong>: 512</li>
<li><strong>Embedding Dimension</strong>: 80</li>
<li><strong>Attention</strong>: Mechanism with intermediary vector dimension of 512</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Primary Metric</strong>: Exact SMILES match accuracy (character-by-character identity between predicted and ground truth SMILES)</li>
<li><strong>Perplexity</strong>: Used for saving model checkpoints (minimizing uncertainty)</li>
<li><strong>Top-k Accuracy</strong>: Reported for $k=1$ (76%) and $k=3$ (85.5%)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mtzgroup/ChemPixCH">ChemPixCH</a></td>
          <td>Code + Dataset</td>
          <td>Apache-2.0</td>
          <td>Official implementation with synthetic data generation pipeline and collected hand-drawn dataset</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Weir, H., Thompson, K., Woodward, A., Choi, B., Braun, A., &amp; Martínez, T. J. (2021). ChemPix: Automated Recognition of Hand-Drawn Hydrocarbon Structures Using Deep Learning. <em>Chemical Science</em>, 12(31), 10622-10633. <a href="https://doi.org/10.1039/D1SC02957F">https://doi.org/10.1039/D1SC02957F</a></p>
<p><strong>Publication</strong>: Chemical Science 2021</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/mtzgroup/ChemPixCH">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{weir2021chempix,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemPix: Automated Recognition of Hand-Drawn Hydrocarbon Structures Using Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Weir, Hayley and Thompson, Keiran and Woodward, Amelia and Choi, Benjamin and Braun, Augustin and Mart{\&#39;i}nez, Todd J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Chemical Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{31}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{10622--10633}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Royal Society of Chemistry}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1039/D1SC02957F}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ABC-Net: Keypoint-Based Molecular Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/abc-net/</link><pubDate>Thu, 18 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/abc-net/</guid><description>Deep learning OCSR model using keypoint estimation to detect atom and bond centers for graph-based molecular structure recognition.</description><content:encoded><![CDATA[<h2 id="contribution-and-paper-type">Contribution and Paper Type</h2>
<p><strong>Method</strong>. The paper proposes a novel architectural framework (ABC-Net) for Optical Chemical Structure Recognition (OCSR). It reformulates the problem from image captioning (sequence generation) to keypoint estimation (pixel-wise detection), backed by ablation studies on noise and comparative benchmarks against state-of-the-art tools.</p>
<h2 id="motivation-for-keypoint-based-ocsr">Motivation for Keypoint-Based OCSR</h2>
<ul>
<li><strong>Inefficiency of Rule-Based Methods</strong>: Traditional tools (OSRA, MolVec) rely on hand-coded rules that are brittle, require domain expertise, and fail to handle the wide variance in molecular drawing styles.</li>
<li><strong>Data Inefficiency of Captioning Models</strong>: Recent Deep Learning approaches (like DECIMER, Img2mol) treat OCSR as image captioning (Image-to-SMILES). This is data-inefficient because canonical SMILES require learning traversal orders, necessitating millions of training examples.</li>
<li><strong>Goal</strong>: To create a scalable, data-efficient model that predicts graph structures directly by detecting atomic/bond primitives.</li>
</ul>
<h2 id="abc-nets-divide-and-conquer-architecture">ABC-Net&rsquo;s Divide-and-Conquer Architecture</h2>
<ul>
<li><strong>Divide-and-Conquer Strategy</strong>: ABC-Net breaks the problem down into detecting <strong>atom centers</strong> and <strong>bond centers</strong> as independent keypoints.</li>
<li><strong>Keypoint Estimation</strong>: A Fully Convolutional Network (FCN) generates heatmaps for object centers. This is inspired by computer vision techniques like CornerNet and CenterNet.</li>
<li><strong>Angle-Based Bond Detection</strong>: To handle overlapping bonds, the model classifies bond angles into 60 distinct bins ($0-360°$) at detected bond centers, allowing separation of intersecting bonds.</li>
<li><strong>Implicit Hydrogen Prediction</strong>: The model explicitly predicts the number of implicit hydrogens for heterocyclic atoms to resolve ambiguity in dearomatization.</li>
</ul>
<h2 id="experimental-setup-and-synthetic-data">Experimental Setup and Synthetic Data</h2>
<ul>
<li><strong>Dataset Construction</strong>: Synthetic dataset of 100,000 molecules from ChEMBL, rendered using two different engines (RDKit and Indigo) to ensure style diversity.</li>
<li><strong>Baselines</strong>: Compared against two rule-based methods (MolVec, OSRA) and one deep learning method (Img2mol).</li>
<li><strong>Robustness Testing</strong>: Evaluated on the external UOB dataset (real-world images) and synthetic images with varying levels of salt-and-pepper noise (up to $p=0.6$).</li>
<li><strong>Data Efficiency</strong>: Analyzed performance scaling with training set size (10k to 160k images).</li>
</ul>
<h2 id="results-generalization-and-noise-robustness">Results, Generalization, and Noise Robustness</h2>
<ul>
<li><strong>Superior Accuracy</strong>: ABC-Net achieved <strong>94-98% accuracy</strong> across all test sets (Table 1), outperforming MolVec (12-45% on synthetic data, ~83% on UOB), OSRA (26-62% on synthetic, ~82% on UOB), and Img2mol (78-93% on non-stereo subsets).</li>
<li><strong>Generalization</strong>: On the external UOB benchmark, ABC-Net achieved <strong>&gt;95% accuracy</strong>, whereas the deep learning baseline (Img2mol) dropped to 78.2%, indicating better generalization.</li>
<li><strong>Data Efficiency</strong>: The model reached ~95% performance with only 80,000 training images, requiring roughly an order of magnitude less data than captioning-based models like Img2mol (which use millions of training examples).</li>
<li><strong>Noise Robustness</strong>: Performance remained stable (&lt;2% drop) with noise levels up to $p=0.1$. Even at extreme noise ($p=0.6$), Tanimoto similarity remained high, suggesting the model recovers most substructures even when exact matches fail.</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li><strong>Drawing style coverage</strong>: The synthetic training data covers only styles available through RDKit and Indigo renderers. Many real-world styles (e.g., hand-drawn structures, atomic group abbreviations) are not represented.</li>
<li><strong>No stereo baseline from Img2mol</strong>: The Img2mol comparison only covers non-stereo subsets because stereo results were not available from the original Img2mol paper.</li>
<li><strong>Scalability to large molecules</strong>: Molecules with more than 50 non-hydrogen atoms are excluded from the dataset, and performance on such large structures is untested.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/zhang-xuan1314/ABC-Net">ABC-Net Repository</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Apache-2.0</td>
          <td style="text-align: left">Official implementation. Missing requirements.txt and pre-trained weights.</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Status: Partially Reproducible</strong>. The code is provided, but key components like the pre-trained weights, exact training environment dependencies, and the generated synthetic datasets are missing from the open-source release, making exact reproduction difficult.</p>
<h3 id="data">Data</h3>
<p>The authors constructed a synthetic dataset because labeled pixel-wise OCSR data is unavailable.</p>
<ul>
<li><strong>Source</strong>: ChEMBL database</li>
<li><strong>Filtering</strong>: Excluded molecules with &gt;50 non-H atoms or rare atom types/charges (&lt;1000 occurrences).</li>
<li><strong>Sampling</strong>: 100,000 unique SMILES selected such that every atom type/charge appears in at least 1,000 compounds.</li>
<li><strong>Generation</strong>: Images generated via <strong>RDKit</strong> and <strong>Indigo</strong> libraries.
<ul>
<li><em>Augmentation</em>: Varied bond thickness, label mode, orientation, and aromaticity markers.</li>
<li><em>Resolution</em>: $512 \times 512$ pixels.</li>
<li><em>Noise</em>: Salt-and-pepper noise added during training ($P$ = prob of background flip, $Q = 50P$).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>ChEMBL (RDKit/Indigo)</td>
          <td>80k</td>
          <td>8:1:1 split (Train/Val/Test)</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>UOB Dataset</td>
          <td>~5.7k images</td>
          <td>External benchmark from Univ. of Birmingham</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Keypoint Detection (Heatmaps)</strong></p>
<ul>
<li>
<p><strong>Down-sampling</strong>: Input $512 \times 512$ → Output $128 \times 128$ (stride 4).</p>
</li>
<li>
<p><strong>Label Softening</strong>: To handle discretization error, ground truth peaks are set to 1, first-order neighbors to 0.95, others to 0.</p>
</li>
<li>
<p><strong>Loss Function</strong>: Penalty-reduced pixel-wise binary focal loss (variants of CornerNet loss). The loss formulation is given as:</p>
<p>$$ L_{det} = - \frac{1}{N} \sum_{x,y} \begin{cases} (1 - \hat{A}_{x,y})^{\alpha} \log(\hat{A}_{x,y}) &amp; \text{if } A_{x,y} = 1 \\ (1 - A_{x,y}) (\hat{A}_{x,y})^{\alpha} \log(1 - \hat{A}_{x,y}) &amp; \text{otherwise} \end{cases} $$</p>
<ul>
<li>$\alpha=2$ (focal parameter). The $(1 - A_{x,y})$ term reduces the penalty for first-order neighbors of ground truth locations.</li>
<li>Property classification losses use a separate focal parameter $\beta=2$ with weight balancing: classes with &lt;10% frequency are weighted 10x.</li>
</ul>
</li>
</ul>
<p><strong>2. Bond Direction Classification</strong></p>
<ul>
<li><strong>Angle Binning</strong>: $360°$ divided into 60 intervals.</li>
<li><strong>Inference</strong>: A bond is detected if the angle probability is a local maximum and exceeds a threshold.</li>
<li><strong>Non-Maximum Suppression (NMS)</strong>: Required for opposite angles (e.g., $30°$ and $210°$) representing the same non-stereo bond.</li>
</ul>
<p><strong>3. Multi-Task Weighting</strong></p>
<ul>
<li>Uses Kendall&rsquo;s uncertainty weighting to balance 8 different loss terms (atom det, bond det, atom type, charge, H-count, bond angle, bond type, bond length).</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: ABC-Net (Custom U-Net / FCN)</p>
<ul>
<li><strong>Input</strong>: $512 \times 512 \times 1$ (Grayscale).</li>
<li><strong>Contracting Path</strong>: 5 steps. Each step has conv-blocks + $2 \times 2$ MaxPool.</li>
<li><strong>Expansive Path</strong>: 3 steps. Transpose-Conv upsampling + Concatenation (Skip Connections).</li>
<li><strong>Heads</strong>: Separate $1 \times 1$ convs for each task map (Atom Heatmap, Bond Heatmap, Property Maps).</li>
<li><strong>Output Dimensions</strong>:
<ul>
<li>Heatmaps: $(1, 128, 128)$</li>
<li>Bond Angles: $(60, 128, 128)$</li>
</ul>
</li>
<li><strong>Pre-trained Weights</strong>: Not included in the public <a href="https://github.com/zhang-xuan1314/ABC-Net">GitHub repository</a>. The paper&rsquo;s availability statement mentions code and training datasets but not weights.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Detection</strong>: Precision &amp; Recall (Object detection level).</li>
<li><strong>Regression</strong>: Mean Absolute Error (MAE) for bond lengths.</li>
<li><strong>Structure Recovery</strong>:
<ul>
<li><em>Accuracy</em>: Exact SMILES match rate.</li>
<li><em>Tanimoto</em>: ECFP similarity (fingerprint overlap).</li>
</ul>
</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>ABC-Net</th>
          <th>Img2mol (Baseline)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (UOB)</strong></td>
          <td><strong>96.1%</strong></td>
          <td>78.2%</td>
          <td>Non-stereo subset</td>
      </tr>
      <tr>
          <td><strong>Accuracy (Indigo)</strong></td>
          <td><strong>96.4%</strong></td>
          <td>89.5%</td>
          <td>Non-stereo subset</td>
      </tr>
      <tr>
          <td><strong>Tanimoto (UOB)</strong></td>
          <td><strong>0.989</strong></td>
          <td>0.953</td>
          <td>Higher substructure recovery</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Configuration</strong>: 15 epochs, Batch size 64.</li>
<li><strong>Optimization</strong>: Adam Optimizer. LR $2.5 \times 10^{-4}$ (first 5 epochs) → $2.5 \times 10^{-5}$ (last 10).</li>
<li><strong>Repetition</strong>: Every experiment was repeated 3 times with random dataset splitting; mean values are reported.</li>
<li><strong>Compute</strong>: High-Performance Computing Center of Central South University. Specific GPU model not listed.</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Zhang, X.-C., Yi, J.-C., Yang, G.-P., Wu, C.-K., Hou, T.-J., &amp; Cao, D.-S. (2022). ABC-Net: A divide-and-conquer based deep learning architecture for SMILES recognition from molecular images. <em>Briefings in Bioinformatics</em>, 23(2), bbac033. <a href="https://doi.org/10.1093/bib/bbac033">https://doi.org/10.1093/bib/bbac033</a></p>
<p><strong>Publication</strong>: Briefings in Bioinformatics 2022</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/zhang-xuan1314/ABC-Net">GitHub Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{zhangABCNetDivideandconquerBased2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{ABC-Net: A Divide-and-Conquer Based Deep Learning Architecture for {SMILES} Recognition from Molecular Images}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Zhang, Xiao-Chen and Yi, Jia-Cai and Yang, Guo-Ping and Wu, Cheng-Kun and Hou, Ting-Jun and Cao, Dong-Sheng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Briefings in Bioinformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{bbac033}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Oxford University Press}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1093/bib/bbac033}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Img2Mol: Accurate SMILES Recognition from Depictions</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2mol/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2mol/</guid><description>Two-stage CNN approach for converting molecular images to SMILES using CDDD embeddings and extensive data augmentation.</description><content:encoded><![CDATA[<h2 id="method-classification">Method Classification</h2>
<p>This is a <strong>method paper</strong> that introduces Img2Mol, a deep learning system for Optical Chemical Structure Recognition (OCSR). The work focuses on building a fast, accurate, and robust system for converting molecular structure depictions into machine-readable SMILES strings.</p>
<h2 id="systematization-and-motivation">Systematization and Motivation</h2>
<p>Vast amounts of chemical knowledge exist only as images in scientific literature and patents, making this data inaccessible for computational analysis, database searches, or machine learning pipelines. Manually extracting this information is slow and error-prone, creating a bottleneck for drug discovery and chemical research.</p>
<p>While rule-based OCSR systems like OSRA, MolVec, and Imago exist, they are brittle. Small variations in drawing style or image quality can cause them to fail. The authors argue that a deep learning approach, trained on diverse synthetic data, can generalize better across different depiction styles and handle the messiness of real-world images more reliably.</p>
<h2 id="two-stage-architecture-and-core-novelty">Two-Stage Architecture and Core Novelty</h2>
<p>The novelty lies in a two-stage architecture that separates perception from decoding, combined with aggressive data augmentation to ensure robustness. The key contributions are:</p>
<p><strong>1. Two-Stage Architecture with CDDD Embeddings</strong></p>
<p>Img2Mol uses an intermediate representation to predict SMILES from pixels. A <strong>custom CNN encoder</strong> maps the input image to a 512-dimensional <strong>Continuous and Data-Driven Molecular Descriptor (CDDD)</strong> embedding - a pre-trained, learned molecular representation that smoothly captures chemical similarity. A <strong>pre-trained decoder</strong> then converts this CDDD vector into the final canonical SMILES string.</p>
<p>This two-stage design has several advantages:</p>
<ul>
<li>The CDDD space is continuous and chemically meaningful, so nearby embeddings correspond to structurally similar molecules. This makes the regression task easier than learning discrete token sequences directly.</li>
<li>The decoder is pre-trained and fixed, so the CNN only needs to learn the image → CDDD mapping. This decouples the visual recognition problem from the sequence generation problem.</li>
<li>CDDD embeddings naturally enforce chemical validity constraints, reducing the risk of generating nonsensical structures.</li>
</ul>
<p><strong>2. Extensive Data Augmentation for Robustness</strong></p>
<p>The model was trained on 11.1 million unique molecules from ChEMBL and PubChem, but the critical insight is how the training images were generated. To expose the CNN to maximum variation in depiction styles, the authors:</p>
<ul>
<li>Used <strong>three different cheminformatics libraries</strong> (RDKit, OEChem, Indigo) to render images, each with its own drawing conventions</li>
<li>Applied <strong>wide-ranging augmentations</strong>: varying bond thickness, font size, rotation, resolution (originally 192-256 px; expanded to 190-2500 px in the final model), and other stylistic parameters</li>
<li><strong>Over-sampled larger molecules</strong> to improve performance on complex structures, which are underrepresented in chemical databases</li>
</ul>
<p>This ensures the network rarely sees the same depiction of a molecule twice, forcing it to learn invariant features.</p>
<p><strong>3. Fast Inference</strong></p>
<p>Because the architecture is a simple CNN followed by a fixed decoder, inference is very fast - especially compared to rule-based systems that rely on iterative graph construction algorithms. This makes Img2Mol practical for large-scale document mining.</p>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The evaluation focused on demonstrating that Img2Mol is more accurate, robust, and generalizable than existing rule-based systems:</p>
<ol>
<li>
<p><strong>Benchmark Comparisons</strong>: Img2Mol was tested on several standard OCSR benchmarks, including USPTO (patent images), University of Birmingham (UoB), CLEF, and JPO (Japanese Patent Office) datasets, against three open-source baselines: <strong>OSRA, MolVec, and Imago</strong>. No deep learning baselines were available at the time for comparison.</p>
</li>
<li>
<p><strong>Resolution and Molecular Size Analysis</strong>: The initial model, <code>Img2Mol(no aug.)</code>, was evaluated across different image resolutions and molecule sizes (measured by number of atoms) to understand failure modes. This revealed that:</p>
<ul>
<li>Performance degraded for molecules with &gt;35 atoms</li>
<li>Very high-resolution images lost detail when downscaled to the fixed input size</li>
<li>Low-resolution images (where rule-based methods failed completely) were handled well</li>
</ul>
</li>
<li>
<p><strong>Data Augmentation Ablation</strong>: A final model, <strong>Img2Mol</strong>, was trained with the full augmentation pipeline (wider resolution range, over-sampling of large molecules). Performance was compared to the initial version to quantify the effect of augmentation.</p>
</li>
<li>
<p><strong>Depiction Library Robustness</strong>: The model was tested on images generated by each of the three rendering libraries separately to confirm that training on diverse styles improved generalization.</p>
</li>
<li>
<p><strong>Input Perturbation for Benchmark Fairness</strong>: For the smaller benchmark datasets (USPTO, UoB, CLEF, JPO), the authors applied slight random rotation (within +/-5 degrees) and shearing to each image five times to detect potential overfitting of rule-based methods to well-known benchmarks.</p>
</li>
<li>
<p><strong>Generalization Tests</strong>: Img2Mol was evaluated on real-world patent images from the <strong>STAKER</strong> dataset, which were not synthetically generated. This tested whether the model could transfer from synthetic training data to real documents.</p>
</li>
<li>
<p><strong>Hand-Drawn Molecule Recognition</strong>: As an exploratory test, the authors evaluated performance on hand-drawn molecular structures, a task the model was never trained for, to see if the learned features could generalize to completely different visual styles.</p>
</li>
<li>
<p><strong>Speed Benchmarking</strong>: Inference time was measured and compared to rule-based baselines to demonstrate the practical efficiency of the approach.</p>
</li>
</ol>
<h2 id="results-conclusions-and-limitations">Results, Conclusions, and Limitations</h2>
<p>Key benchmark results from Table 1 of the paper (accuracy / Tanimoto similarity, in %):</p>
<table>
  <thead>
      <tr>
          <th>Benchmark</th>
          <th>Img2Mol</th>
          <th>MolVec 0.9.8</th>
          <th>Imago 2.0</th>
          <th>OSRA 2.1</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Img2Mol test set</td>
          <td>88.25 / 95.27</td>
          <td>2.59 / 13.03</td>
          <td>0.02 / 4.74</td>
          <td>2.59 / 13.03</td>
      </tr>
      <tr>
          <td>STAKER</td>
          <td>64.33 / 83.76</td>
          <td>5.32 / 31.78</td>
          <td>0.07 / 5.06</td>
          <td>5.23 / 26.98</td>
      </tr>
      <tr>
          <td>USPTO</td>
          <td>42.29 / 73.07</td>
          <td>30.68 / 65.50</td>
          <td>5.07 / 7.28</td>
          <td>6.37 / 44.21</td>
      </tr>
      <tr>
          <td>UoB</td>
          <td>78.18 / 88.51</td>
          <td>75.01 / 86.88</td>
          <td>5.12 / 7.19</td>
          <td>70.89 / 85.27</td>
      </tr>
      <tr>
          <td>CLEF</td>
          <td>48.84 / 78.04</td>
          <td>44.48 / 76.61</td>
          <td>26.72 / 41.29</td>
          <td>17.04 / 58.84</td>
      </tr>
      <tr>
          <td>JPO</td>
          <td>45.14 / 69.43</td>
          <td>49.48 / 66.46</td>
          <td>23.18 / 37.47</td>
          <td>33.04 / 49.62</td>
      </tr>
  </tbody>
</table>
<p>Per-library accuracy on a 5,000-compound subset (depicted five times each):</p>
<table>
  <thead>
      <tr>
          <th>Library</th>
          <th>Img2Mol</th>
          <th>MolVec</th>
          <th>Imago</th>
          <th>OSRA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RDKit</td>
          <td>93.4%</td>
          <td>3.7%</td>
          <td>0.3%</td>
          <td>4.4%</td>
      </tr>
      <tr>
          <td>OEChem</td>
          <td>89.5%</td>
          <td>33.4%</td>
          <td>12.3%</td>
          <td>26.3%</td>
      </tr>
      <tr>
          <td>Indigo</td>
          <td>79.0%</td>
          <td>22.2%</td>
          <td>4.2%</td>
          <td>22.6%</td>
      </tr>
  </tbody>
</table>
<ul>
<li>
<p><strong>Substantial Performance Gains</strong>: Img2Mol outperformed all three rule-based baselines on nearly every benchmark. MolVec scored higher on JPO (49.48% vs. 45.14% accuracy). Accuracy was measured both as exact SMILES match and as <strong>Tanimoto similarity</strong> (using ECFP6 1024-bit fingerprints). Even when Img2Mol did not predict the exact molecule, it often predicted a chemically similar one.</p>
</li>
<li>
<p><strong>Robustness Across Conditions</strong>: The full Img2Mol model (with aggressive augmentation) showed consistent performance across all image resolutions and molecule sizes. In contrast, rule-based systems were &ldquo;brittle&rdquo; - performance dropped sharply with minor perturbations to image quality or style.</p>
</li>
<li>
<p><strong>Depiction Library Invariance</strong>: Img2Mol&rsquo;s performance was stable across all three rendering libraries (RDKit, OEChem, Indigo), validating the multi-library training strategy. Rule-based methods struggled particularly with RDKit-generated images.</p>
</li>
<li>
<p><strong>Strong Generalization to Real-World Data</strong>: Despite being trained exclusively on synthetic images, Img2Mol performed well on real patent images from the STAKER dataset. This suggests the augmentation strategy successfully captured the diversity of real-world depictions.</p>
</li>
<li>
<p><strong>Overfitting in Baselines</strong>: Rule-based methods performed surprisingly well on older benchmarks (USPTO, UoB, CLEF) but failed on newer datasets (Img2Mol&rsquo;s test set, STAKER). This suggests they may be implicitly tuned to specific drawing conventions in legacy datasets.</p>
</li>
<li>
<p><strong>Limited Hand-Drawn Recognition</strong>: Img2Mol could recognize simple hand-drawn structures but struggled with complex or large molecules. This is unsurprising given the lack of hand-drawn data in training, but it highlights a potential avenue for future work.</p>
</li>
<li>
<p><strong>Speed Advantage</strong>: Img2Mol processed 5,000 images in approximately 4 minutes at the smallest input size, with compute time mostly independent of input resolution due to the fixed 224x224 rescaling. Rule-based methods showed sharply increasing compute times at higher resolutions.</p>
</li>
</ul>
<p>The work establishes that deep learning can outperform traditional rule-based OCSR systems when combined with a principled two-stage architecture and comprehensive data augmentation. The CDDD embedding acts as a bridge between visual perception and chemical structure, providing a chemically meaningful intermediate representation that improves both accuracy and robustness. The focus on synthetic data diversity proves to be an effective strategy for generalizing to real-world documents.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Custom 8-layer Convolutional Neural Network (CNN) encoder</p>
<ul>
<li><strong>Input</strong>: $224 \times 224$ pixel grayscale images</li>
<li><strong>Backbone Structure</strong>: 8 convolutional layers organized into 3 stacks, followed by 3 fully connected layers
<ul>
<li><strong>Stack 1</strong>: 3 Conv layers ($7 \times 7$ filters, stride 3, padding 4) + Max Pooling</li>
<li><strong>Stack 2</strong>: 2 Conv layers + Max Pooling</li>
<li><strong>Stack 3</strong>: 3 Conv layers + Max Pooling</li>
<li><strong>Head</strong>: 3 fully connected layers</li>
</ul>
</li>
<li><strong>Output</strong>: 512-dimensional CDDD embedding vector</li>
</ul>
<p><strong>Decoder</strong>: Pre-trained CDDD decoder (from Winter et al.) - fixed during training, not updated</p>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>: Mean Squared Error (MSE) regression minimizing the distance between the predicted and true embeddings:</p>
<p>$$
l(d) = l(\text{cddd}_{\text{true}} - \text{cddd}_{\text{predicted}})
$$</p>
<p><strong>Optimizer</strong>: AdamW with initial learning rate $10^{-4}$</p>
<p><strong>Training Schedule</strong>:</p>
<ul>
<li>Batch size: 256</li>
<li>Training duration: 300 epochs</li>
<li>Plateau scheduler: Multiplies learning rate by 0.7 if validation loss plateaus for 10 epochs</li>
<li>Early stopping: Triggered if no improvement in validation loss for 50 epochs</li>
</ul>
<p><strong>Noise Tolerance</strong>: The decoder requires the CNN to predict embeddings with noise level $\sigma \le 0.15$ to achieve &gt;90% accuracy</p>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong>: 11.1 million unique molecules from ChEMBL and PubChem</p>
<p><strong>Splits</strong>: Approximately 50,000 examples each for validation and test sets</p>
<p><strong>Synthetic Image Generation</strong>:</p>
<ul>
<li>Three cheminformatics libraries: RDKit, OEChem, and Indigo</li>
<li>Augmentations: Resolution (190-2500 pixels), rotation, bond thickness, font size</li>
<li>Salt stripping: Keep only the largest fragment</li>
<li>Over-sampling: Larger molecules (&gt;35 atoms) over-sampled to improve performance</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li>Exact SMILES match accuracy</li>
<li>Tanimoto similarity (chemical fingerprint-based structural similarity)</li>
</ul>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li>Img2Mol test set (25,000 synthetic images at 224x224 px)</li>
<li>STAKER (30,000 real-world USPTO patent images at 256x256 px)</li>
<li>USPTO (4,852 patent images, avg. 649x417 px)</li>
<li>UoB (5,716 images from University of Birmingham, avg. 762x412 px)</li>
<li>CLEF (711 images, avg. 1243x392 px)</li>
<li>JPO (365 Japanese Patent Office images, avg. 607x373 px)</li>
<li>Hand-drawn molecular structures (exploratory, no defined benchmark)</li>
</ul>
<p><strong>Baselines</strong>: OSRA, MolVec, Imago (rule-based systems)</p>
<h3 id="hardware">Hardware</h3>
<p>⚠️ <strong>Unspecified in paper or supplementary materials.</strong> Inference speed reported as ~4 minutes for 5000 images; training hardware (GPU model, count) is undocumented.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/bayer-science-for-a-better-life/Img2Mol">Img2Mol GitHub</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official implementation</td>
      </tr>
      <tr>
          <td><a href="https://github.com/bayer-science-for-a-better-life/Img2Mol">Img2Mol model weights</a></td>
          <td>Model</td>
          <td>CC BY-NC 4.0</td>
          <td>Non-commercial use only</td>
      </tr>
  </tbody>
</table>
<h3 id="known-limitations">Known Limitations</h3>
<p><strong>Molecular Size</strong>: Performance degrades for molecules with &gt;35 atoms. This is partly a property of the CDDD latent space itself: for larger molecules, the &ldquo;volume of decodable latent space&rdquo; shrinks, making the decoder more sensitive to small noise perturbations in the predicted embedding.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Clevert, D.-A., Le, T., Winter, R., &amp; Montanari, F. (2021). Img2Mol &ndash; accurate SMILES recognition from molecular graphical depictions. <em>Chemical Science</em>, 12(42), 14174&ndash;14181. <a href="https://doi.org/10.1039/d1sc01839f">https://doi.org/10.1039/d1sc01839f</a></p>
<p><strong>Publication</strong>: Chemical Science (2021)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/bayer-science-for-a-better-life/Img2Mol">GitHub Repository</a></li>
<li><a href="https://doi.org/10.1039/d1sc01839f">Paper on Royal Society of Chemistry</a></li>
</ul>
]]></content:encoded></item><item><title>Handwritten Chemical Ring Recognition with Neural Networks</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hewahi-ring-recognition-2008/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/hand-drawn/hewahi-ring-recognition-2008/</guid><description>A two-phase Classifier-Recognizer neural network pipeline for recognizing 23 types of handwritten heterocyclic chemical rings, achieving ~94% accuracy.</description><content:encoded><![CDATA[<h2 id="contribution-recognition-architecture-for-heterocyclic-rings">Contribution: Recognition Architecture for Heterocyclic Rings</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$).</p>
<p>It proposes a specific algorithmic architecture (the &ldquo;Classifier-Recognizer Approach&rdquo;) to solve a pattern recognition problem. The rhetorical structure centers on defining three variations of a method, performing ablation-like comparisons between them (Whole Image vs. Lower Part), and demonstrating superior performance metrics (~94% accuracy) for the proposed technique.</p>
<h2 id="motivation-enabling-sketch-based-chemical-search">Motivation: Enabling Sketch-Based Chemical Search</h2>
<p>The authors identify a gap in existing OCR and handwriting recognition research, which typically focuses on alphanumeric characters or whole words.</p>
<ul>
<li><strong>Missing Capability</strong>: Recognition of specific <em>heterocyclic chemical rings</em> (23 types) had not been performed previously.</li>
<li><strong>Practical Utility</strong>: Existing chemical search engines require text-based queries (names); this work enables &ldquo;backward&rdquo; search where a user can draw a ring to find its information.</li>
<li><strong>Educational/Professional Aid</strong>: Useful for chemistry departments and mobile applications where chemists can sketch formulas on screens.</li>
</ul>
<h2 id="innovation-the-classifier-recognizer-pipeline">Innovation: The Classifier-Recognizer Pipeline</h2>
<p>The core novelty is the <strong>two-phase &ldquo;Classifier-Recognizer&rdquo; architecture</strong> designed to handle the visual similarity of heterocyclic rings:</p>
<ol>
<li><strong>Phase 1 (Classifier)</strong>: A neural network classifies the ring into one of four broad categories (S, N, O, Others) based solely on the <em>upper part</em> of the image (40x15 pixels).</li>
<li><strong>Phase 2 (Recognizer)</strong>: A class-specific neural network identifies the exact ring.</li>
<li><strong>Optimization</strong>: The most successful variation (&ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo;) uses only the <em>lower part</em> of the image and <em>odd rows</em> (half-grid) to reduce input dimensionality and computation time while improving accuracy. This effectively subsamples the input grid matrix $M \in \mathbb{R}^{H \times W}$ to a reduced matrix $M_{\text{sub}}$:
$$ M_{\text{sub}} = { m_{i,j} \in M \mid i \text{ is odd} } $$</li>
</ol>
<h2 id="failed-preliminary-approaches">Failed Preliminary Approaches</h2>
<p>Before arriving at the Classifier-Recognizer architecture, the authors tried three simpler methods that all failed:</p>
<ol>
<li><strong>Ordinary NN</strong>: A single neural network with 1600 inputs (40x40 grid), 1600 hidden units, and 23 outputs. This standard approach achieved only 7% accuracy.</li>
<li><strong>Row/Column pixel counts</strong>: Using the number of black pixels per row and per column as features ($N_c + N_r$ inputs), which dramatically reduced dimensionality. This performed even worse, below 1% accuracy.</li>
<li><strong>Midline crossing count</strong>: Drawing a horizontal midline and counting the number of line crossings. This failed because the crossing count varies between writers for the same ring.</li>
</ol>
<p>These failures motivated the two-phase Classifier-Recognizer design.</p>
<h2 id="experimental-setup-and-network-variations">Experimental Setup and Network Variations</h2>
<p>The authors conducted a comparative study of three methodological variations:</p>
<ol>
<li><strong>Whole Image Recognizer</strong>: Uses the full image.</li>
<li><strong>Whole Image (Half Size Grid)</strong>: Uses only odd rows ($20 \times 40$ pixels).</li>
<li><strong>Lower Part (Half Size Grid)</strong>: Uses the lower part of the image with odd rows (the proposed method).</li>
</ol>
<p><strong>Setup</strong>:</p>
<ul>
<li><strong>Dataset</strong>: 23 types of heterocyclic rings.</li>
<li><strong>Training</strong>: 1500 samples (distributed across S, N, O, and Others classes).</li>
<li><strong>Testing</strong>: 1150 samples.</li>
<li><strong>Metric</strong>: Recognition accuracy (Performance %) and Error %.</li>
</ul>
<h2 id="results-high-accuracy-via-dimension-reduction">Results: High Accuracy via Dimension Reduction</h2>
<ul>
<li><strong>Superior Method</strong>: The &ldquo;Lower Part Image Recognizer with Half Size Grid&rdquo; achieved the best performance (~94% overall).</li>
<li><strong>High Classifier Accuracy</strong>: The first phase (classification into S/N/O/Other) achieves 100% accuracy for class S, 98.67% for O, 97.75% for N, and 97.67% for Others (Table 3).</li>
<li><strong>Class &lsquo;Others&rsquo; Difficulty</strong>: The &lsquo;Others&rsquo; class showed lower performance (~90-93%) compared to S/N/O due to the higher complexity and similarity of rings in that category.</li>
<li><strong>Efficiency</strong>: The half-grid approach reduced training time from ~53 hours (Whole Image) to ~35 hours (Lower Part Half Size Grid) while improving accuracy from 87% to 94%.</li>
</ul>
<p><strong>Training/Testing comparison across the three Classifier-Recognizer variations (Table 2)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Method</th>
          <th style="text-align: left">Hidden Nodes</th>
          <th style="text-align: left">Iterations</th>
          <th style="text-align: left">Training Time (hrs)</th>
          <th style="text-align: left">Error</th>
          <th style="text-align: left">Performance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">Whole Image</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~53</td>
          <td style="text-align: left">13.0%</td>
          <td style="text-align: left">87.0%</td>
      </tr>
      <tr>
          <td style="text-align: left">Whole Image (Half Grid)</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~41</td>
          <td style="text-align: left">9.0%</td>
          <td style="text-align: left">91.0%</td>
      </tr>
      <tr>
          <td style="text-align: left">Lower Part (Half Grid)</td>
          <td style="text-align: left">50</td>
          <td style="text-align: left">1000</td>
          <td style="text-align: left">~35</td>
          <td style="text-align: left">6.0%</td>
          <td style="text-align: left">94.0%</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The dataset consists of handwritten samples of 23 specific heterocyclic rings.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Purpose</th>
          <th style="text-align: left">Dataset</th>
          <th style="text-align: left">Size</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>Training</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1500 samples</td>
          <td style="text-align: left">Split: 300 (S), 400 (N), 400 (O), 400 (Others)</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Testing</strong></td>
          <td style="text-align: left">Heterocyclic Rings</td>
          <td style="text-align: left">1150 samples</td>
          <td style="text-align: left">Split: 150 (S), 300 (O), 400 (N), 300 (Others)</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing Steps</strong>:</p>
<ol>
<li><strong>Monochrome Conversion</strong>: Convert image to monochrome bitmap.</li>
<li><strong>Grid Scaling</strong>: Convert drawing area (regardless of original size) to a fixed <strong>40x40</strong> grid.</li>
<li><strong>Bounding</strong>: Scale the ring shape itself to fit the 40x40 grid.</li>
</ol>
<h3 id="algorithms">Algorithms</h3>
<p><strong>The &ldquo;Lower Part with Half Size&rdquo; Pipeline</strong>:</p>
<ol>
<li><strong>Cut Point</strong>: A horizontal midline is defined; the algorithm separates the &ldquo;Upper Part&rdquo; and &ldquo;Lower Part&rdquo;.</li>
<li><strong>Phase 1 Input</strong>: The <strong>Upper Part</strong> (rows 0-15 approx, scaled) is fed to the Classifier NN to determine the class (S, N, O, or Others).</li>
<li><strong>Phase 2 Input</strong>:
<ul>
<li>For classes <strong>S, N, O</strong>: The <strong>Lower Part</strong> of the image is used.</li>
<li>For class <strong>Others</strong>: The <strong>Whole Ring</strong> is used.</li>
</ul>
</li>
<li><strong>Dimensionality Reduction</strong>: For the recognizer networks, only <strong>odd rows</strong> are used (effectively a 20x40 input grid) to reduce inputs from 1600 to 800.</li>
</ol>
<h3 id="models">Models</h3>
<p>The system uses multiple distinct Feed-Forward Neural Networks (Backpropagation is implied by &ldquo;training&rdquo; and &ldquo;epochs&rdquo; context, though not explicitly named as the algorithm):</p>
<ul>
<li><strong>Structure</strong>: 1 Classifier NN + 4 Recognizer NNs (one for each class).</li>
<li><strong>Hidden Layers</strong>: The preliminary &ldquo;ordinary method&rdquo; experiment used 1600 hidden units. The Classifier-Recognizer methods all used 50 hidden nodes per Table 2. The paper also notes that the ordinary approach tried various hidden layer sizes.</li>
<li><strong>Input Nodes</strong>:
<ul>
<li>Standard: 1600 (40x40).</li>
<li>Optimized: ~800 (20x40 via half-grid).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Classifier Phase Testing Results (Table 3)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Class</th>
          <th style="text-align: left">Samples</th>
          <th style="text-align: left">Correct</th>
          <th style="text-align: left">Accuracy</th>
          <th style="text-align: left">Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>S</strong></td>
          <td style="text-align: left">150</td>
          <td style="text-align: left">150</td>
          <td style="text-align: left"><strong>100.00%</strong></td>
          <td style="text-align: left">0.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>O</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">296</td>
          <td style="text-align: left"><strong>98.67%</strong></td>
          <td style="text-align: left">1.33%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>N</strong></td>
          <td style="text-align: left">400</td>
          <td style="text-align: left">391</td>
          <td style="text-align: left"><strong>97.75%</strong></td>
          <td style="text-align: left">2.25%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Others</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">293</td>
          <td style="text-align: left"><strong>97.67%</strong></td>
          <td style="text-align: left">2.33%</td>
      </tr>
  </tbody>
</table>
<p><strong>Recognizer Phase Testing Results (Lower Part Image Recognizer with Half Size Grid, Table 4)</strong>:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Class</th>
          <th style="text-align: left">Samples</th>
          <th style="text-align: left">Correct</th>
          <th style="text-align: left">Accuracy</th>
          <th style="text-align: left">Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><strong>S</strong></td>
          <td style="text-align: left">150</td>
          <td style="text-align: left">147</td>
          <td style="text-align: left"><strong>98.00%</strong></td>
          <td style="text-align: left">2.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>O</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">289</td>
          <td style="text-align: left"><strong>96.33%</strong></td>
          <td style="text-align: left">3.67%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>N</strong></td>
          <td style="text-align: left">400</td>
          <td style="text-align: left">386</td>
          <td style="text-align: left"><strong>96.50%</strong></td>
          <td style="text-align: left">3.50%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Others</strong></td>
          <td style="text-align: left">300</td>
          <td style="text-align: left">279</td>
          <td style="text-align: left"><strong>93.00%</strong></td>
          <td style="text-align: left">7.00%</td>
      </tr>
      <tr>
          <td style="text-align: left"><strong>Overall</strong></td>
          <td style="text-align: left"><strong>1150</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
          <td style="text-align: left"><strong>~94.0%</strong></td>
          <td style="text-align: left"><strong>-</strong></td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-assessment">Reproducibility Assessment</h3>
<p>No source code, trained models, or datasets were released with this paper. The handwritten ring samples were collected by the authors, and the software described (a desktop application) is not publicly available. The neural network architecture details (50 hidden nodes, 1000 iterations) and preprocessing pipeline are described in sufficient detail for reimplementation, but reproducing results would require collecting a new handwritten dataset of heterocyclic rings.</p>
<p><strong>Status</strong>: Closed (no public code, data, or models).</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hewahi, N., Nounou, M. N., Nassar, M. S., Abu-Hamad, M. I., &amp; Abu-Hamad, H. I. (2008). Chemical Ring Handwritten Recognition Based on Neural Networks. <em>Ubiquitous Computing and Communication Journal</em>, 3(3).</p>
<p><strong>Publication</strong>: Ubiquitous Computing and Communication Journal 2008</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{hewahiCHEMICALRINGHANDWRITTEN2008,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{CHEMICAL RING HANDWRITTEN RECOGNITION BASED ON NEURAL NETWORKS}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Hewahi, Nabil and Nounou, Mohamed N and Nassar, Mohamed S and Abu-Hamad, Mohamed I and Abu-Hamad, Husam I}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2008}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Ubiquitous Computing and Communication Journal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{3}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Deep Learning for Molecular Structure Extraction (2019)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/staker-deep-learning-2019/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/staker-deep-learning-2019/</guid><description>An end-to-end deep learning approach using U-Net segmentation and a CNN encoder with GridLSTM decoder to predict chemical structures from document images.</description><content:encoded><![CDATA[<h2 id="contribution-type-method-and-resource">Contribution Type: Method and Resource</h2>
<p>This is primarily a <strong>methodological</strong> paper with a secondary <strong>resource</strong> contribution.</p>
<p><strong>Method</strong>: It proposes a novel end-to-end deep learning architecture (Segmentation U-Net + Recognition Encoder-Decoder) to replace traditional rule-based optical chemical structure recognition (OCSR) systems.</p>
<p><strong>Resource</strong>: It details a pipeline for generating large-scale synthetic datasets (images overlaying patent/journal backgrounds) necessary to train the deep learning models.</p>
<h2 id="motivation-overcoming-brittle-rule-based-systems">Motivation: Overcoming Brittle Rule-Based Systems</h2>
<p>Existing tools for extracting chemical structures from literature (e.g., OSRA, CLIDE) rely on complex, handcrafted rules and heuristics (edge detection, vectorization). These systems suffer from:</p>
<ol>
<li><strong>Brittleness</strong>: They fail when image quality is low (low resolution, noise) or when artistic styles vary (wavy bonds, crossing lines).</li>
<li><strong>Maintenance difficulty</strong>: Improvements require manual codification of new rules for every edge case, which is difficult to scale.</li>
<li><strong>Data volume</strong>: The explosion of published life science papers (2000+ per day in Medline) creates a need for automated, robust curation tools that humans cannot match.</li>
</ol>
<h2 id="core-innovation-end-to-end-pixel-to-smiles-recognition">Core Innovation: End-to-End Pixel-to-SMILES Recognition</h2>
<p>The authors present an <strong>end-to-end deep learning approach</strong> for this task that operates directly on raw pixels without explicit subcomponent recognition (e.g., detecting atoms and bonds separately). Key innovations include:</p>
<ol>
<li><strong>Pixel-to-SMILES</strong>: Treating structure recognition as an image captioning problem using an encoder-decoder architecture with attention, generating SMILES directly.</li>
<li><strong>Low-Resolution Robustness</strong>: The model is trained on aggressively downsampled images (~60 dpi for segmentation, 256x256 for prediction), making it robust to poor quality and noisy inputs from legacy PDF extractions.</li>
<li><strong>Implicit Superatom Handling</strong>: The model learns to recognize and generate sequences for superatoms (e.g., &ldquo;OTBS&rdquo;) contextually.</li>
</ol>
<h2 id="experimental-setup-and-large-scale-synthetic-data">Experimental Setup and Large-Scale Synthetic Data</h2>
<p>The authors validated their approach using a mix of large-scale synthetic training sets and real-world test sets:</p>
<ol>
<li><strong>Synthetic Generation</strong>: They created a segmentation dataset by overlaying USPTO molecules onto &ldquo;whited-out&rdquo; journal pages.</li>
<li><strong>Ablation/Training</strong>: Metrics were tracked on Indigo (synthetic) and USPTO (real patent images) datasets.</li>
<li><strong>External Validation</strong>:
<ul>
<li><strong>Valko Dataset</strong>: A standard benchmark of 454 heterogeneous images from literature.</li>
<li><strong>Proprietary Dataset</strong>: A collection of images from 47 articles and 5 patents to simulate real-world drug discovery curation.</li>
</ul>
</li>
<li><strong>Stress Testing</strong>: They analyzed performance distributions across molecular weight, heavy atom count, and rare elements (e.g., Uranium, Vanadium).</li>
</ol>
<h2 id="results-and-limitations-in-complex-structures">Results and Limitations in Complex Structures</h2>
<ul>
<li><strong>High Accuracy on Standard Sets</strong>: The model achieved <strong>82% accuracy</strong> on the Indigo validation set and <strong>77%</strong> on the USPTO validation set. No apparent overfitting was observed on the Indigo data (57M training examples), though some overfitting occurred on the smaller USPTO set (1.7M training examples).</li>
<li><strong>Real-World Viability</strong>: It achieved <strong>83% accuracy</strong> on the proprietary internal test set, with validation and proprietary accuracies ranging from 77-83%, indicating the training sets reasonably approximate real drug discovery data.</li>
<li><strong>Segmentation Quality</strong>: Low segmentation error rates were observed: only 3.3% of the Valko dataset and 6.6% of the proprietary images failed to segment properly.</li>
<li><strong>Limitations on Complexity</strong>: Performance dropped to <strong>41% on the Valko test set</strong>. Superatoms were the single largest contributor to prediction errors, with 21% of Valko samples containing one or more incorrectly predicted superatoms. Only 6.6% of total training images contained any superatom, limiting the model&rsquo;s exposure.</li>
<li><strong>Stereochemistry Challenges</strong>: 60% of compounds with incorrectly predicted stereochemistry had explicit stereochemistry in both the ground truth and the prediction, but with wrong configurations assigned (e.g., predicting R instead of S). The model often correctly identified which atoms have stereocenters but assigned the wrong direction, suggesting the architecture may not incorporate sufficient spatial context for configuration assignment.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors utilized three primary sources for generating training data. All inputs were strictly downsampled to improve robustness.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>Indigo Set</strong></td>
          <td>57M</td>
          <td>PubChem molecules rendered via Indigo (256x256).</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>USPTO Set</strong></td>
          <td>1.7M</td>
          <td>Image/SMILES pairs from public patent data.</td>
      </tr>
      <tr>
          <td><strong>Training</strong></td>
          <td><strong>OS X Indigo</strong></td>
          <td>10M</td>
          <td>Additional Indigo renders from Mac OS for style diversity.</td>
      </tr>
      <tr>
          <td><strong>Segmentation</strong></td>
          <td><strong>Synthetic Pages</strong></td>
          <td>N/A</td>
          <td>Generated by overlaying USPTO images on text-cleared PDF pages.</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Segmentation Inputs</strong>: Grayscale, downsampled to ~60 dpi.</li>
<li><strong>Prediction Inputs</strong>: Resized to 256x256 such that bond lengths are approximately 3-12 pixels.</li>
<li><strong>Augmentation</strong>: Random affine transforms, brightness scaling, and binarization applied during training.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Segmentation Pipeline</strong>:</p>
<ul>
<li><strong>Multi-scale Inference</strong>: Masks generated at resolutions from 30 to 60 dpi (3 dpi increments) and averaged for the final mask.</li>
<li><strong>Post-processing</strong>: Hough transform used to remove long straight lines (table borders). Mask blobs filtered by pixel count thresholds.</li>
</ul>
<p><strong>Prediction Pipeline</strong>:</p>
<ul>
<li><strong>Sequence Generation</strong>: SMILES generated character-by-character via greedy decoding. During inference, predictions are made at several low resolutions and the sequence with the highest confidence (product of per-character softmax outputs) is returned.</li>
<li><strong>Attention-based Verification</strong>: Attention weights used to re-project predicted atoms back into 2D space to visually verify alignment with the input image.</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>1. Segmentation Model (U-Net Variant)</strong>:</p>
<ul>
<li><strong>Architecture</strong>: U-Net style with skip connections.</li>
<li><strong>Input</strong>: 128x128x1 grayscale image.</li>
<li><strong>Layers</strong>: Alternating 3x3 Conv and 2x2 Max Pool.</li>
<li><strong>Activation</strong>: Parametric ReLU (pReLU).</li>
<li><strong>Parameters</strong>: ~380,000.</li>
</ul>
<p><strong>2. Structure Prediction Model (Encoder-Decoder)</strong>:</p>
<ul>
<li><strong>Encoder</strong>: CNN with 5x5 convolutions, 2x2 Max Pooling, pReLU. No pooling in first layers to preserve fine features.</li>
<li><strong>Decoder</strong>: 3 layers of <strong>GridLSTM</strong> cells.</li>
<li><strong>Attention</strong>: Soft/Global attention mechanism conditioned on the encoder state.</li>
<li><strong>Input</strong>: 256x256x1 image.</li>
<li><strong>Output</strong>: Sequence of characters (vocab size 65).</li>
<li><strong>Parameters</strong>: ~46.3 million.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Evaluation required an exact string match of the Canonical SMILES (including stereochemistry) to the ground truth.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Dataset</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Accuracy</td>
          <td><strong>82%</strong></td>
          <td>Indigo Val</td>
          <td>Synthetic validation set</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>77%</strong></td>
          <td>USPTO Val</td>
          <td>Real patent images</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>83%</strong></td>
          <td>Proprietary</td>
          <td>Internal pharma dataset (real world)</td>
      </tr>
      <tr>
          <td>Accuracy</td>
          <td><strong>41%</strong></td>
          <td>Valko Test</td>
          <td>External benchmark; difficult due to superatoms</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Segmentation Training</strong>: 1 GPU, ~4 days (650k steps).</li>
<li><strong>Prediction Training</strong>: 8 NVIDIA Pascal GPUs, ~26 days (1M steps).</li>
<li><strong>Framework</strong>: TensorFlow.</li>
<li><strong>Optimizer</strong>: Adam.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>No public code, pre-trained models, or generated datasets were released with this paper. The training pipeline relies on publicly available molecular databases (PubChem, USPTO) and open-source rendering tools (Indigo), but the specific training sets, model weights, and inference code remain unavailable.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Staker, J., Marshall, K., Abel, R., &amp; McQuaw, C. (2019). Molecular Structure Extraction From Documents Using Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 59(3), 1017-1029. <a href="https://doi.org/10.1021/acs.jcim.8b00669">https://doi.org/10.1021/acs.jcim.8b00669</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling (JCIM) 2019</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.schrodinger.com/publications/">Schrödinger Publication Page</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{stakerMolecularStructureExtraction2019,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Molecular Structure Extraction From Documents Using Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Staker, Joshua and Marshall, Kyle and Abel, Robert and McQuaw, Carolyn}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2019}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = <span style="color:#e6db74">{feb}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{59}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{3}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1017--1029}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/acs.jcim.8b00669}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1021/acs.jcim.8b00669}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DECIMER: Deep Learning for Chemical Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/decimer/</guid><description>Deep learning method for optical chemical structure recognition using image captioning networks trained on millions of synthetic molecular images.</description><content:encoded><![CDATA[<h2 id="contribution-method-for-optical-chemical-entity-recognition">Contribution: Method for Optical Chemical Entity Recognition</h2>
<p>This is primarily a <strong>Method ($\Psi_{\text{Method}}$)</strong> paper with a strong <strong>Resource ($\Psi_{\text{Resource}}$)</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes a novel architecture (DECIMER) that repurposes &ldquo;show-and-tell&rdquo; image captioning networks for Optical Chemical Entity Recognition (OCER), providing an alternative to traditional rule-based segmentation pipelines.</li>
<li><strong>Resource</strong>: It establishes a framework for generating large-scale synthetic training data using open-source cheminformatics tools (CDK) and databases (PubChem), circumventing the scarcity of manually annotated chemical images.</li>
</ul>
<h2 id="motivation-brittleness-of-heuristic-pipelines">Motivation: Brittleness of Heuristic Pipelines</h2>
<p>The extraction of chemical structures from scientific literature (OCER) is critical for populating open-access databases. Traditional OCER systems (like OSRA or CLiDE) rely on complex multi-step pipelines involving vectorization, character recognition, and graph compilation. These systems are brittle and incorporating new structural features requires laborious engineering. Inspired by the success of deep neural network approaches like AlphaGo Zero, the authors sought to formulate an end-to-end deep learning approach that learns directly from data with minimal prior assumptions.</p>
<h2 id="novelty-image-captioning-for-molecular-graphs">Novelty: Image Captioning for Molecular Graphs</h2>
<ul>
<li><strong>Image-to-Text Formulation</strong>: The paper frames chemical structure recognition as an image captioning problem, translating a bitmap image directly into a SMILES string using an encoder-decoder network. This bypasses explicit segmentation of atoms and bonds entirely.</li>
<li><strong>Synthetic Data Strategy</strong>: The authors generate synthetic images from PubChem using the CDK Structure Diagram Generator, scaling the dataset size to 15 million.</li>
<li><strong>Robust String Representations</strong>: The study performs key ablation experiments on string representations, comparing standard SMILES against DeepSMILES to evaluate how syntactic validity affects the network&rsquo;s learning capability.</li>
</ul>
<h2 id="experimental-setup-and-validation-strategies">Experimental Setup and Validation Strategies</h2>
<ul>
<li><strong>Data Scaling</strong>: Models were trained on dataset sizes ranging from 54,000 to 15 million synthetic images to observe empirical scaling laws regarding accuracy and compute time.</li>
<li><strong>Representation Comparison</strong>: The authors compared the validity of predicted strings and recognition accuracy when training on SMILES versus DeepSMILES. The cross-entropy loss formulation for sequence generation can be represented as:
$$ \mathcal{L} = -\sum_{t=1}^{T} \log P(y_t \mid y_{&lt;t}, \mathbf{x}) $$
where $\mathbf{x}$ is the image representation and $y_t$ are the tokens of the SMILES/DeepSMILES string.</li>
<li><strong>Metric Evaluation</strong>: Performance was measured using Validity (syntactic correctness) and Tanimoto Similarity $T$, computed on molecular fingerprints to capture partial correctness even if the exact string prediction failed:
$$ T(A, B) = \frac{|A \cap B|}{|A| + |B| - |A \cap B|} $$</li>
</ul>
<h2 id="results-and-critical-conclusions">Results and Critical Conclusions</h2>
<ul>
<li><strong>Data Representation</strong>: DeepSMILES proved superior to standard SMILES for training stability and output validity. Preliminary tests suggested SELFIES performs even better (0.78 Tanimoto vs 0.53 for DeepSMILES at 6M images).</li>
<li><strong>Scaling Behavior</strong>: Accuracy improves linearly with dataset size. The authors extrapolate that near-perfect detection would require training on 50 to 100 million structures.</li>
<li><strong>Current Limitations</strong>: At the reported training scale (up to 15M), the model does not yet rival traditional heuristic approaches, but the learning curve suggests it is a viable trajectory given sufficient compute and data.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The training data is synthetic, generated using the Chemistry Development Kit (CDK) Structure Diagram Generator (SDG) based on molecules from PubChem.</p>
<p><strong>Curation Rules</strong> (applied to PubChem data):</p>
<ul>
<li>Molecular weight &lt; 1500 Daltons.</li>
<li>Elements restricted to: C, H, O, N, P, S, F, Cl, Br, I, Se, B.</li>
<li>No counter ions or charged groups.</li>
<li>No isotopes (e.g., D, T).</li>
<li>Bond count between 5 and 40.</li>
<li>SMILES length &lt; 40 characters.</li>
<li>Implicit hydrogens only (except in functional groups).</li>
</ul>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li><strong>Images</strong>: Generated as 299x299 bitmaps to match Inception V3 input requirements.</li>
<li><strong>Augmentation</strong>: One random rotation applied per molecule; no noise or blurring added in this iteration.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic (PubChem)</td>
          <td>54k - 15M</td>
          <td>Scaled across 12 experiments</td>
      </tr>
      <tr>
          <td>Testing</td>
          <td>Independent Set</td>
          <td>6k - 1.6M</td>
          <td>10% of training size</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Architecture</strong>: <code>&quot;Show, Attend and Tell&quot;</code> (Attention-based Image Captioning).</li>
<li><strong>Optimization</strong>: Adam optimizer with learning rate 0.0005.</li>
<li><strong>Loss Function</strong>: Sparse Categorical Crossentropy.</li>
<li><strong>Training Loop</strong>: Trained for 25 epochs per model. Batch size of 640 images.</li>
</ul>
<h3 id="models">Models</h3>
<p>The network is implemented in TensorFlow 2.0.</p>
<ul>
<li><strong>Encoder</strong>: Inception V3 (Convolutional NN), used unaltered. Extracts feature vectors saved as NumPy arrays.</li>
<li><strong>Decoder</strong>: Gated Recurrent Unit (GRU) based Recurrent Neural Network (RNN) with soft attention mechanism.</li>
<li><strong>Embeddings</strong>: Image embedding dimension size of 600.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The primary metric is Tanimoto similarity (Jaccard index) on PubChem fingerprints, which is robust for measuring structural similarity even when exact identity is not reached.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Definition</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Tanimoto 1.0</strong></td>
          <td>Percentage of predictions that are chemically identical to ground truth (isomorphic).</td>
      </tr>
      <tr>
          <td><strong>Average Tanimoto</strong></td>
          <td>Mean similarity score across the test set (captures partial correctness).</td>
      </tr>
      <tr>
          <td><strong>Validity</strong></td>
          <td>Percentage of predicted strings that are valid DeepSMILES/SMILES.</td>
      </tr>
  </tbody>
</table>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER">DECIMER (Java utilities)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>CDK-based data generation and conversion tools</td>
      </tr>
      <tr>
          <td><a href="https://github.com/Kohulan/DECIMER-Image-to-SMILES">DECIMER-Image-to-SMILES</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>TensorFlow training and inference scripts (archived)</td>
      </tr>
      <tr>
          <td><a href="https://pubchem.ncbi.nlm.nih.gov/">PubChem</a></td>
          <td>Dataset</td>
          <td>Public Domain</td>
          <td>Source of molecular structures for synthetic training data</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Training was performed on a single node.</p>
<ul>
<li><strong>GPU</strong>: 1x NVIDIA Tesla V100.</li>
<li><strong>CPU</strong>: 2x Intel Xeon Gold 6230.</li>
<li><strong>RAM</strong>: 384 GB.</li>
<li><strong>Compute Time</strong>:
<ul>
<li>Linear scaling with data size.</li>
<li>15 million structures took ~27 days (91,881s per epoch).</li>
<li>Projected time for 100M structures: ~4 months on single GPU.</li>
</ul>
</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rajan, K., Zielesny, A. &amp; Steinbeck, C. (2020). DECIMER: towards deep learning for chemical image recognition. <em>Journal of Cheminformatics</em>, 12(1), 65. <a href="https://doi.org/10.1186/s13321-020-00469-w">https://doi.org/10.1186/s13321-020-00469-w</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2020</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/Kohulan/DECIMER">Official GitHub Repository</a></li>
<li><a href="https://github.com/Kohulan/DECIMER-Image-to-SMILES">DECIMER Image-to-SMILES Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{rajanDECIMERDeepLearning2020,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{{DECIMER}}: Towards Deep Learning for Chemical Image Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{{{DECIMER}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Rajan, Kohulan and Zielesny, Achim and Steinbeck, Christoph}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2020</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = oct,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{12}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{65}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{1758-2946}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-020-00469-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>ChemGrapher: Deep Learning for Chemical Graph OCSR</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/chemgrapher-2020/</link><pubDate>Wed, 17 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-graph/chemgrapher-2020/</guid><description>Deep learning OCSR method using semantic segmentation and classification CNNs to reconstruct chemical graphs with improved stereochemistry.</description><content:encoded><![CDATA[<h2 id="classifying-the-methodology">Classifying the Methodology</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel deep learning architecture and a specific graph-reconstruction algorithm to solve the problem of Optical Chemical Structure Recognition (OCSR). It validates this method by comparing it against the existing standard tool (OSRA), demonstrating superior performance on specific technical challenges like stereochemistry.</p>
<h2 id="the-ocr-stereochemistry-challenge">The OCR Stereochemistry Challenge</h2>
<p>Chemical knowledge is frequently locked in static images within scientific publications. Extracting this structure into machine-readable formats (graphs, SMILES) is essential for drug discovery and database querying. Existing tools, such as OSRA, rely on optical character recognition (OCR) and expert systems or hand-coded rules. These tools struggle with bond multiplicity and stereochemical information, often missing atoms or misinterpreting 3D cues (wedges and dashes). A machine learning approach allows for improvement via data scaling.</p>
<h2 id="decoupled-semantic-segmentation-and-classification-pipeline">Decoupled Semantic Segmentation and Classification Pipeline</h2>
<p>The core novelty is the <strong>segmentation-classification pipeline</strong> which decouples object detection from type assignment:</p>
<ol>
<li><strong>Semantic Segmentation</strong>: The model first predicts pixel-wise maps for atoms, bonds, and charges using a Dense Prediction Convolutional Network built on dilated convolutions.</li>
<li><strong>Graph Building Algorithm</strong>: A specific algorithm iterates over the segmentation maps to generate candidate locations for atoms and bonds.</li>
<li><strong>Refinement via Classification</strong>: Dedicated classification networks take cutouts of the original image combined with the segmentation mask to verify and classify each candidate (e.g., distinguishing a single bond from a double bond, or a wedge from a dash).</li>
</ol>
<p>Additionally, the authors developed a novel method for <strong>synthetic data generation</strong> by modifying the source code of RDKit to output pixel-wise labels during the image drawing process. This solves the lack of labeled training data.</p>
<h2 id="evaluating-synthetics-and-benchmarks">Evaluating Synthetics and Benchmarks</h2>
<ul>
<li><strong>Synthetic Benchmarking</strong>: The authors generated test sets in 3 different stylistic variations. For each style, they tested on both stereo (complex 3D information) and non-stereo compounds.</li>
<li><strong>Baseline Comparison</strong>: They compared the error rates of ChemGrapher against <strong>OSRA</strong> (Optical Structure Recognition Application).</li>
<li><strong>Component-level Evaluation</strong>: They analyzed the F1 scores of the segmentation networks versus the classification networks independently to understand where errors propagated.</li>
<li><strong>Real-world Case Study</strong>: They manually curated 61 images cut from journal articles to test performance on real, non-synthetic data.</li>
</ul>
<h2 id="advancements-over-osra">Advancements Over OSRA</h2>
<ul>
<li><strong>Superior Accuracy</strong>: ChemGrapher consistently achieved lower error rates than OSRA across all synthetic styles, particularly for stereochemical information (wedge and dash bonds).</li>
<li><strong>Component Performance</strong>: The classification networks showed higher F1 scores than the segmentation networks across all prediction types (Figure 4 in the paper). This suggests the two-stage approach allows the classifier to correct segmentation noise.</li>
<li><strong>Real-world Viability</strong>: In the manual case study, ChemGrapher correctly predicted 46 of 61 images, compared to 42 of 61 for OSRA.</li>
<li><strong>Limitations</strong>: The model struggles with thick bond lines in real-world images. Performance is stronger on carbon-only compounds, where no letters appear in the image.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors created a custom synthetic dataset using ChEMBL and RDKit, as no pixel-wise labeled dataset existed.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Source</strong></td>
          <td>ChEMBL</td>
          <td>1.9M</td>
          <td>Split into training pool (1.5M), val/train pool (300K), and test pools (35K each).</td>
      </tr>
      <tr>
          <td><strong>Segmentation Train</strong></td>
          <td>Synthetic</td>
          <td>~114K</td>
          <td>Sampled from ChEMBL pool such that every atom type appears in &gt;1000 compounds.</td>
      </tr>
      <tr>
          <td><strong>Labels</strong></td>
          <td>Pixel-wise</td>
          <td>N/A</td>
          <td>Generated by modifying <strong>RDKit</strong> source code to output label masks (atom type, bond type, charge) during drawing.</td>
      </tr>
      <tr>
          <td><strong>Candidates (Val)</strong></td>
          <td>Cutouts</td>
          <td>~27K (Atom)<br>~55K (Bond)</td>
          <td>Validation candidates generated from ~450 compounds for evaluating the classification networks.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Algorithm 1: Graph Building</strong></p>
<ol>
<li><strong>Segment</strong>: Apply segmentation network $s(x)$ to get maps $S^a$ (atoms), $S^b$ (bonds), $S^c$ (charges).</li>
<li><strong>Atom Candidates</strong>: Identify candidate blobs in $S^a$.</li>
<li><strong>Classify Atoms</strong>: For each candidate, crop the input image and segmentation map. Feed to $c_A$ and $c_C$ to predict Atom Type and Charge. Add to Vertex set $V$ if valid.</li>
<li><strong>Bond Candidates</strong>: Generate all pairs of nodes in $V$ within $2 \times$ bond length distance.</li>
<li><strong>Classify Bonds</strong>: For each pair, create a candidate mask (two rectangles meeting in the middle to encode directionality). Feed to $c_B$ to predict Bond Type (single, double, wedge, etc.). Add to Edge set $E$.</li>
</ol>
<h3 id="models">Models</h3>
<p>The pipeline uses four distinct Convolutional Neural Networks (CNNs).</p>
<p><strong>1. Semantic Segmentation Network ($s$)</strong></p>
<ul>
<li><strong>Architecture</strong>: 8 convolutional layers (3x3) plus a final 1x1 linear layer (Dense Prediction Convolutional Network).</li>
<li><strong>Kernels</strong>: $3 \times 3$ for all convolutional layers; $1 \times 1$ for the final linear layer.</li>
<li><strong>Dilation</strong>: Uses dilated convolutions to expand receptive field without losing resolution. Six of the eight convolutional layers use dilation (factors: 2, 4, 8, 8, 4, 2); the first and last convolutional layers have no dilation.</li>
<li><strong>Input</strong>: Binary B/W image.</li>
<li><strong>Output</strong>: Multi-channel probability maps for Atom Types ($S^a$), Bond Types ($S^b$), and Charges ($S^c$).</li>
</ul>
<p><strong>2. Classification Networks ($c_A, c_B, c_C$)</strong></p>
<ul>
<li><strong>Purpose</strong>: Refines predictions on small image patches.</li>
<li><strong>Architecture</strong>: 5 convolutional layers, followed by a MaxPool layer and a final linear (1x1) layer.
<ul>
<li>Layer 1: <strong>Depthwise separable convolution</strong> (no dilation).</li>
<li>Layers 2-4: Dilated convolutions (factors 2, 4, 8).</li>
<li>Layer 5: Standard convolution (no dilation).</li>
<li>MaxPool: $124 \times 124$.</li>
<li>Final: 1x1 linear layer.</li>
</ul>
</li>
<li><strong>Inputs</strong>:
<ul>
<li>Crop of the binary image ($x^{cut}$).</li>
<li>Crop of the segmentation map ($S^{cut}$).</li>
<li>&ldquo;Highlight&rdquo; mask ($h_L$) indicating the specific candidate location (e.g., a dot for atoms, two rectangles for bonds).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric</strong>: <strong>F1 Score</strong> for individual network performance (segmentation pixels and classification accuracy).</li>
<li><strong>Metric</strong>: <strong>Error Rate</strong> (percentage of incorrect graphs) for overall system. A graph is &ldquo;incorrect&rdquo; if there is at least one mistake in atoms or bonds.</li>
<li><strong>Baselines</strong>: Compared against <strong>OSRA</strong>.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Training and inference performed on a single <strong>NVIDIA Titan Xp</strong> (donated by NVIDIA).</li>
</ul>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<p><strong>Closed.</strong> The authors did not release source code, pre-trained models, or the synthetic dataset. The data generation pipeline requires modifications to RDKit&rsquo;s internal drawing code, which are not publicly available. The ChEMBL source compounds are public, but the pixel-wise labeling procedure cannot be reproduced without the modified RDKit code.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Oldenhof, M., Arany, Á., Moreau, Y., &amp; Simm, J. (2020). ChemGrapher: Optical Graph Recognition of Chemical Compounds by Deep Learning. <em>Journal of Chemical Information and Modeling</em>, 60(10), 4506-4517. <a href="https://doi.org/10.1021/acs.jcim.0c00459">https://doi.org/10.1021/acs.jcim.0c00459</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Modeling 2020 (arXiv preprint Feb 2020)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2002.09914">arXiv Page</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{oldenhof2020chemgrapher,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{ChemGrapher: Optical Graph Recognition of Chemical Compounds by Deep Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Oldenhof, Martijn and Arany, Ádám and Moreau, Yves and Simm, Jaak}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Chemical Information and Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{60}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{10}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{4506--4517}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{ACS Publications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.1021/acs.jcim.0c00459}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kekulé-1 System for Chemical Structure Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/kekule-1996/</link><pubDate>Mon, 15 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/kekule-1996/</guid><description>Foundational OCSR method combining neural OCR with chemical rule-based post-processing for automated structure interpretation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: McDaniel, J. R., &amp; Balmuth, J. R. (1996). Automatic Interpretation of Chemical Structure Diagrams. <em>Graphics Recognition. Methods and Applications</em>, 148-158. <a href="https://doi.org/10.1007/3-540-61226-2_13">https://doi.org/10.1007/3-540-61226-2_13</a></p>
<p><strong>Publication</strong>: Lecture Notes in Computer Science (LNCS), Vol. 1072, Springer, 1996.</p>
<h2 id="system-architecture-and-contribution">System Architecture and Contribution</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel software architecture (&ldquo;Kekulé-1&rdquo;) designed to solve the specific technical problem of converting rasterized chemical diagrams into machine-readable connection tables. The paper is characterized by:</p>
<ul>
<li><strong>Algorithmic Specification</strong>: It details specific algorithms for vectorization, polygon approximation, and character recognition.</li>
<li><strong>Performance Metrics</strong>: It validates the method using quantitative accuracy (98.9%) and speed comparisons against manual entry.</li>
<li><strong>System Architecture</strong>: It describes the integration of typically disparate components (OCR, vectorization, chemical rules) into a cohesive pipeline.</li>
</ul>
<h2 id="motivation-the-chemical-data-entry-bottleneck">Motivation: The Chemical Data Entry Bottleneck</h2>
<p>Chemical structure diagrams are the primary medium for communication between chemists, but computers cannot natively &ldquo;read&rdquo; these raster images.</p>
<ul>
<li><strong>Efficiency Gap</strong>: Manual redrawing of structures into chemical databases takes 6 to 10 minutes per structure.</li>
<li><strong>Technical Challenge</strong>: Existing commercial OCR systems failed on chemical diagrams because they could not handle the mix of graphics (bonds) and text (atom labels), nor could they recognize small fonts (3-7 points) or chemical symbols accurately.</li>
<li><strong>Goal</strong>: To create an &ldquo;Optical Chemical Structure Recognition&rdquo; (OCSR) system that reduces processing time to seconds while handling complex notation like stereochemistry and group formulas.</li>
</ul>
<h2 id="core-innovations-in-chemical-ocr">Core Innovations in Chemical OCR</h2>
<p>Kekulé-1 represents the &ldquo;first successful attempt&rdquo; to integrate image processing, OCR, and structure editing into a single workflow. Key innovations include:</p>
<ul>
<li><strong>Context-Aware OCR</strong>: Unlike standard OCR, Kekulé-1 uses &ldquo;chemical spell checking&rdquo; by applying valence rules and chemical context to correct raw character recognition errors (e.g., distinguishing &lsquo;5&rsquo; from &lsquo;S&rsquo; based on bonding).</li>
<li><strong>Adaptive Polygon Approximation</strong>: A modified vectorization algorithm that partitions objects at the farthest node to prevent artifact nodes in U-shaped structures.</li>
<li><strong>Hybrid Parsing</strong>: It treats the diagram as a graph where nodes can be explicit atoms or geometric intersections, using rule-based logic to parse &ldquo;group formulas&rdquo; (like $COOH$) recursively.</li>
</ul>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The authors evaluated the system on a private test set to validate robustness and speed.</p>
<ul>
<li><strong>Dataset</strong>: 524 chemical structures chosen from a &ldquo;wide variety of sources&rdquo; specifically to test the system&rsquo;s limits.</li>
<li><strong>Metrics</strong>: Success rate (percentage of structures processed with minimal editing) and processing time per structure.</li>
<li><strong>Comparators</strong>: Performance was compared against the &ldquo;manual redrawing&rdquo; baseline.</li>
</ul>
<h2 id="results-performance-and-conclusions">Results, Performance, and Conclusions</h2>
<ul>
<li><strong>High Accuracy</strong>: 98.9% of the test structures were successfully processed (with an average of 0.74 user prompts per structure).</li>
<li><strong>Speedup</strong>: Processing took 7 to 30 seconds per structure, a significant improvement over the 6 to 10 minute manual baseline.</li>
<li><strong>Robustness</strong>: The system successfully handled pathological cases like broken characters, skew (rotation), and touching characters.</li>
<li><strong>Impact</strong>: The authors conclude that the techniques are generalizable to other domains like electrical circuits and utility maps.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Training/Test Data</strong>: The evaluation used 524 chemical structures. These were not released publicly but were selected to represent &ldquo;limit&rdquo; cases.</li>
<li><strong>Input format</strong>: Scanned images at 300-400 dpi. The authors note that higher resolutions do not add information due to ink wicking and paper limitations.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p>The paper details several specific algorithmic implementations:</p>
<p><strong>Vectorization (Polygon Approximation)</strong>:</p>
<ul>
<li>Standard thinning and raster-to-vector translation are used.</li>
<li><strong>Innovation</strong>: The algorithm searches for the node <em>farthest</em> from the current start node to partition the object. This prevents artifact nodes in curved lines.</li>
<li><strong>Threshold Formula</strong>: The allowed deviation ($dist$) from a straight line is adaptive based on segment length ($length$):</li>
</ul>
<p>$$dist = \max(1, \frac{length}{10.0} + 0.4)$$</p>
<p>(Units in pixels)</p>
<p><strong>Rotation Correction</strong>:</p>
<ul>
<li>The system computes the angle of all &ldquo;long&rdquo; line segments modulo 15 degrees.</li>
<li>It bins these angles; the bin with the highest count (representing &lt; 4 degrees rotation) is treated as the scan skew and corrected.</li>
</ul>
<p><strong>Optical Character Recognition (OCR)</strong>:</p>
<ul>
<li>Uses a neural network with linked/shared weights (similar to Convolutional Neural Networks, though not named as such) acting as a feature detector.</li>
<li><strong>Training</strong>: Trained on specific chemical fonts.</li>
<li><strong>Inference</strong>: Outputs are ranked; if multiple characters (e.g., &lsquo;5&rsquo; and &lsquo;S&rsquo;) exceed a threshold, both are kept, and chemical context resolves the ambiguity later.</li>
</ul>
<p><strong>Chemical Parsing</strong>:</p>
<ul>
<li>Group formulas (e.g., $COOH$) are parsed left-to-right by subtracting valences.</li>
<li>Example: For $COOH$, the external bond reduces Carbon&rsquo;s valence to 3. The first Oxygen takes 2, leaving 1. The final Oxygen takes 1 (attaching to Carbon), and the Hydrogen takes 1 (attaching to Oxygen).</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>OCR Model</strong>: A neural network with a &ldquo;shared weights&rdquo; paradigm, effectively creating a learned convolution map. It achieves ~99.9% raw accuracy on isolated test sets of chemical fonts.</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: The evaluation was performed on an <strong>80486 processor at 33 MHz</strong>.</li>
<li><strong>Time</strong>: Average processing time was 9 seconds per structure.</li>
</ul>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{mcdanielAutomaticInterpretationChemical1996,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Automatic Interpretation of Chemical Structure Diagrams}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Graphics Recognition. Methods and Applications}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{McDaniel, Joe R. and Balmuth, Jason R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">editor</span> = <span style="color:#e6db74">{O&#39;Gorman, Lawrence and Kasturi, Rangachar}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span> = <span style="color:#e6db74">{Lecture Notes in Computer Science}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{1072}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{148--158}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{1996}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Springer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1007/3-540-61226-2_14}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Party Matters: Enhancing Legislative Vote Embeddings</title><link>https://hunterheidenreich.com/notes/interdisciplinary/social-science/party-matters-hiptm/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/interdisciplinary/social-science/party-matters-hiptm/</guid><description>A method for improving legislative vote prediction across sessions by augmenting bill text embeddings with sponsor metadata.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel neural architecture that modifies how bill embeddings are constructed by explicitly incorporating sponsor metadata alongside text. The authors validate this method by comparing it against text-only baselines (MWE and CNN) and demonstrating superior performance in a newly defined &ldquo;out-of-session&rdquo; evaluation setting.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Existing models for predicting legislative roll-call votes rely heavily on text or voting history within a single session. However, these models fail to generalize across sessions because the underlying data generation process changes. Specifically, the ideological position of bills on similar topics shifts depending on which party is in power. A model trained on a single session learns an implicit ideological prior that becomes inaccurate when the political context changes in subsequent sessions.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is a neural architecture that augments bill text representations with sponsor ideology, specifically the percentage of Republican vs. Democrat sponsors.</p>
<ul>
<li><strong>Sponsor-Weighted Embeddings</strong>: They compute a composite embedding where the text representation is weighted by party sponsorship percentages ($p_{r}, p_{d}$) and party-specific influence vectors ($a_{r}, a_{d}$).</li>
<li><strong>Out-of-Session Evaluation</strong>: They introduce a rigorous evaluation setting where models trained on past sessions (e.g., 2005-2012) are tested on future sessions (e.g., 2013-2014) to test generalization, which previous work had ignored.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors evaluated their models using a dataset of U.S. Congressional bills from 2005 to 2016.</p>
<ul>
<li><strong>Models Tested</strong>: They compared text-only models (MWE (Mean Word Embedding), CNN) against metadata-augmented versions (MWE+Meta, CNN+Meta) and a &ldquo;Meta-Only&rdquo; baseline (using dummy text).</li>
<li><strong>Settings</strong>:
<ul>
<li><strong>In-Session</strong>: 5-fold cross-validation on 2005-2012 data.</li>
<li><strong>Out-of-Session</strong>: Training on 2005-2012 and testing on 2013-2014 and 2015-2016.</li>
</ul>
</li>
<li><strong>Baselines</strong>: Comparisons included a &ldquo;Guess Yes&rdquo; baseline and an SVM trained on bag-of-words summaries with sponsor indicators.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Metadata is Critical</strong>: Augmenting text with sponsor metadata consistently outperformed text-only models. The <code>CNN+Meta</code> model achieved the highest accuracy in-session (86.21% vs. 83.24% for CNN) and on 2013-2014 out-of-session (83.59%), while <code>MWE+Meta</code> achieved the best 2015-2016 accuracy (71.90%).</li>
<li><strong>Generalization</strong>: Text-only models degraded significantly in out-of-session testing. For example, CNN dropped from 83.24% in-session to 77.49% on 2013-2014 and 69.63% on 2015-2016, confirming that text alone fails to capture shifting ideological contexts.</li>
<li><strong>Sponsor Signal</strong>: The <code>Meta-Only</code> model (using no text) outperformed text-only models in the 2013-2014 out-of-session test (82.28% vs. 77.57% for MWE), suggesting that in some contexts, the author&rsquo;s identity provides a stronger predictive signal than the bill&rsquo;s content.</li>
<li><strong>2015-2016 Difficulty</strong>: All models performed worse on the 2015-2016 session, where intra-party divisions within the House Republican caucus disrupted typical voting dynamics.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Source</strong>: Collected from GovTrack. The paper text references the &ldquo;106th to 111th&rdquo; Congressional sessions, but the data tables show coverage from 2005 to 2016, which corresponds to the 109th through 114th sessions.</li>
<li><strong>Content</strong>: Non-unanimous roll call votes, full text of bills/resolutions, and Congressional Research Service (CRS) summaries.</li>
<li><strong>Filtering</strong>: Bills with unanimous votes were excluded.</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Text lowercased and stop-words removed.</li>
<li>Summaries truncated to $N=400$ words; full text truncated to $N=2000$ words (80th percentile lengths).</li>
</ul>
</li>
<li><strong>Splits</strong>:
<ul>
<li><em>Training</em>: Sessions 2005-2012 (1718 bills).</li>
<li><em>Testing</em>: Sessions 2013-2014 (360 bills) and 2015-2016 (382 bills).</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Bill Representation ($v_{B}$)</strong>:
$$v_{B}=((a_{r}p_{r})\cdot T_{r})+((a_{d}p_{d})\cdot T_{d})$$
Where $T$ is the text embedding (CNN or MWE), $p$ is the percentage of sponsors from a party, and $a$ is a learnable party influence vector. $T_{r}$ and $T_{d}$ are Republican and Democratic copies of the same bill&rsquo;s text representation, each weighted by the corresponding party&rsquo;s sponsorship proportion.</li>
<li><strong>Vote Prediction</strong>:
<ul>
<li>Project bill embedding to legislator space: $v_{BL}=W_{B}v_{B}+b_{B}$.</li>
<li>Alignment score: $W_{v}(v_{BL}\odot v_{L})+b_{v}$ (using element-wise multiplication).</li>
<li>Output: Sigmoid activation.</li>
</ul>
</li>
<li><strong>Optimization</strong>: AdaMax algorithm with binary cross-entropy loss.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Text Encoders</strong>:
<ul>
<li><strong>CNN</strong>: 4-grams with 400 filter maps.</li>
<li><strong>MWE</strong>: <a href="/posts/intro-to-word-embeddings/">Mean Word Embedding</a>.</li>
</ul>
</li>
<li><strong>Embeddings</strong>:
<ul>
<li>Initialized with 50-dimensional GloVe vectors.</li>
<li>Embeddings are non-static (updated during training).</li>
<li>Legislator embedding size ($v_{L}$): 25 dimensions.</li>
</ul>
</li>
<li><strong>Initialization</strong>: Weights initialized with Glorot uniform distribution.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Accuracy.</li>
<li><strong>Comparison</strong>:
<ul>
<li><strong>In-session</strong>: 5-fold cross-validation.</li>
<li><strong>Out-of-session</strong>: Train on past sessions, predict future sessions.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training Config</strong>: Models trained for 50 epochs with mini-batches of size 50. No specific GPU or compute requirements are reported.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://www.govtrack.us/">GovTrack</a></td>
          <td>Dataset</td>
          <td>Public</td>
          <td>Source for bill texts and roll-call votes</td>
      </tr>
  </tbody>
</table>
<p>No official code repository or pretrained models were released with this paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Kornilova, A., Argyle, D., &amp; Eidelman, V. (2018). Party Matters: Enhancing Legislative Embeddings with Author Attributes for Vote Prediction. <em>Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)</em>, 510-515. <a href="https://doi.org/10.18653/v1/p18-2081">https://doi.org/10.18653/v1/p18-2081</a></p>
<p><strong>Publication</strong>: ACL 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{kornilovaPartyMattersEnhancing2018,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Party {{Matters}}: {{Enhancing Legislative Embeddings}} with {{Author Attributes}} for {{Vote Prediction}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Party {{Matters}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Kornilova, Anastassia and Argyle, Daniel and Eidelman, Vlad}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 56th {{Annual Meeting}} of the {{Association}} for {{Computational Linguistics}} ({{Volume}} 2: {{Short Papers}})}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{510--515}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computational Linguistics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{Melbourne, Australia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.18653/v1/p18-2081}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{1805.08182}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Mixture Density Networks: Modeling Multimodal Distributions</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/mixture-density-networks/</guid><description>A 1994 technical report introducing Mixture Density Networks (MDNs) to model arbitrary conditional probability distributions using neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper.</p>
<p>It identifies a specific failure mode in existing neural network methodologies (least-squares regression on multi-valued inverse problems) and proposes a novel architecture (combining MLPs with Mixture Models) to solve it. It derives the mathematical framework for training this architecture via standard back-propagation and validates it against the established baseline.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Standard neural networks trained with sum-of-squares (MSE) or cross-entropy error functions approximate the <strong>conditional average</strong> of the target data, $\langle t|x \rangle$.</p>
<p>While optimal for single-valued functions or classification, this produces completely erroneous results for <strong>inverse problems</strong> where the mapping is multi-valued (one input has multiple valid outputs). For example, in robot inverse kinematics, &ldquo;elbow-up&rdquo; and &ldquo;elbow-down&rdquo; configurations can achieve the same hand position. An MSE-trained network will average these two valid angles, resulting in an invalid configuration (the paper shows this produces end-effector positions at the outer boundary of the accessible region, corresponding to $\theta_2 = \pi$).</p>















<figure class="post-figure center ">
    <img src="/img/notes/single-gaussian-mse-prediction.webp"
         alt="Single Gaussian MSE prediction averaging multimodal distribution"
         title="Single Gaussian MSE prediction averaging multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MSE-trained networks predict the mean, which averages across modes and produces invalid outputs for inverse problems.</figcaption>
    
</figure>

<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The introduction of the <strong>Mixture Density Network (MDN)</strong>.</p>
<p>The neural network predicts the <strong>parameters</strong> (mixing coefficients, means, and variances) of a kernel mixture distribution (typically Gaussian).</p>















<figure class="post-figure center ">
    <img src="/img/notes/gaussian-mixture-mdn-prediction.webp"
         alt="Gaussian mixture model prediction capturing multimodal distribution"
         title="Gaussian mixture model prediction capturing multimodal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MDNs predict mixture parameters to capture the full conditional probability density, representing all modes.</figcaption>
    
</figure>

<p>Key technical contributions include:</p>
<ol>
<li><strong>Architecture</strong>: Mapping network outputs to mixture parameters using specific activation functions to satisfy constraints (Softmax for priors $\alpha$, Exponential for variances $\sigma$).</li>
<li><strong>Training</strong>: Deriving the error function as the negative log-likelihood of the mixture model.</li>
<li><strong>Optimization</strong>: Deriving the exact derivatives (gradients) of the error with respect to network outputs, allowing the mixture model parameters to be learned via standard back-propagation.</li>
</ol>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>Bishop validated the method on two tasks, comparing an MDN against a standard MLP trained with least-squares:</p>
<ol>
<li><strong>Toy Inverse Problem</strong>: A sinusoidal mapping $x = t + 0.3\sin(2\pi t) + \epsilon$. The forward problem ($t \to x$) is single-valued, but the inverse ($x \to t$) is multi-valued.</li>
<li><strong>Robot Kinematics</strong>: A 2-link robot arm simulation. The task is to map end-effector Cartesian coordinates $(x_1, x_2)$ back to joint angles $(\theta_1, \theta_2)$.</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Toy Problem</strong>: The standard least-squares network failed completely, drawing a smooth curve through the average of the multiple branches, which did not correspond to valid data. The MDN correctly modeled the tri-modal density and discontinuous jumps in the most probable solution.</li>
<li><strong>Robot Kinematics</strong>: The MDN reduced the RMS positioning error by an order of magnitude compared to the standard network (0.0053 vs 0.0578).</li>
<li><strong>Generality</strong>: The paper concludes that MDNs provide a complete description of the conditional probability density, allowing users to calculate any statistic (mean, mode, variance) needed for the application.</li>
</ul>
<h2 id="extracting-predictions">Extracting Predictions</h2>
<p>Once trained, the MDN outputs a full conditional density $p(t|x)$, from which several useful statistics can be derived:</p>
<ul>
<li><strong>Conditional mean</strong>: $\langle t|x \rangle = \sum_i \alpha_i(x) \mu_i(x)$, equivalent to the standard least-squares network output.</li>
<li><strong>Conditional variance</strong>: $s^2(x) = \sum_i \alpha_i(x) { \sigma_i(x)^2 + | \mu_i(x) - \sum_j \alpha_j(x) \mu_j(x) |^2 }$, which is input-dependent (more general than the constant-variance least-squares assumption).</li>
<li><strong>Most probable branch</strong>: Select the kernel $i$ with the largest mixing coefficient $\alpha_i(x)$, then use its center $\mu_i$ as the prediction. This yields a discontinuous but accurate mapping for multi-valued problems.</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li><strong>Model order selection</strong>: The number of mixture components $m$ must be chosen in advance. The paper acknowledges this as an open problem and suggests cross-validation or Bayesian model comparison as potential approaches.</li>
<li><strong>Computational overhead</strong>: The number of network outputs grows as $(c + 2) \times m$, where $c$ is the target dimensionality. For high-dimensional targets or many kernels, this can become significant.</li>
<li><strong>Isotropic kernels</strong>: The paper uses a single variance parameter $\sigma_i$ per kernel (shared across target dimensions), which assumes isotropic covariance. The paper notes this can be generalized to full covariance matrices at the cost of additional parameters.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>1. Toy Inverse Problem</strong></p>
<ul>
<li><strong>Function</strong>: $x = t + 0.3\sin(2\pi t) + \epsilon$</li>
<li><strong>Noise</strong>: $\epsilon \sim U(-0.1, 0.1)$</li>
<li><strong>Sampling</strong>: 1,000 points generated by sampling $t$ at equal intervals in range $(0, 1)$.</li>
<li><strong>Task</strong>: Inverse mapping (predict $t$ given $x$).</li>
</ul>
<p><strong>2. Robot Kinematics</strong></p>
<ul>
<li><strong>System</strong>: 2-link arm with lengths $L_1=0.8, L_2=0.2$.</li>
<li><strong>Forward Kinematics</strong>:
<ul>
<li>$x_1 = L_1 \cos(\theta_1) - L_2 \cos(\theta_1 + \theta_2)$</li>
<li>$x_2 = L_1 \sin(\theta_1) - L_2 \sin(\theta_1 + \theta_2)$</li>
</ul>
</li>
<li><strong>Constraints</strong>: $\theta_1 \in (0.3, 1.2)$, $\theta_2 \in (\pi/2, 3\pi/2)$.</li>
<li><strong>Dataset</strong>: 1,000 training points, 1,000 test points.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Mixture Model Definition</strong></p>
<p>The conditional density is defined as:</p>
<p>$$p(t|x) = \sum_{i=1}^{m} \alpha_i(x) \phi_i(t|x)$$</p>
<p>Where kernels $\phi_i$ are Gaussians with centers $\mu_i(x)$ and variances $\sigma_i(x)$.</p>
<p><strong>Network Output Mappings</strong></p>
<p>If the network produces raw outputs $z$, they are mapped to parameters as follows to satisfy probability constraints:</p>
<ul>
<li><strong>Mixing Coefficients ($\alpha$)</strong>: Softmax. $\alpha_i = \frac{\exp(z_i^\alpha)}{\sum_j \exp(z_j^\alpha)}$</li>
<li><strong>Variances ($\sigma$)</strong>: Exponential. $\sigma_i = \exp(z_i^\sigma)$</li>
<li><strong>Means ($\mu$)</strong>: Linear/Identity. $\mu_{ik} = z_{ik}^\mu$</li>
</ul>
<p><strong>Loss Function</strong></p>
<p>Negative Log Likelihood:</p>
<p>$$E^q = - \ln \left{ \sum_{i=1}^{m} \alpha_i(x^q) \phi_i(t^q|x^q) \right}$$</p>
<h3 id="models">Models</h3>
<p><strong>1. Toy Problem Configuration</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 1 input ($x$), 1 hidden layer.</li>
<li><strong>Hidden Units</strong>: 20 units (tanh activation).</li>
<li><strong>Outputs</strong>: 9 units.
<ul>
<li>$m=3$ Gaussian kernels.</li>
<li>Parameters per kernel: 1 $\alpha$, 1 $\sigma$, 1 $\mu$. Total = $3 \times 3 = 9$.</li>
</ul>
</li>
<li><strong>Training</strong>: 1,000 cycles of BFGS.</li>
</ul>
<p><strong>2. Robot Kinematics Configuration (Least-Squares Baseline)</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 2 inputs ($x_1, x_2$), 2 linear outputs ($\theta_1, \theta_2$).</li>
<li><strong>Hidden Units</strong>: Best result with 20 units (tanh activation), tested with 5, 10, 15, 20, 25, 30.</li>
<li><strong>Training</strong>: 3,000 cycles of BFGS.</li>
</ul>
<p><strong>3. Robot Kinematics Configuration (MDN)</strong></p>
<ul>
<li><strong>Structure</strong>: MLP with 2 inputs ($x_1, x_2$).</li>
<li><strong>Hidden Units</strong>: 10 units (tanh activation).</li>
<li><strong>Outputs</strong>: 8 units.
<ul>
<li>$m=2$ Gaussian kernels.</li>
<li>Target dimension $c=2$ (predicting $\theta_1, \theta_2$).</li>
<li>Parameters per kernel: 1 $\alpha$ + 1 $\sigma$ (common variance) + 2 $\mu$ (means for $\theta_1, \theta_2$).</li>
<li>Total = $2 \times (1 + 1 + 2) = 8$.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metric</strong>: RMS Euclidean distance between the desired end-effector position and the achieved position (calculated by plugging predicted angles back into forward kinematics).</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Hidden Units</th>
          <th>Kernels</th>
          <th>RMS Error</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Least Squares</td>
          <td>20</td>
          <td>N/A</td>
          <td>0.0578</td>
      </tr>
      <tr>
          <td>MDN</td>
          <td>10</td>
          <td>2</td>
          <td>0.0053</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bishop, C. M. (1994). Mixture Density Networks. <em>Neural Computing Research Group Report: NCRG/94/004</em>, Aston University.</p>
<p><strong>Publication</strong>: Neural Computing Research Group Technical Report 1994</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{bishopMixtureDensityNetworks1994,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Mixture {{Density Networks}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Bishop, Christopher M.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1994</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = feb,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{NCRG/94/004}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span> = <span style="color:#e6db74">{Aston University}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Kekulé: OCR-Optical Chemical Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/kekule-1992/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/kekule-1992/</guid><description>A seminal 1992 system for Optical Chemical Structure Recognition (OCSR) using neural networks and heuristic graph compilation.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: McDaniel, J. R., &amp; Balmuth, J. R. (1992). Kekulé: OCR-Optical Chemical (Structure) Recognition. <em>Journal of Chemical Information and Computer Sciences</em>, 32(4), 373-378. <a href="https://doi.org/10.1021/ci00008a018">https://doi.org/10.1021/ci00008a018</a></p>
<p><strong>Publication</strong>: Journal of Chemical Information and Computer Sciences, 1992</p>
<h2 id="system-architecture-and-methodological-approach">System Architecture and Methodological Approach</h2>
<p>This is a <strong>Methodological Paper</strong> ($\Psi_{\text{Method}}$). It proposes a novel software architecture (&ldquo;Kekulé&rdquo;) designed to solve a specific technical problem: the automatic conversion of printed chemical structure diagrams into computer-readable connection tables. The paper focuses on the &ldquo;how&rdquo; of the system by detailing the seven-step pipeline from scanning to graph compilation, validating the method through performance testing on a specific dataset.</p>
<h2 id="motivation-bridging-visual-diagrams-and-connection-tables">Motivation: Bridging Visual Diagrams and Connection Tables</h2>
<p>The primary motivation is to bridge the gap between how chemists communicate (structural diagrams) and how chemical databases store information (connection tables like MOLfiles).</p>
<ul>
<li><strong>Inefficiency of Manual Entry</strong>: Manual compilation of structural descriptions is &ldquo;tedious and highly prone to error&rdquo;.</li>
<li><strong>Redrawing Costs</strong>: Even using drawing programs (like ChemDraw ancestors) to capture connectivity is inefficient; redrawing a complex molecule like vitamin $B_{12}$ takes ~20 minutes.</li>
<li><strong>Lack of Existing Solutions</strong>: Existing OCR systems at the time failed on chemical diagrams because they could not handle the mix of graphics (bonds) and text (atom labels), and struggled with small, mixed fonts.</li>
</ul>
<h2 id="novelty-a-hybrid-ocr-and-heuristic-approach">Novelty: A Hybrid OCR and Heuristic Approach</h2>
<p>Kekulé represents the first successful attempt to integrate all of the required elements of image processing, OCR, structure editing, and database communication into a complete system.</p>
<ul>
<li><strong>Hybrid OCR Approach</strong>: Unlike commercial OCR of the time, it used a custom implementation combining rotation correction (for skew) with a <strong>multilayer perceptron neural network</strong> trained specifically on small fonts (down to 3.2 points).</li>
<li><strong>Heuristic Feature Extraction</strong>: The authors developed specific heuristics to handle chemical artifacts, such as an exhaustive search for dashed lines, explicitly rejecting Hough transforms as unreliable for short segments.</li>
<li><strong>Contextual &ldquo;Spell Checking&rdquo;</strong>: The system uses chemical context to verify OCR results, such as checking atom symbols against a valid list and using bond connections to disambiguate characters.</li>
</ul>
<h2 id="experimental-setup-and-dataset-validation">Experimental Setup and Dataset Validation</h2>
<p>The authors performed a validation study on a diverse set of chemical structures to stress-test the system:</p>
<ul>
<li><strong>Dataset</strong>: 444 chemical structures were selected from a wide variety of sources, including the <em>Merck Index</em>, <em>Aldrich Handbook</em>, and <em>ACS Nomenclature Guide</em>, specifically chosen to &ldquo;test Kekulé&rsquo;s limits&rdquo;.</li>
<li><strong>Metrics</strong>:
<ul>
<li><strong>Processing Success</strong>: Percentage of structures processed.</li>
<li><strong>User Intervention</strong>: Average number of prompts per structure for verification.</li>
<li><strong>Editing Time</strong>: Time required to correct interpretation errors (arbitrary &ldquo;good&rdquo; limit set at 30 seconds).</li>
</ul>
</li>
</ul>
<h2 id="results-and-system-performance">Results and System Performance</h2>
<ul>
<li><strong>High Success Rate</strong>: 98.9% of the 444 structures were processed successfully.</li>
<li><strong>Performance Speed</strong>: The average processing time was 9 seconds per structure on an 80486 (33 MHz) processor.</li>
<li><strong>Error Modes</strong>: The primary bottleneck was broken characters in scanned images (e.g., breaks in &lsquo;H&rsquo; or &lsquo;N&rsquo; crossbars), which slowed down the OCR significantly.</li>
<li><strong>Impact</strong>: The system demonstrated that automated interpretation was faster and less error-prone than manual redrawing.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The following details outline the specific technical implementation described in the 1992 paper.</p>
<h3 id="data">Data</h3>
<p>The authors did not release a public dataset but described their test set sources in detail.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Evaluation</td>
          <td>Mixed Chemical Sources</td>
          <td>444 structures</td>
          <td>Sourced from <em>Merck Index</em>, <em>Aldrich Handbook</em>, <em>ACS Nomenclature Guide</em>, etc.</td>
      </tr>
      <tr>
          <td>Training (OCR)</td>
          <td>Font Exemplars</td>
          <td>Unknown</td>
          <td>&ldquo;Exemplars of characters from numerous serif and sanserif fonts&rdquo;.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The paper details a 7-step pipeline. Key algorithmic choices include:</p>
<ul>
<li>
<p><strong>Vectorization</strong>:</p>
<ul>
<li>Images are reduced to 1-pixel width using <strong>thinning</strong> and <strong>raster-to-vector translation</strong>.</li>
<li>An <strong>adaptive smoothing algorithm</strong> is applied to remove pixel-level jitter.</li>
</ul>
</li>
<li>
<p><strong>Feature Extraction (Dashed Lines)</strong>:</p>
<ul>
<li><strong>Hough Transforms</strong> were rejected due to poor performance on short line segments.</li>
<li><strong>Slope sorting</strong> was rejected due to variance in short dashes.</li>
<li><strong>Chosen Method</strong>: Exhaustive search/testing of all features that <em>might</em> be dashed lines (subset of features).</li>
</ul>
</li>
<li>
<p><strong>Graph Compilation</strong>:</p>
<ul>
<li><strong>Character Grouping</strong>: Characters are assembled into strings based on XY adjacency.</li>
<li><strong>Node Creation</strong>: Character strings become nodes. Vectors with endpoints &ldquo;too far&rdquo; from strings create new nodes.</li>
<li><strong>Heuristics</strong>: Circles are converted to alternating single-double bonds; &ldquo;thick&rdquo; bonds between wedges are automatically generated.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The core machine learning component is the OCR engine.</p>
<ul>
<li><strong>Architecture</strong>: A <strong>multilayer perceptron neural network</strong> (fully connected).</li>
<li><strong>Input</strong>: Normalized characters. Normalization involves rotation (for skew), scaling, under-sampling, and contrast/density adjustments.</li>
<li><strong>Output</strong>: Ranked probability matches. Outputs above an experimental threshold are retained. If a character is ambiguous (e.g., &lsquo;5&rsquo; vs &lsquo;S&rsquo;), both are kept and resolved via chemical context.</li>
<li><strong>Performance</strong>: Raw accuracy ~96% on small fonts (compared to ~85% for commercial OCR of the era).</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>The system was developed and tested on hardware typical of the early 1990s.</p>
<ul>
<li><strong>Processor</strong>: Intel 80486 at 33 MHz.</li>
<li><strong>Scanners</strong>: Hewlett-Packard ScanJet (300 dpi) and Logitech ScanMan (400 dpi hand-held).</li>
<li><strong>Platform</strong>: Microsoft Windows.</li>
</ul>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{mcdanielKekuleOCRopticalChemical1992,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Kekulé: {{OCR-optical}} Chemical (Structure) Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Kekulé}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{McDaniel, Joe R. and Balmuth, Jason R.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">1992</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Computer Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{32}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{373--378}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{0095-2338, 1520-5142}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/ci00008a018}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>IMG2SMI: Translating Molecular Structure Images to SMILES</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2smi/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/img2smi/</guid><description>Campos &amp; Ji's method for converting 2D molecular images to SMILES strings using Transformers and SELFIES representation.</description><content:encoded><![CDATA[<h2 id="contributions--taxonomy">Contributions &amp; Taxonomy</h2>
<p>This is both a <strong>Method</strong> and <strong>Resource</strong> paper:</p>
<ul>
<li><strong>Method</strong>: It adapts standard image captioning architectures (encoder-decoder) to the domain of Optical Chemical Structure Recognition (OCSR), treating molecule recognition as a translation task.</li>
<li><strong>Resource</strong>: It introduces <strong>MOLCAP</strong>, a large-scale dataset of 81 million molecules aggregated from public chemical databases, addressing the data scarcity that previously hindered deep learning approaches to OCSR.</li>
</ul>
<h2 id="the-bottleneck-in-chemical-literature-translation">The Bottleneck in Chemical Literature Translation</h2>
<p>Chemical literature is &ldquo;full of recipes written in a language computers cannot understand&rdquo; because molecules are depicted as 2D images. This creates a fundamental bottleneck:</p>
<ul>
<li><strong>The Problem</strong>: Chemists must manually redraw molecular structures to search for related compounds or reactions. This is slow, error-prone, and makes large-scale literature mining impossible.</li>
<li><strong>Existing Tools</strong>: Legacy systems like OSRA (Optical Structure Recognition Application) rely on handcrafted rules and often require human correction, making them unfit for unsupervised, high-throughput processing.</li>
<li><strong>The Goal</strong>: An automated system that can translate structure images directly to machine-readable strings (SMILES/SELFIES) without human supervision, enabling large-scale knowledge extraction from decades of chemistry literature and patents.</li>
</ul>
<h2 id="core-innovation-selfies-and-image-captioning">Core Innovation: SELFIES and Image Captioning</h2>
<p>The core novelty is demonstrating that <strong>how you represent the output text is as important as the model architecture itself</strong>. Key contributions:</p>
<ol>
<li>
<p><strong>Image Captioning Framework</strong>: Applies modern encoder-decoder architectures (ResNet-101 + Transformer) to OCSR, treating it as an image-to-text translation problem with a standard cross-entropy loss objective over the generation sequence:
$$ \mathcal{L} = -\sum\limits_{t=1}^{T} \log P(y_t \mid y_1, \ldots, y_{t-1}, x) $$</p>
</li>
<li>
<p><strong>SELFIES as Target Representation</strong>: The key mechanism relies on using <strong>SELFIES</strong> (Self-Referencing Embedded Strings) as the output format. SELFIES is based on a formal grammar where every possible string corresponds to a valid molecule, eliminating the syntactic invalidity problems (unmatched parentheses, invalid characters) that plague SMILES generation.</p>
</li>
<li>
<p><strong>MOLCAP Dataset</strong>: Created a comprehensive dataset of 81 million unique molecules from PubChem, ChEMBL, <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a>, and other sources. Generated 256x256 pixel images using RDKit for 1 million training samples and 5,000 validation samples.</p>
</li>
<li>
<p><strong>Task-Specific Evaluation</strong>: Demonstrated that traditional NLP metrics (BLEU) are poor indicators of scientific utility. Introduced evaluation based on <strong>molecular fingerprints</strong> (MACCS, RDK, Morgan) and <strong>Tanimoto similarity</strong>:
$$ T(a, b) = \frac{c}{a + b - c} $$
where $c$ is the number of common fingerprint bits, and $a$ and $b$ are the number of set bits in each respective molecule&rsquo;s fingerprint. This formulation reliably measures functional chemical similarity.</p>
</li>
</ol>
<h2 id="experimental-setup-and-ablation-studies">Experimental Setup and Ablation Studies</h2>
<p>The evaluation focused on comparing IMG2SMI to existing systems and identifying which design choices matter most:</p>
<ol>
<li>
<p><strong>Baseline Comparisons</strong>: Benchmarked against OSRA (rule-based system) and DECIMER (first deep learning approach) on the MOLCAP dataset to establish whether modern architectures could surpass traditional methods.</p>
</li>
<li>
<p><strong>Ablation Studies</strong>: Extensive ablations isolating key factors:</p>
<ul>
<li><strong>Decoder Architecture</strong>: Transformer vs. RNN/LSTM decoders</li>
<li><strong>Encoder Fine-tuning</strong>: Fine-tuned vs. frozen pre-trained ResNet weights</li>
<li><strong>Output Representation</strong>: SELFIES vs. character-level SMILES vs. BPE-tokenized SMILES (the most critical ablation)</li>
</ul>
</li>
</ol>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>MACCS FTS</th>
          <th>Valid Captions</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RNN + Fixed Encoder</td>
          <td>0.1526</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>RNN + Fine-tuned Encoder</td>
          <td>0.4180</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Transformer + Fixed Encoder</td>
          <td>0.7674</td>
          <td>61.1%</td>
      </tr>
      <tr>
          <td>Transformer + Fine-tuned Encoder</td>
          <td>0.9475</td>
          <td>99.4%</td>
      </tr>
      <tr>
          <td>Character-level SMILES (fine-tuned)</td>
          <td>N/A</td>
          <td>2.1%</td>
      </tr>
      <tr>
          <td>BPE SMILES (2000 vocab, fine-tuned)</td>
          <td>N/A</td>
          <td>20.0%</td>
      </tr>
      <tr>
          <td>SELFIES (fine-tuned)</td>
          <td>0.9475</td>
          <td>99.4%</td>
      </tr>
  </tbody>
</table>
<ol start="3">
<li><strong>Metric Analysis</strong>: Systematic comparison of evaluation metrics including BLEU, ROUGE, Levenshtein distance, exact match accuracy, and molecular fingerprint-based similarity measures.</li>
</ol>
<h2 id="results-findings-and-limitations">Results, Findings, and Limitations</h2>
<p><strong>Performance Gains</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI</th>
          <th>OSRA</th>
          <th>DECIMER</th>
          <th>Random Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>0.0000</td>
          <td>0.3378</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9020</td>
          <td>0.2790</td>
          <td>0.0000</td>
          <td>0.2229</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8707</td>
          <td>0.2677</td>
          <td>0.0000</td>
          <td>0.1081</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>0.0000</td>
          <td>0.0422</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>0.00%</td>
          <td>0.00%</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<ul>
<li>163% improvement over OSRA on MACCS Tanimoto similarity.</li>
<li>Nearly 10x improvement on ROUGE scores (0.6240 vs. 0.0684).</li>
<li>Average Tanimoto similarity exceeds 0.85 (functionally similar molecules even when not exact matches).</li>
</ul>
<p><strong>Key Findings</strong>:</p>
<ul>
<li><strong>SELFIES is Critical</strong>: Using SELFIES yields <strong>99.4% valid molecules</strong>, compared to only ~2% validity for character-level SMILES.</li>
<li><strong>Architecture Matters</strong>: Transformer decoder significantly outperforms RNN/LSTM approaches. Fine-tuning the ResNet encoder (vs. frozen weights) yields substantial performance gains (e.g., MACCS FTS: 0.7674 to 0.9475).</li>
<li><strong>Metric Insights</strong>: BLEU is a poor metric for this task. Molecular fingerprint-based Tanimoto similarity is most informative because it measures functional chemical similarity.</li>
</ul>
<p><strong>Limitations</strong>:</p>
<ul>
<li><strong>Low Exact Match</strong>: Only <strong>7.24%</strong> exact matches. The model captures the overarching functional groups and structure but misses fine details like exact double bond placement.</li>
<li><strong>Complexity Bias</strong>: Trained on large molecules (average length &gt;40 tokens), so it performs poorly on very simple structures where OSRA still excels.</li>
</ul>
<p><strong>Conclusion</strong>: The work shows that modern encoder-decoder architectures combined with valid-by-construction molecular representations (SELFIES) can outperform traditional rule-based systems by large margins on fingerprint-based similarity metrics. The system is useful for literature mining where functional similarity matters more than exact matches, though 7.24% exact match accuracy and poor performance on simple molecules indicate clear directions for future work.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Image captioning system based on DETR (Detection Transformer) framework.</p>
<p><strong>Visual Encoder</strong>:</p>
<ul>
<li><strong>Backbone</strong>: ResNet-101 pre-trained on ImageNet</li>
<li><strong>Feature Extraction</strong>: 4th layer extraction (convolutions only)</li>
<li><strong>Output</strong>: 2048-dimensional dense feature vector</li>
</ul>
<p><strong>Caption Decoder</strong>:</p>
<ul>
<li><strong>Type</strong>: Transformer encoder-decoder</li>
<li><strong>Layers</strong>: 3 stacked encoder layers, 3 stacked decoder layers</li>
<li><strong>Attention Heads</strong>: 8</li>
<li><strong>Hidden Dimensions</strong>: 2048 (feed-forward networks)</li>
<li><strong>Dropout</strong>: 0.1</li>
<li><strong>Layer Normalization</strong>: 1e-12</li>
</ul>
<p><strong>Training Configuration</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: AdamW</li>
<li><strong>Learning Rate</strong>: 5e-5 (selected after sweep from 1e-4 to 1e-6)</li>
<li><strong>Weight Decay</strong>: 1e-4</li>
<li><strong>Batch Size</strong>: 32</li>
<li><strong>Epochs</strong>: 5</li>
<li><strong>Codebase</strong>: Built on open-source DETR implementation</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>MOLCAP Dataset</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Property</th>
          <th>Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Total Size</td>
          <td>81,230,291 molecules</td>
          <td>Aggregated from PubChem, ChEMBL, GDB13</td>
      </tr>
      <tr>
          <td>Training Split</td>
          <td>1,000,000 molecules</td>
          <td>Randomly selected unique molecules</td>
      </tr>
      <tr>
          <td>Validation Split</td>
          <td>5,000 molecules</td>
          <td>Randomly selected for evaluation</td>
      </tr>
      <tr>
          <td>Image Resolution</td>
          <td>256x256 pixels</td>
          <td>Generated using RDKit</td>
      </tr>
      <tr>
          <td>Median SELFIES Length</td>
          <td>&gt;45 characters</td>
          <td>More complex than typical benchmarks</td>
      </tr>
      <tr>
          <td>Full Dataset Storage</td>
          <td>~16.24 TB</td>
          <td>Necessitated use of 1M subset</td>
      </tr>
      <tr>
          <td>Augmentation</td>
          <td>None</td>
          <td>No cropping, rotation, or other augmentation</td>
      </tr>
  </tbody>
</table>
<p><strong>Preprocessing</strong>:</p>
<ul>
<li>Images generated using RDKit at 256x256 resolution</li>
<li>Molecules converted to canonical representations</li>
<li>SELFIES tokenization for model output</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metrics</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>IMG2SMI Value</th>
          <th>OSRA Baseline</th>
          <th>Purpose</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MACCS FTS</td>
          <td>0.9475</td>
          <td>0.3600</td>
          <td>Fingerprint Tanimoto Similarity (functional groups)</td>
      </tr>
      <tr>
          <td>RDK FTS</td>
          <td>0.9020</td>
          <td>0.2790</td>
          <td>RDKit fingerprint similarity</td>
      </tr>
      <tr>
          <td>Morgan FTS</td>
          <td>0.8707</td>
          <td>0.2677</td>
          <td>Morgan fingerprint similarity (circular)</td>
      </tr>
      <tr>
          <td>ROUGE</td>
          <td>0.6240</td>
          <td>0.0684</td>
          <td>Text overlap metric</td>
      </tr>
      <tr>
          <td>Exact Match</td>
          <td>7.24%</td>
          <td>0.04%</td>
          <td>Structural identity (strict)</td>
      </tr>
      <tr>
          <td>Valid Captions</td>
          <td>99.4%</td>
          <td>65.2%</td>
          <td>Syntactic validity (with SELFIES)</td>
      </tr>
      <tr>
          <td>Levenshtein Distance</td>
          <td>21.13</td>
          <td>32.76</td>
          <td>String edit distance (lower is better)</td>
      </tr>
  </tbody>
</table>
<p><strong>Secondary Metrics</strong> (shown to be less informative for chemical tasks):</p>
<ul>
<li>BLEU, ROUGE (better suited for natural language)</li>
<li>Levenshtein distance (doesn&rsquo;t capture chemical similarity)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU</strong>: Single NVIDIA GeForce RTX 2080 Ti</li>
<li><strong>Training Time</strong>: ~5 hours per epoch, approximately 25 hours total for 5 epochs</li>
<li><strong>Memory</strong>: Sufficient for batch size 32 with ResNet-101 + Transformer architecture</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<p>The paper mentions releasing both code and the MOLCAP dataset, but no public repository or download link has been confirmed as available.</p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MOLCAP dataset</td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>81M molecules; claimed released but no public URL found</td>
      </tr>
      <tr>
          <td>IMG2SMI code</td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Built on DETR; claimed released but no public URL found</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Campos, D., &amp; Ji, H. (2021). IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System (No. arXiv:2109.04202). arXiv. <a href="https://doi.org/10.48550/arXiv.2109.04202">https://doi.org/10.48550/arXiv.2109.04202</a></p>
<p><strong>Publication</strong>: arXiv preprint (2021)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.48550/arXiv.2109.04202">Paper on arXiv</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{campos2021img2smi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{IMG2SMI: Translating Molecular Structure Images to Simplified Molecular-input Line-entry System}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Campos, Daniel and Ji, Heng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2109.04202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.48550/arXiv.2109.04202}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Distributed Representations: A Foundational Theory</title><link>https://hunterheidenreich.com/notes/machine-learning/model-architectures/distributed-representations/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/model-architectures/distributed-representations/</guid><description>Hinton's 1984 technical report establishing the theoretical efficiency of distributed representations over local encoding in neural networks.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Theory</strong> paper, with strong secondary elements of <strong>Method</strong> and <strong>Position</strong>.</p>
<p>It is a theoretical work because its core contribution is the formal mathematical derivation of the encoding accuracy and error properties of distributed schemes (coarse coding) compared to local schemes. It serves as a position paper by challenging the &ldquo;grandmother cell&rdquo; (local representation) intuition prevalent in AI at the time and advocating for the &ldquo;constructive&rdquo; view of memory.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The motivation is to overcome the inefficiency of <strong>local representations</strong>, where one hardware unit corresponds to exactly one entity, and to challenge traditional metaphors of memory.</p>
<ul>
<li><strong>Inefficiency</strong>: In local representations, high accuracy requires an exponential number of units (accuracy $\propto \sqrt[k]{n}$ for $k$ dimensions).</li>
<li><strong>Brittleness</strong>: Local representations lack natural support for generalization; learning a fact about one concept (e.g., &ldquo;chimps like onions&rdquo;) requires extra machinery to transfer to similar concepts (e.g., &ldquo;gorillas&rdquo;).</li>
<li><strong>Hardware Mismatch</strong>: Massive parallelism is wasted if units are active rarely (1 bit of info per unit active 50% of the time vs. almost 0 for sparse local units).</li>
<li><strong>The &ldquo;Filing Cabinet&rdquo; Metaphor</strong>: The paper challenges the standard view of memory as a storage system of literal copies. It motivates a shift toward understanding memory as a reconstructive inference process.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The paper introduces formal mechanisms that explain <em>why</em> distributed representations are superior:</p>
<ol>
<li><strong>Coarse Coding Efficiency</strong>: Hinton proves that using broad, overlapping receptive fields (&ldquo;coarse coding&rdquo;) yields higher accuracy for a fixed number of units than non-overlapping local fields. For a $k$-dimensional feature space with $n$ units of receptive field radius $r$, accuracy scales as $a \propto n \cdot r^{k-1}$. This is far superior to local encoding, where accuracy scales as $a \propto n^{1/k}$.</li>
<li><strong>Automatic Generalization</strong>: It demonstrates that generalization is an emergent property of vector overlap. Modifying weights for one pattern automatically affects similar patterns (conspiracy effect).</li>
<li><strong>Memory as Reconstruction</strong>: It posits that memory is a reconstructive process where items are created afresh from fragments using plausible inference rules (connection strengths). This blurs the line between veridical recall and confabulation.</li>
<li><strong>Gradual Concept Formation</strong>: Distributed representations allow new concepts to emerge gradually through weight modifications that progressively differentiate existing concepts. This avoids the discrete decisions and spare hardware units required by local representations.</li>
<li><strong>Solution to the Binding Problem</strong>: It proposes that true part/whole hierarchies are formed by fusing the identity of a part with its role to produce a single, new subpattern. The representation of the whole is then the sum of these combined identity/role representations.</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/notes/distributed-representations-binding.svg"
         alt="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         title="Diagram showing distributed representations with three pools of units (AGENT, RELATIONSHIP, PATIENT) connected via role/identity bindings"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The binding problem solution: true hierarchies require creating unique subpatterns that fuse an identity with its role, where the whole is represented as the sum of these combined representations.</figcaption>
    
</figure>

<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The paper performs analytical derivations and two specific computer simulations:</p>
<ol>
<li><strong>Arbitrary Mapping Simulation</strong>: A 3-layer network trained to map 20 grapheme strings (e.g., words) to 20 unrelated semantic vectors.</li>
<li><strong>Damage &amp; Recovery Analysis</strong>:
<ul>
<li><strong>Lesioning</strong>: Removing a single word-set unit to observe error patterns. This produced &ldquo;Deep Dyslexia&rdquo;-like semantic errors (e.g., reading &ldquo;PEACH&rdquo; as &ldquo;APRICOT&rdquo;), where the clean-up effect settles on a similar but incorrect meaning.</li>
<li><strong>Noise Injection</strong>: Adding noise to all connections involving word-set units, reducing performance from 99.3% to 64.3%.</li>
<li><strong>Retraining</strong>: Measuring the speed of relearning after noise damage (&ldquo;spontaneous recovery&rdquo;), where unrehearsed items recover alongside rehearsed ones due to shared weights.</li>
</ul>
</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ol>
<li><strong>Accuracy Scaling</strong>: For a $k$-dimensional feature space, the accuracy $a$ of a distributed representation scales as $a \propto n \cdot r^{k-1}$ (where $r$ is the receptive field radius), vastly outperforming local schemes.</li>
<li><strong>Reliability</strong>: Distributed systems exhibit graceful degradation. Removing units causes slight noise across many items.</li>
<li><strong>Spontaneous Recovery</strong>: When retraining a damaged network on a subset of items, the network &ldquo;spontaneously&rdquo; recovers unrehearsed items due to weight sharing, which is a qualitative signature of distributed representations.</li>
<li><strong>Limitations of Coarse Coding</strong>: The paper identifies that coarse coding requires relatively sparse features. Crowding too many feature-points together causes receptive fields to contain too many features, preventing the activity pattern from discriminating between combinations.</li>
<li><strong>Sequential Processing Constraint</strong>: When constituent structure is represented using identity/role bindings, only one structure can be represented at a time. Hinton argues this matches the empirical observation that people are, to a first approximation, sequential symbol processors.</li>
<li><strong>Learning Problem Deferred</strong>: The paper acknowledges that discovering which sets of items should correspond to single units is a difficult search problem, and defers the learning question to separate work (Hinton, Sejnowski, and Ackley, 1984).</li>
</ol>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>The following details are extracted from Section 5 (&ldquo;Implementing an Arbitrary Mapping&rdquo;) to facilitate reproduction of the &ldquo;Deep Dyslexia&rdquo; and &ldquo;Arbitrary Mapping&rdquo; simulation.</p>
<h3 id="data">Data</h3>
<p>The simulation uses synthetic data representing words and meanings.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>Synthetic Grapheme/Sememe Pairs</td>
          <td>20 pairs</td>
          <td>20 different grapheme strings mapped to random semantic vectors.</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Input (Graphemes)</strong>: 30 total units.
<ul>
<li>Structure: Divided into 3 groups of 10 units each.</li>
<li>Encoding: Each &ldquo;word&rdquo; (3 letters) activates exactly 1 unit in each group (sparse binary).</li>
</ul>
</li>
<li><strong>Output (Sememes)</strong>: 30 total units.
<ul>
<li>Structure: Binary units.</li>
<li>Encoding: Meanings are random vectors where each unit is active with probability $p=0.2$.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Learning Rule</strong>: The paper cites &ldquo;Hinton, Sejnowski &amp; Ackley (1984)&rdquo; (Boltzmann Machines) for the specific learning algorithm used to set weights.</li>
<li><strong>False Positive Analysis</strong>: The probability $f$ that a semantic feature is incorrectly activated is derived as:</li>
</ul>
<p>$$f = (1 - (1-p)^{(w-1)})^u$$</p>
<p>Where:</p>
<ul>
<li>$p$: Probability of a sememe being in a word meaning ($0.2$).</li>
<li>$w$: Number of words in a &ldquo;word-set&rdquo; (cluster).</li>
<li>$u$: Number of active &ldquo;word-set&rdquo; units per word.</li>
</ul>
<h3 id="models">Models</h3>
<p>The simulation uses a specific three-layer architecture.</p>
<table>
  <thead>
      <tr>
          <th>Layer</th>
          <th>Type</th>
          <th>Count</th>
          <th>Connectivity</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Input</strong></td>
          <td>Grapheme Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units (no direct link to Output).</td>
      </tr>
      <tr>
          <td><strong>Hidden</strong></td>
          <td>&ldquo;Word-Set&rdquo; Units</td>
          <td>20</td>
          <td>Fully connected to Input and Output.</td>
      </tr>
      <tr>
          <td><strong>Output</strong></td>
          <td>Sememe Units</td>
          <td>30</td>
          <td>Connected to all Intermediate units. Includes lateral inhibition (implied for &ldquo;clean up&rdquo;).</td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Weights</strong>: Binary/Integer logic in theoretical analysis, but &ldquo;stochastic&rdquo; weights in the Boltzmann simulation.</li>
<li><strong>Thresholds</strong>: Sememe units have variable thresholds dynamically adjusted to be slightly less than the number of active word-set units.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The simulation evaluated the robustness of the mapping.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Accuracy (Clean)</strong></td>
          <td>99.9%</td>
          <td>N/A</td>
          <td>Correct pattern produced 99.9% of the time after learning.</td>
      </tr>
      <tr>
          <td><strong>Lesion Error Rate</strong></td>
          <td>1.4%</td>
          <td>N/A</td>
          <td>140 errors in 10,000 tests after removing 1 word-set unit.</td>
      </tr>
      <tr>
          <td><strong>Semantic Errors</strong></td>
          <td>~60% of errors</td>
          <td>N/A</td>
          <td>83 of the 140 lesion errors were &ldquo;Deep Dyslexia&rdquo; errors (producing a valid but wrong semantic pattern).</td>
      </tr>
      <tr>
          <td><strong>Post-Noise Accuracy</strong></td>
          <td>64.3%</td>
          <td>99.3%</td>
          <td>Performance after adding noise to all connections involving word-set units. The 99.3% baseline (reported separately from the 99.9% clean accuracy above) reflects the pre-noise measurement at the time of this specific experiment.</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: Minimal. The original simulation ran on 1980s hardware (likely VAX-11 or similar).</li>
<li><strong>Replication</strong>: Reproducible on any modern CPU in milliseconds.</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hinton, G. E. (1984). Distributed Representations. <em>Technical Report CMU-CS-84-157</em>, Carnegie-Mellon University.</p>
<p><strong>Publication</strong>: CMU Computer Science Department Technical Report, October 1984</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@techreport</span>{hinton1984distributed,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Distributed representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hinton, Geoffrey E}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{1984}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">institution</span>=<span style="color:#e6db74">{Carnegie-Mellon University}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{CMU-CS-84-157}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Chemical Machine Vision</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/chemical-machine-vision/</link><pubDate>Sun, 14 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/rule-based/chemical-machine-vision/</guid><description>Machine vision approach using Gabor wavelets and Kohonen networks to classify chemical raster images and extract structural metadata.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Gkoutos, G. V., Rzepa, H., Clark, R. M., Adjei, O., &amp; Johal, H. (2003). Chemical Machine Vision: Automated Extraction of Chemical Metadata from Raster Images. <em>Journal of Chemical Information and Computer Sciences</em>, 43(5), 1342-1355. <a href="https://doi.org/10.1021/ci034017n">https://doi.org/10.1021/ci034017n</a></p>
<p><strong>Publication</strong>: J. Chem. Inf. Comput. Sci. 2003</p>
<h2 id="paper-classification-methodological-approach">Paper Classification: Methodological Approach</h2>
<p>This is a <strong>Method</strong> paper. It proposes a novel architectural pipeline applying &ldquo;machine vision&rdquo; techniques (Gabor wavelets and Kohonen networks) to the problem of identifying chemical diagrams in low-resolution raster images. The paper focuses on the &ldquo;how&rdquo; (the algorithm and its parameters) and validates the method through quantitative experiments optimizing feature vectors and masks.</p>
<h2 id="motivation-extracting-legacy-chemical-data">Motivation: Extracting Legacy Chemical Data</h2>
<p>The primary motivation is to unlock the &ldquo;large amount of data&rdquo; trapped in legacy raster images (GIF, JPEG) on the Web that lack semantic metadata.</p>
<ul>
<li><strong>Legacy Data Problem</strong>: Most chemical structural information on the Web is embedded in raster images, not machine-readable formats like Molfiles.</li>
<li><strong>Limitations of Existing Tools</strong>: Previous tools like Kekule and CLiDE acted as &ldquo;Chemical OCR,&rdquo; attempting to reconstruct exact atom-bond connections. This required high-resolution images (&gt;300 dpi) and human intervention, making them unsuitable for automated Web crawling of low-resolution (72-96 dpi) images.</li>
<li><strong>Goal</strong>: To create a low-cost, automated tool for a &ldquo;robot-based Internet resource discovery tool&rdquo; that can classify images (e.g., &ldquo;is this a molecule?&rdquo;).</li>
</ul>
<h2 id="core-innovation-texture-recognition-over-structural-ocr">Core Innovation: Texture Recognition over Structural OCR</h2>
<p>The core novelty is the shift from &ldquo;Optical Character Recognition&rdquo; (exact reconstruction) to <strong>&ldquo;Texture Recognition&rdquo;</strong> (classification).</p>
<ul>
<li><strong>Texture-Based Approach</strong>: The authors treat chemical diagrams as textures. They use <strong>Gabor wavelets</strong> to extract texture features. <strong>Crucially, this system does not recognize specific chemical structures</strong> (i.e., atom-bond connectivity tables, <a href="/notes/chemistry/molecular-representations/notations/smiles/">SMILES</a>, or Molfiles). It only classifies images into broad categories.</li>
<li><strong>Incremental Learning</strong>: The system uses a <strong>Kohonen Self-Organizing Feature Map (KSOFM)</strong> combined with Class Boundary Analysis (CBA). This allows for &ldquo;incremental learning,&rdquo; where new classes (e.g., aromatic vs. non-aromatic) can be added without retraining the entire system.</li>
<li><strong>Optimization for Chemistry</strong>: The authors identify specific parameters (frequency channels, mask sizes) that are optimal for the &ldquo;texture&rdquo; of chemical diagrams.</li>
<li><strong>Integration with ChemDig</strong>: The method was designed to feed into ChemDig, a robot-based index engine for automated web crawling and metadata generation.</li>
</ul>
<h2 id="experimental-setup-parameter-optimization">Experimental Setup: Parameter Optimization</h2>
<p>The authors performed optimization and validation experiments using a dataset of <strong>300 images</strong> divided into three classes: Ring Systems, Non-Ring Systems, and Non-Chemistry (textures, biological figures, etc.).</p>
<ol>
<li><strong>Parameter Optimization</strong>: They systematically varied hyperparameters to find the optimal configuration:
<ul>
<li><strong>Feature Vector Size</strong>: Tested sizes from 100 to 4000 elements.</li>
<li><strong>Energy Mask Size</strong>: Tested windows from $3 \times 3$ to $15 \times 15$ pixels.</li>
<li><strong>Frequency Channels</strong>: Tested seven spatial frequencies ($\sqrt{2}$ to $64\sqrt{2}$).</li>
</ul>
</li>
<li><strong>Classification Performance</strong>: Evaluated the system&rsquo;s ability to classify unseen test images using a 50:50 training/test split.</li>
<li><strong>Comparison</strong>: Qualitatively compared the approach against vectorization tools (Autotrace, CR2V).</li>
</ol>
<h2 id="results-robust-classification-of-low-resolution-images">Results: Robust Classification of Low-Resolution Images</h2>
<ul>
<li><strong>Optimal Configuration</strong>: The system performed best with a feature vector size of ~1500 elements, a $9 \x9$ energy mask, and frequency channel $4\sqrt{2}$.</li>
<li><strong>High Accuracy</strong>: Achieved a recognition rate of <strong>91%</strong> with a 50:50 training/test split, and up to <strong>92%</strong> with a 70:30 split.</li>
<li><strong>Robustness</strong>: The system successfully distinguished between chemical and non-chemical images (zero false negatives for chemical images).</li>
<li><strong>Limitations</strong>: Misclassifications occurred between &ldquo;ring&rdquo; and &ldquo;non-ring&rdquo; systems when structures had similar visual &ldquo;textures&rdquo; (e.g., similar density or layout).</li>
<li><strong>Impact</strong>: The method is viable for automating metadata generation (e.g., <code>alt</code> tags) for web crawlers, functioning as a coarse-grained filter before more expensive processing.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The study used a custom dataset of raster images collected from the Web.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td><strong>Custom Web Dataset</strong></td>
          <td>300 images</td>
          <td>Split into 3 classes: Ring Systems, Non-Ring Systems, Non-Chemistry.</td>
      </tr>
      <tr>
          <td>Resolution</td>
          <td><strong>Low-Res Web Images</strong></td>
          <td>72-96 dpi</td>
          <td>Deliberately chosen to mimic Web conditions where OCR fails.</td>
      </tr>
      <tr>
          <td>Format</td>
          <td><strong>Raster</strong></td>
          <td>GIF, JPEG</td>
          <td>Typical web formats.</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>The core pipeline consists of a <strong>Gabor Transform Unit</strong> followed by a <strong>Training/Classification Unit</strong>.</p>
<ul>
<li><strong>Gabor Wavelets</strong>: Used for feature extraction. The 2D Gabor wavelet equation is:
$$h(x,y)=\exp\left{-\frac{1}{2}\left[\frac{x^{2}}{\sigma_{x}^{2}}+\frac{y^{2}}{\sigma_{y}^{2}}\right]\right}\cos(2\pi\mu_{\sigma}x+\phi)$$
<ul>
<li><strong>Bank Structure</strong>: 28 filters total (4 orientations $\times$ 7 radial frequencies).</li>
<li><strong>Orientations</strong>: $0^{\circ}, 45^{\circ}, 90^{\circ}, 135^{\circ}$.</li>
<li><strong>Frequencies</strong>: 1 octave apart, specifically $1\sqrt{2}, \dots, 64\sqrt{2}$.</li>
<li><strong>Selected Frequency</strong>: $4\sqrt{2}$ was found to be optimal for chemistry.</li>
</ul>
</li>
<li><strong>Preprocessing</strong>:
<ul>
<li><strong>Buffer Mounting</strong>: Images are mounted in a buffer (set to 0) to handle edge artifacts.</li>
<li><strong>Look-Up-Tables (LUT/LUF)</strong>: A binary Look-Up-Frame (LUF) indicates Regions of Interest (ROI) to avoid computing empty space; values are stored in a Look-Up-Table (LUT) to prevent re-computation of overlapping windows.</li>
</ul>
</li>
<li><strong>Feature Extraction</strong>:
<ul>
<li><strong>Non-linear Thresholding</strong>: $\psi(t) = \tanh(\alpha t)$ with $\alpha = 0.25$.</li>
<li><strong>Energy Function</strong>: Calculated as average absolute deviation from the mean using a window $W_{xy}$.
$$e_{k}(x,y)=\frac{1}{M^{2}}\sum_{(a,b)\in W_{xy}}|\psi(r_{k}(a,b))|$$</li>
<li><strong>Optimal Window</strong>: $9 \times 9$ pixels.</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<p>The classification model relies on competitive learning.</p>
<ul>
<li><strong>Architecture</strong>: <strong>Kohonen Self-Organizing Feature Map (KSOFM)</strong>.</li>
<li><strong>Training</strong>:
<ul>
<li><strong>Learning Rate</strong>: Starts at 1.0, decreases to 0.1.</li>
<li><strong>Class Boundary Analysis (CBA)</strong>: Computes the centroid (mean) and variance of each cluster. The variance defines the class boundary.</li>
</ul>
</li>
<li><strong>Classification Metric</strong>: <strong>Euclidean Distance Norm</strong>. An unknown vector is classified based on the shortest distance to a cluster center, provided it falls within the variance boundary.
$$D_{ij}=||x_{i}-x_{j}||$$</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Performance was measured using recognition rate ($R_s$) and misclassification error ($E_s$).</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value</th>
          <th>Baseline</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Recognition Rate</td>
          <td><strong>91%</strong></td>
          <td>N/A</td>
          <td>Achieved with 50:50 split. 92% with 70:30 split.</td>
      </tr>
      <tr>
          <td>Feature Size</td>
          <td><strong>~1500</strong></td>
          <td>4000</td>
          <td>Reducing vector size from 4000 to 1500 maintained ~80% accuracy while improving speed.</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{gkoutosChemicalMachineVision2003,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Chemical {{Machine Vision}}: {{Automated Extraction}} of {{Chemical Metadata}} from {{Raster Images}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{Chemical {{Machine Vision}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Gkoutos, Georgios V. and Rzepa, Henry and Clark, Richard M. and Adjei, Osei and Johal, Harpal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2003</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = sep,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Chemical Information and Computer Sciences}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{43}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{5}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1342--1355}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">issn</span> = <span style="color:#e6db74">{0095-2338}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1021/ci034017n}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">urldate</span> = <span style="color:#e6db74">{2025-12-15}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">langid</span> = <span style="color:#e6db74">{english}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Invalid SMILES Benefit Chemical Language Models: A Study</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/invalid-smiles-help/</link><pubDate>Tue, 02 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-representations/notations/invalid-smiles-help/</guid><description>Skinnider (2024) shows that generating invalid SMILES actually improves chemical language model performance through quality filtering.</description><content:encoded><![CDATA[<h2 id="core-contribution-repurposing-invalid-smiles">Core Contribution: Repurposing Invalid SMILES</h2>
<p>This is an <strong>Empirical</strong> paper that challenges a fundamental assumption in the field of chemical language models. Skinnider provides both empirical evidence and mechanistic explanations for why the ability to generate &ldquo;invalid&rdquo; SMILES strings is beneficial for model performance.</p>
<h2 id="the-problem-with-absolute-validity-in-chemical-lms">The Problem with Absolute Validity in Chemical LMs</h2>
<p>Prior research attempted to eliminate invalid generations using constrained representations like SELFIES. This paper demonstrates that invalid outputs serve as low-likelihood samples whose removal acts as an implicit quality filter, improving distribution learning.</p>
<h2 id="invalid-generation-as-an-implicit-quality-filter">Invalid Generation as an Implicit Quality Filter</h2>
<p>The central insight is counterintuitive: <strong>invalid SMILES generation acts as a built-in quality control mechanism</strong>. The key contributions are:</p>
<ol>
<li>
<p><strong>Empirical Evidence</strong>: Direct comparisons showing that SMILES-based models consistently outperform SELFIES-based models across multiple metrics, with performance gains strongly correlated with the proportion of invalid outputs generated.</p>
</li>
<li>
<p><strong>Mechanistic Explanation</strong>: Invalid SMILES are demonstrated to be low-likelihood samples from the model&rsquo;s probability distribution. When these are filtered out, it&rsquo;s equivalent to removing the model&rsquo;s least confident predictions, a form of automatic quality control.</p>
</li>
<li>
<p><strong>Causal Evidence</strong>: By modifying SELFIES to allow invalid generation (through relaxed constraints), the author shows that performance improves when models can generate and discard invalid outputs, directly proving the causal relationship.</p>
</li>
<li>
<p><strong>Bias Analysis</strong>: SELFIES models are shown to introduce systematic structural biases (fewer aromatic rings, more aliphatic rings) due to their validity constraints, limiting their ability to explore chemical space naturally.</p>
</li>
</ol>
<h2 id="experimental-design-and-causal-interventions">Experimental Design and Causal Interventions</h2>
<p>The paper uses a multi-pronged approach to establish both correlation and causation:</p>
<p><strong>Performance Comparisons</strong>: SMILES and SELFIES models were trained on identical datasets and evaluated using distribution-learning metrics like Fréchet ChemNet distance. The comparison was robust across different architectures, training set sizes, and chemical databases.</p>
<p><strong>Loss Analysis</strong>: The relationship between SMILES validity and model confidence was examined by analyzing the sequence loss. For a given SMILES string $S$ composed of tokens $t_1, t_2, &hellip;, t_N$, the negative log-likelihood acts as a proxy for the model&rsquo;s uncertainty:</p>
<p>$$ \text{NLL}(S) = -\sum_{i=1}^N \log P(t_i | t_1, &hellip;, t_{i-1}) $$</p>
<p>Invalid SMILES strings consistently register higher $\text{NLL}$ scores, meaning they represent the model&rsquo;s least confident predictions. Filtering them effectively acts as automatic quality control, providing the mechanistic explanation for why invalid filtering improves performance.</p>
<p><strong>Causal Intervention</strong>: A key experiment involved modifying the SELFIES valency constraints at two levels: first allowing pentavalent carbons (&ldquo;Texas SELFIES&rdquo;), then removing all constraints entirely (&ldquo;unconstrained SELFIES&rdquo;). This allowed direct testing of whether the ability to generate invalid outputs (which are then discarded) causally improves performance.</p>
<p><strong>Structural Bias Analysis</strong>: Generated molecules were analyzed for chemical features like ring types and bond patterns to quantify how validity constraints systematically distort the model&rsquo;s exploration of chemical space.</p>
<p><strong>Generalization Testing</strong>: Models were trained on subsets of chemical databases and tested on their ability to reproduce the broader chemical space, measuring how validity constraints affect generalization.</p>
<p><strong>Practical Application</strong>: The approach was tested on structure elucidation, using models to identify unknown molecules from minimal experimental data like mass spectrometry.</p>
<h2 id="key-findings-on-validity-constraints-and-bias">Key Findings on Validity Constraints and Bias</h2>
<p><strong>Superior Performance Across the Board</strong>: SMILES-based models consistently outperformed SELFIES models on distribution-learning tasks. Using metrics like Fréchet ChemNet distance, SMILES models generated molecules that more closely matched the statistical properties of their training data. This performance advantage was directly correlated with the proportion of invalid SMILES generated. Models that produced more invalid outputs performed better after filtering.</p>
<p><strong>Invalid SMILES Are Low-Confidence Predictions</strong>: The analysis revealed that invalid SMILES consistently have higher loss values than valid ones, meaning they represent the model&rsquo;s least confident predictions. This suggests that validity checking acts as an automatic confidence filter, removing low-quality samples without requiring explicit uncertainty estimation.</p>
<p><strong>Causal Evidence Through Unconstrained SELFIES</strong>: Direct causal evidence came from modifying SELFIES to allow invalid generation. When &ldquo;unconstrained SELFIES&rdquo; models could generate and discard invalid molecules, their performance improved, approaching that of SMILES models. This provides direct causal evidence that the ability to generate invalid outputs is what drives the performance gains.</p>
<p><strong>Validity Constraints Introduce Systematic Bias</strong>: SELFIES models showed clear structural biases compared to both training data and SMILES outputs. They generated fewer aromatic rings and more aliphatic structures, systematic distortions caused by the valency constraints used to ensure validity. These biases limit the model&rsquo;s ability to faithfully represent chemical space.</p>
<p><strong>Reduced Generalization</strong>: When trained on subsets of chemical databases, SMILES models could reproduce a larger portion of the complete chemical space compared to SELFIES models. Although SELFIES generated more valid molecules in absolute terms, their structural biases constrained exploration and limited generalization beyond the training set.</p>
<p><strong>Real-World Application Benefits</strong>: In structure elucidation tasks, identifying unknown molecules from experimental data like mass spectrometry, SMILES-based models significantly outperformed SELFIES models. This demonstrates that the benefits extend beyond academic benchmarks to practical applications.</p>
<p><strong>CASMI 2022 Benchmark</strong>: The language model trained on the LOTUS database was benchmarked against 19 submissions to the CASMI 2022 competition for structure elucidation of unknown compounds. Using only accurate mass as input (no MS/MS data), the model achieved competitive performance, highlighting the practical utility of the sampling-frequency-based approach for de novo structure elucidation.</p>
<p><strong>Computational Efficiency</strong>: Filtering invalid SMILES is computationally trivial. Parsing ten million SMILES strings with RDKit takes approximately 7.5 minutes on a single CPU, making the post-processing overhead negligible compared to model training and inference costs.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="models">Models</h3>
<p><strong>Primary Architecture (LSTM):</strong> The main results rely on a Recurrent Neural Network (RNN) using Long Short-Term Memory (LSTM) units.</p>
<ul>
<li><strong>Structure:</strong> Three-layer LSTM with a hidden layer size of 1,024 dimensions</li>
<li><strong>Embedding:</strong> An embedding layer of 128 dimensions</li>
<li><strong>Decoder:</strong> A linear decoder layer outputs token probabilities</li>
</ul>
<p><strong>Secondary Architecture (Transformer/GPT):</strong> To confirm robustness across architectures, the author also used a Generative Pretrained Transformer (GPT) architecture adapted from MolGPT.</p>
<ul>
<li><strong>Structure:</strong> Eight transformer blocks</li>
<li><strong>Internals:</strong> Each block contains eight masked self-attention heads and a feed-forward network (1,024 dimensions) using GELU activation</li>
<li><strong>Embedding:</strong> 256 dimensions, concatenated with learned positional encodings</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Optimizer:</strong> Adam optimizer for both architectures with $\beta_1=0.9$ and $\beta_2=0.999$.</p>
<p><strong>Learning Rate:</strong></p>
<ul>
<li>LSTM: 0.001</li>
<li>Transformer: 0.0005</li>
</ul>
<p><strong>Batch Size:</strong> 64</p>
<p><strong>Loss Function:</strong> Cross-entropy loss of next-token prediction.</p>
<p><strong>Stopping Criteria:</strong> Early stopping using a validation set (10% of training data) with patience of 50,000 minibatches.</p>
<h3 id="data">Data</h3>
<p><strong>Primary Source:</strong> ChEMBL database (version 28).</p>
<p><strong>Preprocessing Pipeline:</strong></p>
<ul>
<li><strong>Cleaning:</strong> Removal of duplicate SMILES, salts, and solvents (retaining heavy fragments with $\geq 3$ heavy atoms)</li>
<li><strong>Filtering:</strong> Molecules with atoms other than {Br, C, Cl, F, H, I, N, O, P, S} were removed</li>
<li><strong>Normalization:</strong> Charged molecules were neutralized and converted to canonical SMILES</li>
</ul>
<p><strong>Training Subsets:</strong> Models were trained on random samples of 30,000, 100,000, and 300,000 molecules to test scalability.</p>
<p><strong>Generalization Data:</strong> To test generalization, models were also trained on the <a href="/notes/chemistry/datasets/gdb-13/">GDB-13</a> database (enumerating drug-like molecules up to 13 heavy atoms).</p>
<p><strong>Structure Elucidation Data:</strong> For practical application tasks, models were trained on natural products (LOTUS, COCONUT), food compounds (FooDB), and environmental contaminants (NORMAN).</p>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metric:</strong> Fréchet ChemNet Distance (FCD), measuring chemical similarity between generated molecules and the training set (lower is better).</p>
<p><strong>Secondary Metrics:</strong></p>
<ul>
<li><strong>Validity:</strong> Percentage of outputs parseable by RDKit</li>
<li><strong>Scaffold Similarity:</strong> Jensen-Shannon distances between Murcko scaffold compositions</li>
<li><strong>Physical Properties:</strong> Comparisons of molecular weight, LogP, topological polar surface area (TPSA), and ring counts (aromatic vs. aliphatic)</li>
<li><strong>Structure Elucidation:</strong> &ldquo;Top-k accuracy,&rdquo; the proportion of held-out molecules where the correct structure appeared in the model&rsquo;s top $k$ ranked outputs</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute Nodes:</strong> Dell EMC C4140 GPU compute nodes</li>
<li><strong>GPUs:</strong> NVIDIA Tesla V100</li>
<li><strong>Compute Time:</strong> Parsing 10 million SMILES took ~7.5 minutes on a single CPU; SELFIES models required an average of 0.6 hours longer to train than SMILES models</li>
</ul>
<h3 id="replicability">Replicability</h3>
<p><strong>Code Availability:</strong> Source code and intermediate data are available via <a href="https://doi.org/10.5281/zenodo.10680855">Zenodo</a>. Pre-trained model weights are not provided in the archive, requiring researchers to train models from scratch using the included scripts to fully replicate the study.</p>
<p><strong>Data Availability:</strong> Training datasets and generated molecule samples (10 million from ChEMBL/GDB-13 models, 100 million from LOTUS/COCONUT/FooDB/NORMAN cross-validation folds) are available via <a href="https://doi.org/10.5281/zenodo.8321735">Zenodo</a>.</p>
<p><strong>Software Libraries:</strong></p>
<ul>
<li><strong>PyTorch:</strong> LSTM and Transformer implementations</li>
<li><strong>RDKit:</strong> SMILES parsing, validity checking, and property calculation</li>
<li><strong>SELFIES:</strong> Version 2.1.1 for conversion</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.10680855">Source code (Zenodo)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Training scripts, analysis code, and intermediate data</td>
      </tr>
      <tr>
          <td><a href="https://doi.org/10.5281/zenodo.8321735">Training and generated molecules (Zenodo)</a></td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Preprocessed training sets and sampled molecules</td>
      </tr>
  </tbody>
</table>
<h2 id="implications-and-takeaways">Implications and Takeaways</h2>
<p>This work reframes how we think about &ldquo;errors&rdquo; in generative models. The key insight is that model outputs appearing incorrect often represent low-likelihood samples whose removal improves overall performance.</p>
<p>The findings suggest that the field&rsquo;s drive toward guaranteed validity leads to systematic biases. Letting models fail informatively and using those failures as quality signals can yield better distribution learning. This is relevant as the field moves toward larger, more capable models where such self-correction mechanisms become increasingly valuable.</p>
<p>For practitioners, the takeaway is to consider the role of invalid outputs before eliminating them. Filtering low-confidence generations provides automatic quality control that improves final results.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Skinnider, M. A. (2024). Invalid SMILES are beneficial rather than detrimental to chemical language models. Nature Machine Intelligence, 6(4), 437-448. <a href="https://doi.org/10.1038/s42256-024-00821-x">https://doi.org/10.1038/s42256-024-00821-x</a></p>
<p><strong>Publication</strong>: Nature Machine Intelligence (2024)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{skinnider2024invalid,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Invalid SMILES are beneficial rather than detrimental to chemical language models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Skinnider, Michael A}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Nature Machine Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{6}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{4}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{437--448}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{Nature Publishing Group UK London}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself is where the real gain comes from. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a different objective that optimizes a tighter bound on the log-likelihood.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and the difference matters.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to many latent dimensions becoming inactive (the paper reports the number of &ldquo;active units&rdquo; per model).</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To activate more latent dimensions or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders (IWAE) for Tighter Bounds</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</guid><description>Summary of Burda, Grosse &amp; Salakhutdinov's ICLR 2016 paper introducing Importance Weighted Autoencoders for tighter variational bounds</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a generative model that shares the same architecture as the Variational Autoencoder (VAE) but uses a different, tighter objective function. The key innovation is using importance weighting to derive a strictly tighter log-likelihood lower bound than the standard VAE objective.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The standard VAE has several limitations that motivated this work:</p>
<ul>
<li><strong>Strong assumptions</strong>: VAEs typically assume the posterior distribution is simple (e.g., approximately factorial) and that its parameters can be easily approximated from observations.</li>
<li><strong>Simplified representations</strong>: The VAE objective can force models to learn overly simplified representations that underutilize the network&rsquo;s full modeling capacity.</li>
<li><strong>Harsh penalization</strong>: The VAE objective harshly penalizes approximate posterior samples that are poor explanations for the data, which can be overly restrictive.</li>
<li><strong>Inactive units</strong>: VAEs tend to learn latent spaces with effective dimensions far below their capacity, where many latent units are ignored (a phenomenon later termed <strong>posterior collapse</strong>, where the approximate posterior collapses to the prior and conveys no information). The authors wanted to investigate whether a new objective could address this issue.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>IWAE objective function</strong>, denoted as $\mathcal{L}_{k}$.</p>
<ul>
<li>
<p><strong>VAE ($\mathcal{L}_{1}$ Bound)</strong>: The standard VAE maximizes $\mathcal{L}(x)=\mathbb{E}_{q(h|x)}[\log\frac{p(x,h)}{q(h|x)}]$. This is equivalent to the new bound when $k=1$.</p>
</li>
<li>
<p><strong>IWAE ($\mathcal{L}_{k}$ Bound)</strong>: The IWAE maximizes a tighter bound that uses $k$ samples drawn from the recognition model $q(h|x)$:</p>
</li>
</ul>
<p>$$\mathcal{L}_{k}(x)=\mathbb{E}_{h_{1},&hellip;,h_{k}\sim q(h|x)}\left[\log\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,h_{i})}{q(h_{i}|x)}\right]$$</p>
<ul>
<li>
<p><strong>Tighter Bound</strong>: The authors prove that this bound is always tighter than or equal to the VAE bound ($\mathcal{L}_{k+1} \geq \mathcal{L}_{k}$) and that as $k$ approaches infinity, $\mathcal{L}_{k}$ approaches the true log-likelihood $\log p(x)$.</p>
</li>
<li>
<p><strong>Increased Flexibility</strong>: Using multiple samples gives the IWAE additional flexibility to learn generative models whose posterior distributions are complex and violate the VAE&rsquo;s simplifying assumptions.</p>
</li>
</ul>
<h3 id="key-concept-averaging-inside-vs-outside-the-log">Key Concept: Averaging Inside vs. Outside the Log</h3>
<p>A crucial distinction exists between how VAE and IWAE utilize $k$ samples. Understanding this difference explains why increasing $k$ in IWAE improves the bound. In VAE, it reduces variance.</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-vae-vs-importance-weighted-autoencoder-iwae.webp"
         alt="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         title="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">VAE vs IWAE computation flow. The key difference is where the log operation occurs: VAE computes log(w_i) for each sample then averages. IWAE averages the weights first then applies log to the result.</figcaption>
    
</figure>

<p><strong>VAE (Average of Logs):</strong></p>
<p>For a VAE, using $k$ samples approximates:</p>
<p>$$\mathbb{E}\left[ \frac{1}{k} \sum_{i=1}^k \log w_i \right] \approx \text{ELBO}$$</p>
<p>where $w_i = p(x, h_i) / q(h_i | x)$. Increasing $k$ here only reduces the variance of the gradient estimator; the model still targets the same ELBO bound, so performance gains saturate quickly.</p>
<p><strong>IWAE (Log of Average):</strong></p>
<p>IWAE performs the averaging <em>inside</em> the logarithm:</p>
<p>$$\mathbb{E}\left[ \log \left( \frac{1}{k} \sum_{i=1}^k w_i \right) \right] = \mathcal{L}_k$$</p>
<p>By Jensen&rsquo;s Inequality ($\log(\mathbb{E}[X]) \geq \mathbb{E}[\log(X)]$ for concave functions), this bound is mathematically guaranteed to be at least as tight as the VAE bound. Each increase in $k$ defines a new, strictly tighter lower bound on the log-likelihood.</p>
<p><strong>Why This Matters for Gradients:</strong></p>
<p>In IWAE, the gradient weights are normalized importance weights $\tilde{w}_i = w_i / \sum_j w_j$. This means &ldquo;bad&rdquo; samples (those with low $w_i$) contribute very little to the gradient update since they vanish from the weighted sum. VAE uses unweighted samples, so a single sample with extremely low probability produces a massive negative log value that can dominate the loss and harshly penalize the model. IWAE&rsquo;s formulation allows the model to focus learning on the samples that explain the data well.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors compared VAE and IWAE on density estimation tasks using the MNIST and Omniglot datasets. They evaluated two main network architectures: one with a single stochastic layer and another with two stochastic layers. The models were trained with varying numbers of importance samples ($k \in {1, 5, 50}$) to observe the effect on performance and latent space utilization. The primary metrics for evaluation were the test log-likelihood (estimated using 5000 samples) and the number of &ldquo;active&rdquo; latent units, which quantifies the richness of the learned representations.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-vs-vae-active-latent-units-comparison.webp"
         alt="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         title="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Active latent units for VAE vs IWAE (1 stochastic layer). VAE active units remain flat. IWAE increases with k. Data from Table 1 of Burda et al. (2016).</figcaption>
    
</figure>

<ul>
<li>
<p><strong>Better Performance</strong>: IWAE achieved higher log-likelihoods than VAEs across all configurations. On MNIST with two stochastic layers and $k=50$, IWAE reached $-82.90$ nats compared to $-84.78$ for VAE. On Omniglot, the best IWAE achieved $-103.38$ nats versus $-106.30$ for VAE. IWAE performance improved consistently with increasing $k$, while VAE performance benefited only slightly from using more samples ($k&gt;1$).</p>
</li>
<li>
<p><strong>Richer Representations</strong>: In all experiments with $k&gt;1$, IWAE learned more active latent dimensions than VAE, suggesting richer latent representations.</p>
</li>
<li>
<p><strong>Objective Drives Representation</strong>: The authors found that latent dimension inactivation is driven by the objective function. They demonstrated this through an &ldquo;objective swap&rdquo; experiment:</p>
</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-objective-swap-experiment.webp"
         alt="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         title="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Objective swap experiment on MNIST (1 stochastic layer). Switching a trained VAE to the IWAE objective improves both metrics. Switching IWAE to VAE degrades them. Data from Table 2 of Burda et al. (2016).</figcaption>
    
</figure>

<p>This experiment provides evidence that the objective function itself influences latent utilization:</p>
<ul>
<li><strong>VAE → IWAE</strong>: A converged VAE model, when fine-tuned with the IWAE objective ($k=50$), gained 3 active units (19 → 22) and improved test NLL from 86.76 to 84.88.</li>
<li><strong>IWAE → VAE</strong>: A converged IWAE model fine-tuned with the VAE objective lost 2 active units (25 → 23) and worsened test NLL from 84.78 to 86.02.</li>
</ul>
<p>These results strongly suggest that inactivation of latent dimensions is driven by the objective function rather than by optimization dynamics, initialization, or architecture. The authors note that optimization also plays a role, as the swap results do not exactly match training from scratch.</p>
<ul>
<li>
<p><strong>Comparison to Other Models</strong>: On MNIST, the best IWAE ($-82.90$ nats) outperformed deep belief networks ($-84.55$ nats) and deep autoregressive networks ($-84.13$ nats), though DRAW ($-80.97$ nats), which exploits spatial structure, achieved better results. On Omniglot, the best IWAE ($-103.38$ nats) fell slightly behind RBMs trained with persistent contrastive divergence ($-100.46$ nats).</p>
</li>
<li>
<p><strong>Conclusion</strong>: IWAEs learn richer latent representations and achieve better generative performance than VAEs with equivalent architectures and training time.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: $28 \times 28$ binarized handwritten digits (60,000 training / 10,000 test).</li>
<li><strong>Omniglot</strong>: $28 \times 28$ binarized handwritten characters from various alphabets (24,345 training / 8,070 test).</li>
<li><strong>Binarization</strong>: Dynamic sampling where binary values are sampled with expectations equal to the real pixel intensities (following Salakhutdinov &amp; Murray, 2008).</li>
<li><strong>Fixed Binarization</strong>: Results on a fixed binarization of MNIST (Larochelle, 2011) confirm that IWAE outperforms VAE across preprocessing methods. It exhibits notably more overfitting compared to dynamic sampling.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main network architectures were tested:</p>
<ol>
<li>One stochastic layer (50 units) with two deterministic layers (200 units each).</li>
<li>Two stochastic layers (100 and 50 units). Between x and h1 were two deterministic layers with 200 units each. Between h1 and h2 were two deterministic layers with 100 units each.</li>
</ol>
<ul>
<li><strong>Activations</strong>: <code>tanh</code> for deterministic layers; <code>exp</code> applied to variance predictions to ensure positivity.</li>
<li><strong>Distributions</strong>: Gaussian latent layers with diagonal covariance; Bernoulli observation layer.</li>
<li><strong>Initialization</strong>: Glorot &amp; Bengio (2010) heuristic.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9$, $\beta_2=0.999$, $\epsilon=10^{-4}$).</li>
<li><strong>Batch Size</strong>: 20.</li>
<li><strong>Learning Rate Schedule</strong>: Annealed rate of $0.001 \cdot 10^{-i/7}$ for $3^i$ epochs (where $i=0&hellip;7$), totaling 3,280 passes over the data.</li>
<li><strong>Variance Control</strong>: A common concern with importance sampling is high variance. The authors prove that the Mean Absolute Deviation of their estimator is bounded by $2 + 2\delta$, where $\delta$ is the gap between the bound and true log-likelihood. As the bound tightens, variance remains controlled.</li>
<li><strong>Computational trick</strong>: In the basic IWAE implementation, both forward and backward passes must be done independently for each of the $k$ samples, so the cost scales linearly with $k$. However, the authors describe an optional optimization: stochastically approximate the gradient sum by sampling a single $\epsilon_i$ proportional to its normalized weight $\tilde{w}_i$, then computing only that one backward pass. This reduces the cost to $k$ forward passes and one backward pass. Since the backward pass costs roughly twice the forward pass, this yields approximately a 3x speedup for large $k$ at the cost of increased gradient variance.</li>
</ul>
<p><strong>Relationship to Reweighted Wake-Sleep (RWS):</strong> Both IWAE and Reweighted Wake-Sleep (Bornschein &amp; Bengio, 2015) use importance-weighted samples and have closely related generative model updates. The key difference is that IWAE derives a single unified lower bound $\mathcal{L}_k$ and uses the reparameterization trick to train the recognition network jointly. RWS instead uses separate wake and sleep phases for the recognition network, which are not derived from $\mathcal{L}_k$.</p>
<h3 id="evaluation">Evaluation</h3>
<ol>
<li><strong>Test Log-Likelihood</strong>: Primary measure of generative performance, estimated as the mean of $\mathcal{L}_{5000}$ (5000 samples) on the test set.</li>
<li><strong>Active Units</strong>: To quantify latent space richness, the authors measured &ldquo;active&rdquo; latent dimensions. A unit $u$ was defined as active if its activity statistic $A_{u}=\text{Cov}_{x}(\mathbb{E}_{u\sim q(u|x)}[u])$ exceeded $10^{-2}$. The $10^{-2}$ threshold is justified by a bimodal distribution of the log activity statistic, showing clear separation between active and inactive units.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: GPU-based implementation using mini-batch replication to parallelize the $k$ samples. Specific GPU type and training times are not reported.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yburda/iwae">yburda/iwae</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official Theano implementation for MNIST and Omniglot</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Burda, Y., Grosse, R., &amp; Salakhutdinov, R. (2016). Importance Weighted Autoencoders. <em>International Conference on Learning Representations (ICLR) 2016</em>. <a href="https://arxiv.org/abs/1509.00519">https://arxiv.org/abs/1509.00519</a></p>
<p><strong>Publication</strong>: ICLR 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{burda2016importance,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Importance Weighted Autoencoders}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yuri Burda and Roger Grosse and Ruslan Salakhutdinov}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1509.00519}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1509.00519">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Auto-Encoding Variational Bayes: VAE Paper Summary</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</guid><description>Summary of Kingma &amp; Welling's 2013 VAE paper introducing the reparameterization trick and variational autoencoders.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces a generative mechanism (the VAE) and an optimization technique (the reparameterization trick), with formal theoretical derivation. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the <strong>variational auto-encoder (VAE)</strong> when neural networks are used as the recognition model.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors address two central intractabilities in directed probabilistic models with continuous latent variables:</p>















<figure class="post-figure center ">
    <img src="/img/notes/autoencoding-variational-bayes-figure-1-model-diagram.webp"
         alt="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         title="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Figure 1 from the paper: The directed graphical model. Solid lines denote the generative model $p_\theta(z)p_\theta(x|z)$, dashed lines denote the variational approximation $q_\phi(z|x)$. The variational parameters $\phi$ are learned jointly with the generative parameters $\theta$.</figcaption>
    
</figure>

<ol>
<li>
<p><strong>Intractable Posteriors</strong>: In models with continuous latent variables (like those with non-linear hidden layers), the true posterior $p_{\theta}(z|x)$ cannot be calculated analytically, preventing the use of standard EM algorithms.</p>
</li>
<li>
<p><strong>Large Datasets</strong>: Sampling-based solutions like Monte Carlo EM (MCEM) require expensive sampling loops per datapoint. This makes them too slow for large datasets where batch optimization is too costly and efficient minibatch updates are required.</p>
</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<h3 id="the-reparameterization-trick-sgvb-estimator">The Reparameterization Trick (SGVB Estimator)</h3>
<p>The core innovation is the <strong>Stochastic Gradient Variational Bayes (SGVB)</strong> estimator. The authors solve the high variance of standard gradient estimation by &ldquo;reparameterizing&rdquo; the random variable $\tilde{z}$.</p>
<p>They express $z$ as a deterministic function of the input $x$ and an auxiliary noise variable $\epsilon$:</p>
<p>$$\tilde{z} = g_{\phi}(\epsilon, x) \quad \text{with} \quad \epsilon \sim p(\epsilon)$$</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-reparameterization-trick.webp"
         alt="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         title="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reparameterization trick. (A) Standard stochastic nodes block gradient flow during backpropagation. (B) By expressing $z = \mu + \sigma \odot \epsilon$ with external noise $\epsilon \sim \mathcal{N}(0,1)$, gradients can flow through the deterministic path to the parameters $\phi$.</figcaption>
    
</figure>

<ul>
<li><strong>Mechanism</strong>: For a Gaussian posterior, $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim \mathcal{N}(0, I)$.</li>
<li><strong>Impact</strong>: This makes the Monte Carlo estimate differentiable with respect to the variational parameters $\phi$, allowing the variational lower bound to be optimized via standard stochastic gradient ascent (like SGD or Adagrad).</li>
</ul>
<h3 id="the-aevb-algorithm-the-vae">The AEVB Algorithm (The VAE)</h3>
<p>The <strong>Auto-Encoding VB (AEVB)</strong> algorithm amortizes inference by learning a global recognition model (encoder) $q_{\phi}(z|x)$ jointly with the generative model (decoder) $p_{\theta}(x|z)$.</p>
<p><strong>Objective Function</strong>: Maximize the variational lower bound $\mathcal{L}(\theta, \phi; x^{(i)})$:</p>
<p>$$\mathcal{L} \simeq -D_{KL}(q_\phi(z|x^{(i)}) | p_\theta(z)) + \frac{1}{L} \sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})$$</p>
<ul>
<li><strong>First Term (Regularizer)</strong>: Forces the approximate posterior to match the prior (integrated analytically for Gaussians).</li>
<li><strong>Second Term (Reconstruction Error)</strong>: The expected negative reconstruction error (estimated via sampling).</li>
</ul>
<p>This mirrors the standard auto-encoder objective, adding a variational regularizer.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was benchmarked against the <strong>Wake-Sleep</strong> algorithm and <strong>Monte Carlo EM (MCEM)</strong> using the <strong>MNIST</strong> (digits) and <strong>Frey Face</strong> (continuous faces) datasets.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li>
<p><strong>Efficiency</strong>: AEVB converged faster and reached a better lower bound than Wake-Sleep (Figure 2). It scaled efficiently to the full MNIST dataset. MCEM&rsquo;s per-datapoint sampling cost made it impractical at full dataset scale, so comparisons were limited to small subsets (Figure 3).</p>
</li>
<li>
<p><strong>Regularization</strong>: The KL-divergence term provided a regularizing effect, preventing overfitting while increasing latent dimensions ($N_z$).</p>
</li>
<li>
<p><strong>Manifold Learning</strong>: The model successfully learned smooth 2D latent manifolds (visualized in Appendix A), grouping similar digits/faces together.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Evaluation Data</strong>: For the marginal likelihood comparison (Figure 3), the paper used MNIST with $N_{\text{train}} = 100$ and $N_{\text{train}} = 5000$ to compare data efficiency (marginal log-likelihood vs. training samples seen) across algorithms. A smaller network (100 hidden units, 3 latent variables) was used for this comparison because the marginal likelihood estimator only works reliably in low-dimensional latent spaces.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Algorithm</strong>: Stochastic gradient ascent with <strong>Adagrad</strong> (global stepsizes chosen from ${0.01, 0.02, 0.1}$).</li>
<li><strong>Regularization</strong>: The objective included a weight decay term corresponding to a prior $p(\theta)=\mathcal{N}(0,I)$.</li>
<li><strong>Minibatches</strong>: Size $M=100$ with $L=1$ sample per datapoint.</li>
<li><strong>Initialization</strong>: Parameters sampled from $\mathcal{N}(0, 0.01)$.</li>
</ul>
<h3 id="models">Models</h3>
<p>The original VAE used simple Multi-Layered Perceptrons (MLPs):</p>
<ul>
<li><strong>Symmetry</strong>: The encoder and decoder were symmetric, having an equal number of hidden units.</li>
<li><strong>Hidden Units</strong>: 500 units for MNIST, 200 for Frey Face (to prevent overfitting on the smaller dataset).</li>
<li><strong>Activations</strong>: <strong>Tanh</strong> activation functions for the hidden layers.</li>
<li><strong>Latent Space</strong>: Experimented with $N_z$ ranging from 2 to 200.</li>
<li><strong>Outputs</strong>:
<ul>
<li><em>MNIST</em>: <strong>Bernoulli</strong> MLP (sigmoid output).</li>
<li><em>Frey Face</em>: <strong>Gaussian</strong> MLP, with means constrained to $(0,1)$ via sigmoid.</li>
</ul>
</li>
<li><strong>Encoder Architecture</strong>: For the Gaussian encoder, the mean $\mu$ and log-variance $\log(\sigma^2)$ are linear outputs from the shared hidden layer (they share the hidden layer weights and have separate output weights).</li>
<li><strong>Log-Variance</strong>: The encoder predicted $\log(\sigma^2)$ for numerical stability.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The paper distinguishes between two metrics:</p>
<ul>
<li><strong>Variational Lower Bound</strong>: Used as the training objective (what the model optimizes).</li>
<li><strong>Marginal Likelihood</strong>: Used for final evaluation (Figure 3). The true marginal likelihood $p_\theta(x)$ was estimated using an Importance Sampling estimator constructed from samples drawn via Hybrid Monte Carlo (HMC), as detailed in Appendix D. This estimator uses: $p_{\theta}(x^{(i)}) \simeq (\frac{1}{L}\sum \frac{q(z)}{p(z)p(x|z)})^{-1}$.</li>
</ul>
<p>This distinction is critical: the training metric (lower bound) differs from the evaluation metric (estimated marginal likelihood).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: Trained on a standard Intel Xeon CPU (approx. 40 GFLOPS); no GPUs were used.</li>
<li><strong>Training Time</strong>: Approximately 20-40 minutes per million training samples.</li>
</ul>
<h3 id="key-implementation-details-from-appendices">Key Implementation Details from Appendices</h3>
<ul>
<li><strong>Appendix A</strong>: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets.</li>
<li><strong>Appendix B</strong>: Closed-form solution for the KL divergence of two Gaussians, essential for implementing the efficient version of the estimator (Equation 10).</li>
<li><strong>Appendix C</strong>: Exact MLP equations, including the use of tanh hidden layers and specific output layers for Bernoulli vs. Gaussian data. Includes specifications for <strong>Bernoulli MLPs</strong> (binary data) and <strong>Gaussian MLPs</strong> (real-valued data).</li>
<li><strong>Appendix D</strong>: Marginal likelihood estimation protocol using Hybrid Monte Carlo (HMC) and importance sampling for evaluation (Figure 3).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Diederik P. Kingma and Max Welling. &ldquo;Auto-Encoding Variational Bayes.&rdquo; arXiv:1312.6114 [stat.ML], 2013. <a href="https://doi.org/10.48550/arXiv.1312.6114">https://doi.org/10.48550/arXiv.1312.6114</a></p>
<p><strong>Publication</strong>: ICLR 2014 (arXiv preprint December 2013)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kingma2022autoencodingvariationalbayes,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Variational Bayes}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Diederik P Kingma and Max Welling}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2013}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1312.6114}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{stat.ML}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1312.6114}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Variational_autoencoder">Wikipedia: Variational Autoencoder</a> - General overview</li>
<li><a href="https://openreview.net/forum?id=33X9fd2-9FyZd">OpenReview</a> - Original peer review with author responses</li>
<li><a href="/posts/modern-variational-autoencoder-in-pytorch/">Modern VAE in PyTorch</a> - Implementation tutorial on this site</li>
</ul>
]]></content:encoded></item><item><title>αExtractor: Chemical Info from Biomedical Literature</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/alpha-extractor/</link><pubDate>Sat, 11 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/alpha-extractor/</guid><description>αExtractor uses ResNet-Transformer to extract chemical structures from literature images, including noisy and hand-drawn molecules.</description><content:encoded><![CDATA[<h2 id="methodological-contribution-a-robust-optical-recognition-system">Methodological Contribution: A Robust Optical Recognition System</h2>
<p>This is primarily a <strong>Method</strong> ($\Psi_{\text{Method}}$) paper with a significant secondary <strong>Resource</strong> ($\Psi_{\text{Resource}}$) contribution (see the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> for more on these categories).</p>
<p>The dominant methodological contribution is the ResNet-Transformer recognition architecture that outperforms existing OCSR tools across multiple benchmarks through robustness engineering. It specifically focuses on training on 20 million synthetic images with aggressive augmentation to handle degraded image conditions. The work answers the core methodological question &ldquo;How well does this work?&rdquo; through extensive benchmarking against existing OCSR tools and ablation studies validating architectural choices.</p>
<p>The secondary resource contribution comes from releasing αExtractor as a freely available web service, correcting labeling errors in standard benchmarks (CLEF, UOB, JPO), and providing an end-to-end document processing pipeline for biomedical literature mining.</p>
<h2 id="motivation-extracting-visual-chemical-knowledge-from-biomedical-literature">Motivation: Extracting Visual Chemical Knowledge from Biomedical Literature</h2>
<p>The motivation addresses a familiar pain point in chemical informatics within a biomedical context. Vast amounts of chemical knowledge in biomedical literature exist only as images, such as molecular structures embedded in figures, chemical synthesis schemes, and compound diagrams. This visual knowledge remains effectively invisible to computational methods, which creates a massive bottleneck for drug discovery research, systematic reviews, and large-scale chemical database construction.</p>
<p>Existing OCSR tools face two critical problems when applied to biomedical literature:</p>
<ol>
<li>
<p><strong>Real-world image quality</strong>: Biomedical papers often contain low-resolution figures, images with complex backgrounds, noise from scanning/digitization, and inconsistent drawing styles across different journals and decades of publications.</p>
</li>
<li>
<p><strong>End-to-end extraction</strong>: Most OCSR systems assume the presence of clean, cropped molecular images. In practice, you need to first find the molecular structures within multi-panel figures, reaction schemes, and dense document layouts before you can recognize them.</p>
</li>
</ol>
<p>The authors argue that a practical literature mining system needs to solve both problems simultaneously via robust recognition under noisy conditions and automated detection of molecular images within complex documents.</p>
<h2 id="core-innovation-robust-resnet-transformer-architecture">Core Innovation: Robust ResNet-Transformer Architecture</h2>
<p>The core innovation lies in combining a competition-winning recognition architecture with extensive robustness engineering and end-to-end document processing. The key contributions include:</p>
<ol>
<li>
<p><strong>ResNet-Transformer Recognition Model</strong>: The core recognition system uses a <strong>Residual Neural Network (ResNet)</strong> encoder paired with a <strong>Transformer decoder</strong> in an image-captioning framework. This architecture won first place in a Kaggle molecular translation competition, which provided a strong foundation for the recognition task. Let the input image be $I$. The model maximizes the joint likelihood of the SMILES tokens $T$ and coordinate sequences $X, Y$:
$$
\begin{aligned}
\mathcal{L}_{\text{total}} = - \sum_{i=1}^{L} \log P(T_i \mid I, T_{&lt;i}) - \lambda \sum_{i=1}^{L} \big(\log P(X_i \mid I, X_{&lt;i}) + \log P(Y_i \mid I, Y_{&lt;i})\big)
\end{aligned}
$$
Here, continuous $X$ and $Y$ atom coordinates are mapped strictly to 200 discrete bins to formulate the coordinate prediction as a standard classification task alongside SMILES generation.</p>
</li>
<li>
<p><strong>Enhanced Molecular Representation</strong>: The model produces an augmented representation that encompasses:</p>
<ul>
<li>Standard molecular connectivity information</li>
<li><strong>Bond type tokens</strong> (solid wedge bonds, dashed bonds, etc.) that preserve 3D stereochemical information</li>
<li><strong>Atom coordinate predictions</strong> that allow reconstruction of the exact molecular pose from the original image</li>
</ul>
<p>This dual prediction of discrete structure and continuous coordinates makes the output strictly faithful to the source material and enables better quality assessment.</p>
</li>
<li>
<p><strong>Massive Synthetic Training Dataset</strong>: The model was trained on approximately <strong>20 million synthetic molecular images</strong> generated from PubChem SMILES with aggressive data augmentation. The augmentation strategy randomized visual styles, image quality, and rendering parameters to create maximum diversity, ensuring the network rarely saw the same molecular depiction twice. This forces the model to learn robust, style-invariant features.</p>
</li>
<li>
<p><strong>End-to-End Document Processing Pipeline</strong>: αExtractor integrates <strong>object detection</strong> and <strong>structure recognition</strong> into a complete document mining system:</p>
<ul>
<li>An object detection model automatically locates molecular images within PDF documents</li>
<li>The recognition model converts detected images to structured representations</li>
<li>A web service interface makes the entire pipeline accessible to researchers without machine learning expertise</li>
</ul>
</li>
<li>
<p><strong>Robustness-First Design</strong>: The system was explicitly designed to handle degraded image conditions that break traditional OCSR tools, including low resolution, background interference, color variations, and scanning artifacts commonly found in legacy biomedical literature.</p>
</li>
</ol>
<h2 id="experimental-methodology-stress-testing-under-real-world-conditions">Experimental Methodology: Stress Testing under Real-World Conditions</h2>
<p>The evaluation focused on demonstrating robust performance across diverse image conditions, from pristine benchmarks to challenging real-world scenarios:</p>
<ol>
<li>
<p><strong>Benchmark Dataset Evaluation</strong>: αExtractor was tested on four standard OCSR benchmarks:</p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
</ul>
<p>Performance was measured using exact SMILES match accuracy.</p>
</li>
<li>
<p><strong>Error Analysis and Dataset Correction</strong>: During evaluation, the researchers discovered numerous labeling errors in the original benchmark datasets. They systematically identified and corrected these errors, then re-evaluated all methods on the cleaned datasets to get more accurate performance measurements.</p>
</li>
<li>
<p><strong>Robustness Stress Testing</strong>: The system was evaluated on two challenging datasets specifically designed to test robustness:</p>
<ul>
<li><strong>Color background images</strong> (200 samples): Molecular structures on complex, colorful backgrounds that simulate real figure conditions</li>
<li><strong>Low-quality images</strong> (200 samples): Degraded images with noise, blur, and artifacts typical of scanned documents</li>
</ul>
<p>These tests compared αExtractor against three open-source tools (OSRA, Molvel, and Imago) under realistic degradation conditions.</p>
</li>
<li>
<p><strong>Generalization Testing</strong>: In the most challenging experiment, αExtractor was tested on the <strong>DECIMER hand-drawn molecule images dataset</strong> (Brinkhaus et al., 2022), representing a completely different visual domain not represented in the training data. This tested whether the learned features could generalize beyond digital rendering styles to human-drawn chemistry.</p>
</li>
<li>
<p><strong>End-to-End Document Extraction</strong>: The complete pipeline was evaluated on 50 PDF files containing 2,336 molecular images. This tested both the object detection component (finding molecules in complex documents) and the recognition component (converting them to SMILES) in a realistic literature mining scenario.</p>
</li>
<li>
<p><strong>Speed Benchmarking</strong>: Inference time was measured to demonstrate the practical efficiency needed for large-scale document processing.</p>
</li>
</ol>
<h2 id="results--conclusions-strong-performance-on-degraded-images">Results &amp; Conclusions: Strong Performance on Degraded Images</h2>
<ul>
<li>
<p><strong>Substantial Accuracy Gains</strong>: On the four benchmark datasets, αExtractor achieved accuracies of 91.83% (CLEF), 98.47% (UOB), 88.67% (JPO), and 93.64% (USPTO), compared to previous best results of 84.6%, 90.0%, 72.2%, and 89.9% respectively. After correcting dataset labeling errors, the true accuracies were even higher, reaching <strong>95.77% on CLEF, 99.86% on UOB, and 92.44% on JPO</strong>.</p>
</li>
<li>
<p><strong>Robustness on Degraded Images</strong>: Open-source competitors struggled on degraded images (achieving 5.5% accuracy at best). αExtractor maintained <strong>over 90% accuracy</strong> on both color background and low-quality image datasets, demonstrating the effectiveness of the synthetic training strategy.</p>
</li>
<li>
<p><strong>Generalization to Hand-Drawn Molecules</strong>: On hand-drawn molecules, a domain completely absent from training data, αExtractor achieved <strong>61.4% accuracy</strong> while other tools scored between 0.69% and 2.93%. This suggests the model learned genuinely chemical features rather than style-specific patterns.</p>
</li>
<li>
<p><strong>Practical End-to-End Performance</strong>: In the complete document processing evaluation, αExtractor detected <strong>95.1% of molecular images</strong> (2,221 out of 2,336) and correctly recognized <strong>94.5% of detected structures</strong> (2,098 correct predictions). This demonstrates the system&rsquo;s readiness for real-world literature mining applications.</p>
</li>
<li>
<p><strong>Ablation Results</strong>: Ablation experiments confirmed that each architectural component (ResNet backbone, Transformer encoder, Transformer decoder) contributes to performance, with the Transformer decoder having the largest impact. Replacing the Transformer decoder with an LSTM decoder substantially reduced accuracy (Table S6 in the paper).</p>
</li>
<li>
<p><strong>Dataset Quality Issues</strong>: The systematic discovery of labeling errors in standard benchmarks highlights a broader problem in OCSR evaluation. The corrected datasets provide more reliable baselines for future method development.</p>
</li>
<li>
<p><strong>Spatial Layout Limitation</strong>: αExtractor correctly identifies molecular connectivity, but the re-rendered structures may have different spatial layouts than the originals. This could complicate visual verification for complex molecules, even if the chemical information remains accurate.</p>
</li>
<li>
<p><strong>Non-Standard Depiction Handling</strong>: For images with non-standard bond depictions or atomic valences, αExtractor correctly identifies and normalizes them to standard representations. While chemically accurate, this means the re-rendered structure may visually differ from the original image.</p>
</li>
</ul>
<p>Overall, αExtractor combines accurate recognition (over 90% on degraded images), end-to-end document processing, and strong generalization across image conditions. It targets large-scale literature mining tasks where previous tools struggled with degraded inputs. The focus on real-world robustness over benchmark optimization reflects a practical approach to deploying machine learning in scientific workflows.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>This paper is <strong>Partially Reproducible</strong>. While the authors detail the model architectures and training techniques, the source code, training dataset (20M synthetic images), and pre-trained weights remain closed-source and proprietary. The authors released a sample of their test data and host an online web server for running inference.</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/CLEF_corrected">Corrected CLEF Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the CLEF benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/UOB_corrected">Corrected UOB Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the UOB benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/JPO_corrected">Corrected JPO Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Authors&rsquo; corrected version of the JPO benchmark.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Colored_Background">Color Background Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of molecular structures on complex, colorful backgrounds.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/Low_Quality">Low Quality Dataset</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">200 samples of degraded images with noise, blur, and artifacts.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://github.com/jiachengxiong/alpha-Extractor/tree/main/PDF">PDF Test Set</a></td>
          <td style="text-align: left">Dataset</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Sample PDF files for end-to-end document extraction evaluation.</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://extractor.alphama.com.cn/csr">αExtractor Web Server</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Online service for running inference using the proprietary system.</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Backbone:</strong> ResNet50 producing output of shape $2048 \times 19 \times 19$, projected to 512 channels via a feed-forward layer</li>
<li><strong>Transformer Architecture:</strong> 3 encoder layers and 3 decoder layers with hidden dimension of 512</li>
<li><strong>Output Format:</strong> Generates SMILES tokens plus two auxiliary coordinate sequences (X-axis and Y-axis) that are length-aligned with the SMILES tokens via padding</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Architecture:</strong> DETR (Detection Transformer) with ResNet101 backbone</li>
<li><strong>Transformer Architecture:</strong> 6 encoder layers and 6 decoder layers with hidden dimension of 256</li>
<li><strong>Purpose:</strong> Locates molecular images within PDF pages before recognition</li>
</ul>
<p><strong>Coordinate Prediction:</strong></p>
<ul>
<li>Continuous X/Y coordinates are discretized into <strong>200 discrete bins</strong></li>
<li>Padding tokens added to coordinate sequences to align perfectly with SMILES token sequence, enabling simultaneous structure and pose prediction</li>
</ul>
<h3 id="data">Data</h3>
<p><strong>Training Data:</strong></p>
<ul>
<li><strong>Synthetic Generation:</strong> Python script rendering PubChem SMILES into 2D images</li>
<li><strong>Dataset Size:</strong> Approximately 20.3 million synthetic molecular images from PubChem</li>
<li><strong>Superatom Handling:</strong> 50% of molecules had functional groups replaced with superatoms (e.g., &ldquo;COOH&rdquo;) or generic labels (R1, X1) to match literature drawing conventions</li>
<li><strong>Rendering Augmentation:</strong> Randomized bond thickness, bond spacing, font size, font color, and padding size</li>
</ul>
<p><strong>Geometric Augmentation:</strong></p>
<ul>
<li>Shear along x-axis: $\pm 15^\circ$</li>
<li>Rotation: $\pm 15^\circ$</li>
<li>Piecewise affine scaling</li>
</ul>
<p><strong>Noise Injection:</strong></p>
<ul>
<li>Pepper noise: 0-2%</li>
<li>Salt noise: 0-40%</li>
<li>Gaussian noise: scale 0-0.16</li>
</ul>
<p><strong>Destructive Augmentation:</strong></p>
<ul>
<li>JPEG compression: severity levels 2-5</li>
<li>Random masking</li>
</ul>
<p><strong>Evaluation Datasets:</strong></p>
<ul>
<li><strong>CLEF</strong>: Chemical structure recognition challenge dataset</li>
<li><strong>UOB</strong>: University of Birmingham patent images</li>
<li><strong>JPO</strong>: Japan Patent Office molecular diagrams</li>
<li><strong>USPTO</strong>: US Patent and Trademark Office structures</li>
<li><strong>Color background images</strong>: 200 samples</li>
<li><strong>Low-quality images</strong>: 200 samples</li>
<li><strong>Hand-drawn structures</strong>: Test set for generalization</li>
<li><strong>End-to-end document extraction</strong>: 50 PDFs (567 pages, 2,336 molecular images)</li>
</ul>
<h3 id="training">Training</h3>
<p><strong>Image Recognition Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 100</li>
<li><strong>Epochs:</strong> 5</li>
<li><strong>Loss Function:</strong> Cross-entropy loss for both SMILES prediction and coordinate prediction</li>
</ul>
<p><strong>Object Detection Model:</strong></p>
<ul>
<li><strong>Optimizer:</strong> Adam with learning rate of 1e-4</li>
<li><strong>Batch Size:</strong> 24</li>
<li><strong>Training Strategy:</strong> Pre-trained on synthetic &ldquo;Lower Quality&rdquo; data for 5 epochs, then fine-tuned on annotated real &ldquo;High Quality&rdquo; data for 30 epochs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics:</strong></p>
<ul>
<li><strong>Recognition</strong>: SMILES accuracy (exact match)</li>
<li><strong>End-to-End Pipeline</strong>:
<ul>
<li><strong>Recall</strong>: 95.1% for detection</li>
<li><strong>Accuracy</strong>: 94.5% for recognition</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Inference Hardware:</strong></p>
<ul>
<li>Cloud CPU server (8 CPUs, 64 GB RAM)</li>
<li><strong>Throughput:</strong> Processed 50 PDFs (567 pages) in 40 minutes</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Xiong, J., Liu, X., Li, Z., Xiao, H., Wang, G., Niu, Z., Fei, C., Zhong, F., Wang, G., Zhang, W., Fu, Z., Liu, Z., Chen, K., Jiang, H., &amp; Zheng, M. (2023). αExtractor: a system for automatic extraction of chemical information from biomedical literature. <em>Science China Life Sciences</em>, 67(3), 618-621. <a href="https://doi.org/10.1007/s11427-023-2388-x">https://doi.org/10.1007/s11427-023-2388-x</a></p>
<p><strong>Publication</strong>: Science China Life Sciences (2023)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://doi.org/10.1007/s11427-023-2388-x">Paper on Springer</a></li>
</ul>
]]></content:encoded></item><item><title>MolNexTR: A Dual-Stream Molecular Image Recognition</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molnextr/</link><pubDate>Sat, 04 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molnextr/</guid><description>Dual-stream encoder combining ConvNext and ViT for robust optical chemical structure recognition across diverse molecular drawing styles.</description><content:encoded><![CDATA[<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chen, Y., Leung, C. T., Huang, Y., Sun, J., Chen, H., &amp; Gao, H. (2024). MolNexTR: a generalized deep learning model for molecular image recognition. <em>Journal of Cheminformatics</em>, 16(141). <a href="https://doi.org/10.1186/s13321-024-00926-w">https://doi.org/10.1186/s13321-024-00926-w</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics 2024</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/CYF2000127/MolNexTR">GitHub Repository</a></li>
<li><a href="https://huggingface.co/datasets/CYF200127/MolNexTR/tree/main">HuggingFace Dataset/Model</a></li>
</ul>
<h2 id="methodology-overview-and-taxonomic-classification">Methodology Overview and Taxonomic Classification</h2>
<p>This is a <strong>Method</strong> paper ($\Psi_{\text{Method}}$). It proposes a neural network architecture (MolNexTR) that integrates ConvNext and Vision Transformers to solve the Optical Chemical Structure Recognition (OCSR) task. The paper validates this method through ablation studies and benchmarking against existing methods including MolScribe and DECIMER.</p>
<h2 id="the-challenge-of-domain-specific-drawing-styles-in-ocsr">The Challenge of Domain-Specific Drawing Styles in OCSR</h2>
<p>Converting molecular images from chemical literature into machine-readable formats (SMILES) is critical but challenging due to the high variance in drawing styles, fonts, and conventions (e.g., Markush structures, abbreviations). Existing methods have limitations:</p>
<ul>
<li>CNN-based and ViT-based models often struggle to generalize across diverse, non-standard drawing styles found in real literature.</li>
<li>Pure ViT methods lack translation invariance and local feature representation, while pure CNNs struggle with global dependencies.</li>
<li>Many models predict SMILES strings directly, making it difficult to enforce chemical validity or resolve complex stereochemistry and abbreviations.</li>
</ul>
<h2 id="core-innovation-dual-stream-encoding-and-image-contamination">Core Innovation: Dual-Stream Encoding and Image Contamination</h2>
<p>MolNexTR introduces three main innovations:</p>
<ol>
<li><strong>Dual-Stream Encoder</strong>: A hybrid architecture processing images simultaneously through a ConvNext stream (for local features) and a Vision Transformer stream (for long-range dependencies), fusing them to capture multi-scale information.</li>
<li><strong>Image Contamination Augmentation</strong>: A specialized data augmentation algorithm that simulates real-world &ldquo;noise&rdquo; found in literature, such as overlapping text, arrows, and partial molecular fragments, to improve robustness.</li>
<li><strong>Graph-Based Decoding with Post-Processing</strong>: Unlike pure image-to-SMILES translation, it predicts atoms and bonds (graph generation) and uses a stereochemical discrimination and abbreviation self-correction module to enforce chemical rules (e.g., chirality) and resolve superatoms (e.g., &ldquo;Ph&rdquo;, &ldquo;Bn&rdquo;).</li>
</ol>
<p>The prediction of atom labels and coordinates is formulated as a conditional autoregressive generation task, optimized via a cross-entropy loss:
$$ \mathcal{L}_{\text{atom}} = -\sum_{t=1}^{T} \log P(x_t \mid \text{Image}, x_{&lt;t}) $$</p>
<h2 id="experimental-setup-benchmarking-on-synthetic-and-real-data">Experimental Setup: Benchmarking on Synthetic and Real Data</h2>
<p>The model was trained on synthetic data (PubChem) and real patent data (USPTO). It was evaluated on nine benchmarks (three synthetic, six real-world):</p>
<ul>
<li><strong>Synthetic</strong>: Indigo, ChemDraw, RDKit (rendered from 5,719 molecules)</li>
<li><strong>Real-World</strong>: CLEF, UOB, JPO, USPTO, Staker, and a newly curated ACS dataset (diverse styles)</li>
</ul>
<p><strong>Baselines</strong>: Compared against rule-based (OSRA, MolVec) and deep learning models (MolScribe, DECIMER, SwinOCSR, Img2Mol).</p>
<p><strong>Ablations</strong>: Tested the impact of the dual-stream encoder vs. single streams, and the contribution of individual augmentation strategies.</p>
<h2 id="empirical-results-and-robustness-findings">Empirical Results and Robustness Findings</h2>
<ul>
<li><strong>Performance</strong>: MolNexTR achieved 81-97% accuracy across test sets, outperforming the second-best method (often MolScribe) by margins of 0.3% to 10.0% (on the difficult ACS dataset).</li>
<li><strong>Perturbation resilience</strong>: The model maintained higher accuracy under image perturbations (rotation, noise) and &ldquo;curved arrow&rdquo; noise common in reaction mechanisms compared to MolScribe and DECIMER (Table 3).</li>
<li><strong>Ablation Results</strong>: The dual-stream encoder consistently outperformed single CNN or ViT baselines, and the image contamination algorithm significantly boosted performance on noisy real-world data (ACS).</li>
<li><strong>Limitations</strong>: The model still struggles with extremely complex hand-drawn molecules and mechanism diagrams where arrows or text are conflated with structure. The authors also note that R-group information in real literature often appears in separate text or tables, which the model does not incorporate.</li>
</ul>
<p><strong>Key Results (Table 2, SMILES exact match accuracy %)</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MolScribe</th>
          <th>MolNexTR</th>
          <th>Improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Indigo</td>
          <td>97.5</td>
          <td>97.8</td>
          <td>+0.3</td>
      </tr>
      <tr>
          <td>ChemDraw</td>
          <td>93.8</td>
          <td>95.1</td>
          <td>+1.3</td>
      </tr>
      <tr>
          <td>RDKit</td>
          <td>94.6</td>
          <td>96.4</td>
          <td>+1.8</td>
      </tr>
      <tr>
          <td>CLEF</td>
          <td>88.3</td>
          <td>90.4</td>
          <td>+2.1</td>
      </tr>
      <tr>
          <td>UOB</td>
          <td>87.9</td>
          <td>88.5</td>
          <td>+0.6</td>
      </tr>
      <tr>
          <td>JPO</td>
          <td>77.7</td>
          <td>82.1</td>
          <td>+4.4</td>
      </tr>
      <tr>
          <td>USPTO</td>
          <td>92.6</td>
          <td>93.8</td>
          <td>+1.2</td>
      </tr>
      <tr>
          <td>Staker</td>
          <td>86.9</td>
          <td>88.3</td>
          <td>+1.4</td>
      </tr>
      <tr>
          <td>ACS</td>
          <td>71.9</td>
          <td>81.9</td>
          <td>+10.0</td>
      </tr>
  </tbody>
</table>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data</strong>:</p>
<ul>
<li><strong>Synthetic</strong>: ~1M molecules randomly selected from PubChem, rendered using RDKit and Indigo with varied styles (thickness, fonts, bond width)</li>
<li><strong>Real</strong>: 0.68M images from USPTO, with coordinates normalized from MOLfiles</li>
</ul>
<p><strong>Augmentation</strong>:</p>
<ul>
<li><strong>Render Augmentation</strong>: Randomized drawing styles (line width, font size, label modes)</li>
<li><strong>Image Augmentation</strong>: Rotation, cropping, blurring, noise (Gaussian, salt-and-pepper)</li>
<li><strong>Molecular Augmentation</strong>: Randomly replacing functional groups with abbreviations (from a list of &gt;100) or complex chains (e.g., CH3CH2NH2); adding R-groups</li>
<li><strong>Image Contamination</strong>: Adding &ldquo;noise&rdquo; objects (arrows, lines, text, partial structures) at a minimum distance from the main molecule to simulate literature artifacts</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Dual-Stream Encoder</strong>:</p>
<ul>
<li><strong>CNN Stream</strong>: ConvNext backbone (pre-trained on ImageNet), generating feature maps at scales $H/4$ to $H/32$</li>
<li><strong>ViT Stream</strong>: Parallel transformer blocks receiving patches of sizes $p=4, 8, 16, 32$. Uses Multi-Head Self-Attention (MHSA) and Feed-Forward Networks (FFN)</li>
<li><strong>Fusion</strong>: Outputs from both streams are concatenated</li>
</ul>
<p><strong>Decoder (Graph Generation)</strong>:</p>
<ul>
<li><strong>Transformer Decoder</strong>: 6 layers, 8 heads, hidden dim 256</li>
<li><strong>Task 1 (Atoms)</strong>: Autoregressive prediction of atom tokens $(l, x, y)$ (label + coordinates)</li>
<li><strong>Task 2 (Bonds)</strong>: Prediction of bond types between atom pairs (None, Single, Double, Triple, Aromatic, Solid Wedge, Dashed Wedge)</li>
</ul>
<p><strong>Post-Processing</strong>:</p>
<ul>
<li><strong>Stereochemistry</strong>: Uses predicted coordinates and bond types (wedge/dash) to resolve chirality using RDKit logic</li>
<li><strong>Abbreviation Correction</strong>: Matches superatoms to a dictionary; if unknown, attempts to greedily connect atoms based on valence or finds the nearest match ($\sigma=0.8$ similarity threshold)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: Encoder-Decoder (ConvNext + ViT Encoder -&gt; Transformer Decoder)</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: ADAM (max lr 3e-4, linear warmup for 5% of steps)</li>
<li>Batch Size: 256</li>
<li>Image Size: $384 \times 384$</li>
<li>Dropout: 0.1</li>
</ul>
</li>
<li><strong>Training</strong>: Fine-tuned CNN backbone for 40 epochs on 10 NVIDIA RTX 3090 GPUs</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Primary Metric</strong>: SMILES sequence exact matching accuracy (canonicalized)</p>
<p><strong>Benchmarks</strong>:</p>
<ul>
<li><strong>Synthetic</strong>: Indigo (5,719), ChemDraw (5,719), RDKit (5,719)</li>
<li><strong>Real</strong>: CLEF (992), UOB (5,740), JPO (450), USPTO (5,719), Staker (50,000), ACS (331)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPUs</strong>: 10 NVIDIA RTX 3090 GPUs</li>
<li><strong>Cluster</strong>: HPC3 Cluster at HKUST (ITSC)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/CYF2000127/MolNexTR">MolNexTR GitHub</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official implementation (PyTorch, Jupyter notebooks)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/CYF200127/MolNexTR">MolNexTR HuggingFace</a></td>
          <td>Dataset/Model</td>
          <td>Apache-2.0</td>
          <td>Training data and model checkpoint</td>
      </tr>
  </tbody>
</table>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{chenMolNexTRGeneralizedDeep2024,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{{MolNexTR}: a generalized deep learning model for molecular image recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chen, Yufan and Leung, Ching Ting and Huang, Yong and Sun, Jianwei and Chen, Hao and Gao, Hanyu}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{16}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{141}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-024-00926-w}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>MolParser-7M &amp; WildMol: Large-Scale OCSR Datasets</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molparser_7m-wildmol/</link><pubDate>Fri, 03 Oct 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/molparser_7m-wildmol/</guid><description>MolParser-7M is the largest open-source OCSR dataset with 7.7M image-SMILES pairs including 400k real-world annotated samples.</description><content:encoded><![CDATA[<h2 id="dataset-examples">Dataset Examples</h2>















<figure class="post-figure center ">
    <img src="/img/molparser-markush-example.webp"
         alt="Example of a complex Markush structure"
         title="Example of a complex Markush structure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">An example of a complex Markush structure that can be represented by the E-SMILES format but not by standard SMILES or FG-SMILES.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-low-quality-example.webp"
         alt="Sample from the WildMol benchmark"
         title="Sample from the WildMol benchmark"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A sample from the WildMol benchmark, showing a low-quality, noisy molecular image cropped from real-world literature that challenges OCSR systems.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/molparser-colored-example.webp"
         alt="Colored molecule with annotations"
         title="Colored molecule with annotations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A colored molecule with annotations, representing the diverse drawing styles found in scientific papers that OCSR models must handle.</figcaption>
    
</figure>

<h2 id="dataset-subsets">Dataset Subsets</h2>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Count</th>
          <th>Description</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>MolParser-7M (Training Set)</strong></td>
          <td>7,740,871</td>
          <td>A large-scale dataset for training OCSR models, split into pre-training and fine-tuning stages.</td>
      </tr>
      <tr>
          <td><strong>WildMol (Test Set)</strong></td>
          <td>20,000</td>
          <td>A benchmark of 20,000 human-annotated samples cropped from real PDF files to evaluate OCSR models in &lsquo;in-the-wild&rsquo; scenarios. Comprises WildMol-10k (10k ordinary molecules) and WildMol-10k-M (10k Markush structures).</td>
      </tr>
  </tbody>
</table>
<h2 id="benchmarks">Benchmarks</h2>

<div class="benchmarks-content">
  <div class="benchmark-section">
    <h3 id="wildmol-10k-accuracy">WildMol-10K Accuracy<a hidden class="anchor" aria-hidden="true" href="#wildmol-10k-accuracy">#</a></h3>
    <p class="benchmark-description">Evaluation of OCSR models on 10,000 real-world molecular images cropped from scientific literature and patents</p>
    <table class="benchmark-table">
      <thead>
        <tr>
          <th>Rank</th>
          <th>Model</th>
          <th>Accuracy (%)</th>
        </tr>
      </thead>
      <tbody>
        <tr class="top-result">
          <td>🥇 1</td>
          <td>
            <strong>MolParser-Base</strong><br><small>End-to-end visual recognition trained on MolParser-7M</small>
          </td>
          <td>76.9</td>
        </tr>
        <tr class="top-result">
          <td>🥈 2</td>
          <td>
            <strong>MolScribe</strong><br><small>Transformer-based OCSR system</small>
          </td>
          <td>66.4</td>
        </tr>
        <tr class="top-result">
          <td>🥉 3</td>
          <td>
            <strong>DECIMER 2.7</strong><br><small>Deep learning for chemical image recognition</small>
          </td>
          <td>56</td>
        </tr>
        <tr>
          <td>4</td>
          <td>
            <strong>MolGrapher</strong><br><small>Graph-based molecular structure recognition</small>
          </td>
          <td>45.5</td>
        </tr>
        <tr>
          <td>5</td>
          <td>
            <strong>MolVec 0.9.7</strong><br><small>Vector-based structure recognition</small>
          </td>
          <td>26.4</td>
        </tr>
        <tr>
          <td>6</td>
          <td>
            <strong>OSRA 2.1</strong><br><small>Optical Structure Recognition Application</small>
          </td>
          <td>26.3</td>
        </tr>
        <tr>
          <td>7</td>
          <td>
            <strong>Img2Mol</strong><br><small>Image-to-molecule translation</small>
          </td>
          <td>24.4</td>
        </tr>
        <tr>
          <td>8</td>
          <td>
            <strong>Imago 2.0</strong><br><small>Chemical structure recognition toolkit</small>
          </td>
          <td>6.9</td>
        </tr>
      </tbody>
    </table>
  </div>
</div>

<h2 id="key-contribution">Key Contribution</h2>
<p>Introduces MolParser-7M, the largest open-source Optical Chemical Structure Recognition (OCSR) dataset, uniquely combining diverse synthetic data with a large volume of manually-annotated, &ldquo;in-the-wild&rdquo; images from real scientific documents to improve model robustness. Also introduces WildMol, a new challenging benchmark for evaluating OCSR performance on real-world data, including Markush structures.</p>
<h2 id="overview">Overview</h2>
<p>The MolParser project addresses the challenge of recognizing molecular structures from images found in real-world scientific documents. Unlike existing OCSR datasets that rely primarily on synthetically generated images, MolParser-7M incorporates 400,000 manually annotated images cropped from actual patents and scientific papers, making it the first large-scale dataset to bridge the gap between synthetic training data and real-world deployment scenarios.</p>
<h2 id="strengths">Strengths</h2>
<ul>
<li>Largest open-source OCSR dataset with over 7.7 million pairs</li>
<li>The only large-scale OCSR training set that includes a significant amount (400k) of &ldquo;in-the-wild&rdquo; data cropped from real patents and literature</li>
<li>High diversity of molecular structures from numerous sources (PubChem, ChEMBL, polymers, etc.)</li>
<li>Introduces the WildMol benchmark for evaluating performance on challenging, real-world data, including Markush structures</li>
<li>The &ldquo;in-the-wild&rdquo; fine-tuning data (MolParser-SFT-400k) was curated via an efficient active learning data engine with human-in-the-loop validation</li>
</ul>
<h2 id="limitations">Limitations</h2>
<ul>
<li>The E-SMILES format cannot represent certain complex cases, such as coordination bonds, dashed abstract rings, Markush structures depicted with special patterns, and replication of long structural segments on the skeleton</li>
<li>The model and data do not yet fully exploit molecular chirality, which is critical for chemical properties</li>
<li>Performance could be further improved by scaling up the amount of real annotated training data</li>
</ul>
<h2 id="technical-notes">Technical Notes</h2>
<h3 id="synthetic-data-generation">Synthetic Data Generation</h3>
<p>To ensure diversity, molecular structures were collected from databases like ChEMBL, PubChem, and Kaggle BMS. A significant number of Markush, polymer, and fused-ring structures were also randomly generated. Images were rendered using RDKit and epam.indigo with randomized parameters (e.g., bond width, font size, rotation) to increase visual diversity. The pretraining dataset is composed of the following subsets:</p>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Ratio</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Markush-3M</td>
          <td>40%</td>
          <td>Random groups replacement from PubChem</td>
      </tr>
      <tr>
          <td>ChEMBL-2M</td>
          <td>27%</td>
          <td>Molecules selected from ChEMBL</td>
      </tr>
      <tr>
          <td>Polymer-1M</td>
          <td>14%</td>
          <td>Randomly generated polymer molecules</td>
      </tr>
      <tr>
          <td>PAH-600k</td>
          <td>8%</td>
          <td>Randomly generated fused-ring molecules</td>
      </tr>
      <tr>
          <td>BMS-360k</td>
          <td>5%</td>
          <td>Molecules with long carbon chains from BMS</td>
      </tr>
      <tr>
          <td>MolGrapher-300K</td>
          <td>4%</td>
          <td>Training data from MolGrapher</td>
      </tr>
      <tr>
          <td>Pauling-100k</td>
          <td>2%</td>
          <td>Pauling-style images drawn using epam.indigo</td>
      </tr>
  </tbody>
</table>
<h3 id="in-the-wild-data-engine-molparser-sft-400k">In-the-Wild Data Engine (MolParser-SFT-400k)</h3>
<p>A YOLO11 object detection model (MolDet) located and cropped over 20 million molecule images from 1.22 million real PDFs (patents and papers). After de-duplication via p-hash similarity, 4 million unique images remained.</p>
<p>An active learning algorithm was used to select the most informative samples for annotation, targeting images where an ensemble of 5-fold models showed moderate confidence (0.6-0.9 Tanimoto similarity), indicating they were challenging but learnable.</p>
<p>This active learning approach with model pre-annotations reduced manual annotation time per molecule to 30 seconds, approximately 90% savings compared to annotating from scratch. In the final fine-tuning dataset, 56.04% of annotations directly utilized raw model pre-annotations, 20.97% passed review after a single manual correction, 13.87% were accepted after a second round of annotation, and 9.13% required three or more rounds.</p>
<p>The fine-tuning dataset is composed of:</p>
<table>
  <thead>
      <tr>
          <th>Subset</th>
          <th>Ratio</th>
          <th>Source</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MolParser-SFT-400k</td>
          <td>66%</td>
          <td>Manually annotated data obtained via data engine</td>
      </tr>
      <tr>
          <td>MolParser-Gen-200k</td>
          <td>32%</td>
          <td>Synthetic data selected from pretraining stage</td>
      </tr>
      <tr>
          <td>Handwrite-5k</td>
          <td>1%</td>
          <td>Handwritten molecules selected from Img2Mol</td>
      </tr>
  </tbody>
</table>
<h3 id="e-smiles-specification">E-SMILES Specification</h3>
<p>To accommodate complex patent structures that standard SMILES cannot support, the authors introduced an Extended SMILES format (<code>SMILES&lt;sep&gt;EXTENSION</code>). The <code>EXTENSION</code> component uses XML-like tokens to manage complexities:</p>
<ul>
<li><code>&lt;a&gt;...&lt;/a&gt;</code> encapsulates Markush R-groups and abbreviation groups.</li>
<li><code>&lt;r&gt;...&lt;/r&gt;</code> denotes ring attachments with uncertainty positions.</li>
<li><code>&lt;c&gt;...&lt;/c&gt;</code> defines abstract rings.</li>
<li><code>&lt;dum&gt;</code> identifies a connection point.</li>
</ul>
<p>This format enables Markush-molecule matching and LLM integration, while retaining RDKit compatibility for the standard SMILES portion.</p>
<h2 id="reproducibility">Reproducibility</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/datasets/UniParser/MolParser-7M">MolParser-7M</a></td>
          <td>Dataset</td>
          <td>CC-BY-NC-SA-4.0</td>
          <td>Training and test data on HuggingFace. SFT subset is partially released.</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/UniParser/MolDet">MolDet (YOLO11)</a></td>
          <td>Model</td>
          <td>Unknown</td>
          <td>Molecule detection model on HuggingFace</td>
      </tr>
      <tr>
          <td><a href="https://ocsr.dp.tech/">MolParser Demo</a></td>
          <td>Other</td>
          <td>N/A</td>
          <td>Online OCSR demo using MolParser-Base</td>
      </tr>
  </tbody>
</table>
<p>The dataset is publicly available on HuggingFace under a CC-BY-NC-SA-4.0 (non-commercial) license. The MolParser-SFT-400k subset is only partially released. The YOLO11-based MolDet detection model is also available on HuggingFace. No public code repository is provided for the MolParser recognition model itself. All experiments were conducted on 8 NVIDIA RTX 4090D GPUs, and throughput benchmarks were measured on a single RTX 4090D GPU.</p>
]]></content:encoded></item><item><title>DenoiseVAE: Adaptive Noise for Molecular Pre-training</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/</link><pubDate>Sun, 24 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/denoise-vae/</guid><description>Liu et al.'s ICLR 2025 paper introducing DenoiseVAE, which learns adaptive, atom-specific noise distributions for better molecular force fields.</description><content:encoded><![CDATA[<h2 id="paper-contribution-type">Paper Contribution Type</h2>
<p>This is a <strong>method paper</strong> with a supporting theoretical component. It introduces a new pre-training framework, DenoiseVAE, that challenges the standard practice of using fixed, hand-crafted noise distributions in denoising-based molecular representation learning.</p>
<h2 id="motivation-the-inter--and-intra-molecular-variations-problem">Motivation: The Inter- and Intra-molecular Variations Problem</h2>
<p>The motivation is to create a more physically principled denoising pre-training task for 3D molecules. The core idea of denoising is to learn molecular force fields by corrupting an equilibrium conformation with noise and then learning to recover it. However, existing methods use a single, hand-crafted noise strategy (e.g., Gaussian noise of a fixed scale) for all atoms across all molecules. This is physically unrealistic for two main reasons:</p>
<ol>
<li><strong>Inter-molecular differences</strong>: Different molecules have unique Potential Energy Surfaces (PES), meaning the space of low-energy (i.e., physically plausible) conformations is highly molecule-specific.</li>
<li><strong>Intra-molecular differences (Anisotropy)</strong>: Within a single molecule, different atoms have different degrees of freedom. For instance, an atom in a rigid functional group can move much less than one connected by a single, rotatable bond.</li>
</ol>
<p>The authors argue that this &ldquo;one-size-fits-all&rdquo; noise approach leads to inaccurate force field learning because it samples many physically improbable conformations.</p>
<h2 id="novelty-a-learnable-atom-specific-noise-generator">Novelty: A Learnable, Atom-Specific Noise Generator</h2>
<p>The core novelty is a framework that learns to generate noise tailored to each specific molecule and atom. This is achieved through three key innovations:</p>
<ol>
<li><strong>Learnable Noise Generator</strong>: The authors introduce a Noise Generator module (a 4-layer Equivariant Graph Neural Network) that takes a molecule&rsquo;s equilibrium conformation $X$ as input and outputs a unique, atom-specific Gaussian noise distribution (i.e., a different variance $\sigma_i^2$ for each atom $i$). This directly addresses the issues of PES specificity and force field anisotropy.</li>
<li><strong>Variational Autoencoder (VAE) Framework</strong>: The Noise Generator (encoder) and a Denoising Module (a 7-layer EGNN decoder) are trained jointly within a VAE paradigm. The noisy conformation is sampled using the reparameterization trick:
$$
\begin{aligned}
\tilde{x}_i &amp;= x_i + \epsilon \sigma_i
\end{aligned}
$$</li>
<li><strong>Principled Optimization Objective</strong>: The training loss balances two competing goals:
$$
\begin{aligned}
\mathcal{L}_{DenoiseVAE} &amp;= \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}
\end{aligned}
$$
<ul>
<li>A denoising reconstruction loss ($\mathcal{L}_{Denoise}$) encourages the Noise Generator to produce physically plausible perturbations from which the original conformation can be recovered. This implicitly constrains the noise to respect the molecule&rsquo;s underlying force fields.</li>
<li>A KL divergence regularization term ($\mathcal{L}_{KL}$) pushes the generated noise distributions towards a predefined prior. This prevents the trivial solution of generating zero noise and encourages the model to explore a diverse set of low-energy conformations.</li>
</ul>
</li>
</ol>
<p>The authors also provide a theoretical analysis showing that optimizing their objective is equivalent to maximizing the Evidence Lower Bound (ELBO) on the log-likelihood of observing physically realistic conformations.</p>
<h2 id="methodology--experimental-baselines">Methodology &amp; Experimental Baselines</h2>
<p>The model was pretrained on the PCQM4Mv2 dataset (approximately 3.4 million organic molecules) and then evaluated on a comprehensive suite of downstream tasks to test the quality of the learned representations:</p>
<ol>
<li><strong>Molecular Property Prediction (<a href="/notes/chemistry/datasets/qm9/">QM9</a>)</strong>: The model was evaluated on 12 quantum chemical property prediction tasks for small molecules (134k molecules; 100k train, 18k val, 13k test split). DenoiseVAE achieved state-of-the-art or second-best performance on 11 of the 12 tasks, with particularly significant gains on $C_v$ (heat capacity), indicating better capture of vibrational modes.</li>
<li><strong>Force Prediction (MD17)</strong>: The task was to predict atomic forces from molecular dynamics trajectories for 8 different small molecules (9,500 train, 500 val split). DenoiseVAE was the top performer on 5 of the 8 molecules (Aspirin, Benzene, Ethanol, Naphthalene, Toluene), though it underperformed Frad on Malonaldehyde, Salicylic Acid, and Uracil by significant margins.</li>
<li><strong>Ligand Binding Affinity (PDBBind v2019)</strong>: On the PDBBind dataset with 30% and 60% protein sequence identity splits, the model showed strong generalization, outperforming baselines like Uni-Mol particularly on the more stringent 30% split across RMSE, Pearson correlation, and Spearman correlation.</li>
<li><strong>PCQM4Mv2 Validation</strong>: DenoiseVAE achieved a validation MAE of 0.0777 on the PCQM4Mv2 HOMO-LUMO gap prediction task with only 1.44M parameters, competitive with models 10-40x larger (e.g., GPS++ at 44.3M params achieves 0.0778).</li>
<li><strong>Ablation Studies</strong>: The authors analyzed the sensitivity to key hyperparameters, namely the prior&rsquo;s standard deviation ($\sigma$) and the KL-divergence weight ($\lambda$), confirming that $\lambda=1$ and $\sigma=0.1$ are optimal. Removing the KL term leads to trivial solutions (near-zero noise). An additional ablation on the Noise Generator depth found 4 EGNN layers optimal over 2 layers. A comparison of independent (diagonal) versus non-independent (full covariance) noise sampling showed comparable results, suggesting the EGNN already captures inter-atomic dependencies implicitly.</li>
<li><strong>Case Studies</strong>: Visualizations of the learned noise variances for different molecules confirmed that the model learns chemically intuitive noise patterns. For example, it applies smaller perturbations to atoms in a rigid bicyclic norcamphor derivative and larger ones to atoms in flexible functional groups of a cyclopropane derivative. Even identical functional groups (e.g., hydroxyl) receive different noise scales in different molecular contexts.</li>
</ol>
<h2 id="key-findings-on-force-field-learning">Key Findings on Force Field Learning</h2>
<ul>
<li><strong>Primary Conclusion</strong>: Learning a <strong>molecule-adaptive and atom-specific</strong> noise distribution is a superior strategy for denoising-based pre-training compared to using fixed, hand-crafted heuristics. This more physically-grounded approach leads to representations that better capture molecular force fields.</li>
<li><strong>Strong Benchmark Performance</strong>: DenoiseVAE achieves best or second-best results on 11 of 12 QM9 tasks, 5 of 8 MD17 molecules, and leads on the stringent 30% LBA split. Performance is mixed on some MD17 molecules (Malonaldehyde, Salicylic Acid, Uracil), where it trails Frad.</li>
<li><strong>Effective Framework</strong>: The proposed VAE-based framework, which jointly trains a Noise Generator and a Denoising Module, is an effective and theoretically sound method for implementing this adaptive noise strategy. The interplay between the reconstruction loss and the KL-divergence regularization is key to its success.</li>
<li><strong>Limitation and Future Direction</strong>: The method is based on classical force field assumptions. The authors note that integrating more accurate force fields represents a promising direction for future work.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/Serendipity-r/DenoiseVAE">Serendipity-r/DenoiseVAE</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official implementation</td>
      </tr>
  </tbody>
</table>
<h3 id="reproducibility-status">Reproducibility Status</h3>
<ul>
<li><strong>Source Code</strong>: The authors have released their code at <a href="https://github.com/Serendipity-r/DenoiseVAE">Serendipity-r/DenoiseVAE</a> on GitHub. No license is specified in the repository.</li>
<li><strong>Implementation</strong>: Hyperparameters and architectures are detailed in the paper&rsquo;s appendix (A.14), and the repository provides reference implementations.</li>
</ul>
<h3 id="data">Data</h3>
<ul>
<li><strong>Pre-training Dataset</strong>: <a href="https://ogb.stanford.edu/docs/lsc/pcqm4mv2/">PCQM4Mv2</a> (approximately 3.4 million organic molecules)</li>
<li><strong>Property Prediction</strong>: <a href="https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.QM9.html">QM9 dataset</a> (134k molecules; 100k train, 18k val, 13k test split) for 12 quantum chemical properties</li>
<li><strong>Force Prediction</strong>: <a href="http://www.sgdml.org/#datasets">MD17 dataset</a> (9,500 train, 500 val split) for 8 different small molecules</li>
<li><strong>Ligand Binding Affinity</strong>: PDBBind v2019 (4,463 protein-ligand complexes) with 30% and 60% sequence identity splits</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise Generator</strong>: 4-layer Equivariant Graph Neural Network (EGNN) that outputs atom-specific Gaussian noise distributions</li>
<li><strong>Denoising Module</strong>: 7-layer EGNN decoder</li>
<li><strong>Training Objective</strong>: $\mathcal{L}_{DenoiseVAE} = \mathcal{L}_{Denoise} + \lambda \mathcal{L}_{KL}$ with $\lambda=1$</li>
<li><strong>Noise Sampling</strong>: Reparameterization trick with $\tilde{x}_i = x_i + \epsilon \sigma_i$</li>
<li><strong>Prior Distribution</strong>: Standard deviation $\sigma=0.1$</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Model Size</strong>: 1.44M parameters total</li>
<li><strong>Fine-tuning Protocol</strong>: Noise Generator discarded after pre-training; only the pre-trained Denoising Module (7-layer EGNN) is retained for downstream fine-tuning</li>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate decay (max LR of 0.0005)</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>System Training</strong>: Fine-tuned end-to-end for specific tasks; force prediction involves computing the gradient of the predicted energy</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Ablation Studies</strong>: Sensitivity analysis confirmed $\lambda=1$ and $\sigma=0.1$ as optimal hyperparameters; removing the KL term leads to trivial solutions (near-zero noise)</li>
<li><strong>Noise Generator Depth</strong>: 4 EGNN layers outperformed 2 layers across both QM9 and MD17 benchmarks</li>
<li><strong>Covariance Structure</strong>: Full covariance matrix (non-independent noise sampling) yielded comparable results to diagonal variance (independent sampling), likely because the EGNN already integrates neighboring atom information</li>
<li><strong>O(3) Invariance</strong>: The method satisfies O(3) probabilistic invariance, meaning the noise distribution is unchanged under rotations and reflections</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>GPU Configuration</strong>: Experiments conducted on a single RTX A3090 GPU; 6 GPUs with 144GB total memory sufficient for full reproduction</li>
<li><strong>CPU</strong>: Intel Xeon Gold 5318Y @ 2.10GHz</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, Y., Chen, J., Jiao, R., Li, J., Huang, W., &amp; Su, B. (2025). DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training. <em>The Thirteenth International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liu2025denoisevae,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DenoiseVAE: Learning Molecule-Adaptive Noise Distributions for Denoising-based 3D Molecular Pre-training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yurou Liu and Jiahao Chen and Rui Jiao and Jiangmeng Li and Wenbing Huang and Bing Su}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{The Thirteenth International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://openreview.net/forum?id=ym7pr83XQr}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://iclr.cc/virtual/2025/poster/27701">ICLR 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=ym7pr83XQr">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=ym7pr83XQr">PDF on OpenReview</a></li>
</ul>
]]></content:encoded></item><item><title>eSEN: Smooth Interatomic Potentials (ICML Spotlight)</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/learning-smooth-interatomic-potentials/</guid><description>Fu et al. propose energy conservation as a key MLIP diagnostic and introduce eSEN, bridging test accuracy and real performance.</description><content:encoded><![CDATA[<h2 id="paper-overview">Paper Overview</h2>
<p>This is a <strong>method paper</strong>. It addresses a critical disconnect in the evaluation of Machine Learning Interatomic Potentials (MLIPs) and introduces a novel architecture, <strong>eSEN</strong>, designed based on insights from this analysis. The paper proposes a new standard for evaluating MLIPs beyond simple test-set errors.</p>
<h2 id="the-energy-conservation-gap-in-mlip-evaluation">The Energy Conservation Gap in MLIP Evaluation</h2>
<p>The motivation addresses a well-known but under-addressed problem in the field: improvements in standard MLIP metrics (lower energy/force MAE on static test sets) do not reliably translate to better performance on complex downstream tasks like molecular dynamics (MD) simulations, materials stability prediction, or phonon calculations. The authors seek to understand why this gap exists and how to design models that are both accurate on test sets and physically reliable in practical scientific workflows.</p>
<h2 id="the-esen-architecture-and-continuous-representation">The eSEN Architecture and Continuous Representation</h2>
<p>The novelty is twofold, spanning both a conceptual framework for evaluation and a new model architecture:</p>
<ol>
<li>
<p><strong>Energy Conservation as a Diagnostic Test</strong>: The core conceptual contribution is using an MLIP&rsquo;s ability to conserve energy in out-of-distribution MD simulations as a crucial diagnostic test. The authors demonstrate that for models passing this test, a strong correlation between test-set error and downstream task performance is restored.</p>
</li>
<li>
<p><strong>The eSEN Architecture</strong>: The paper introduces the <strong>equivariant Smooth Energy Network (eSEN)</strong>, designed with specific choices to ensure a smooth and well-behaved Potential Energy Surface (PES):</p>
<ul>
<li><strong>Strictly Conservative Forces</strong>: Forces are computed exclusively as the negative gradient of energy ($F = -\nabla E$), using conservative force prediction instead of faster direct-force prediction heads.</li>
<li><strong>Continuous Representations</strong>: Maintains strict equivariance and smoothness by using equivariant gated non-linearities instead of discretizing spherical harmonic representations during nodewise processing.</li>
<li><strong>Smooth PES Construction</strong>: Critical design choices include using distance cutoffs, polynomial envelope functions ensuring derivatives go to zero at cutoffs, and limited radial basis functions to avoid overly sensitive PES.</li>
</ul>
</li>
<li>
<p><strong>Efficient Training Strategy</strong>: A two-stage training regimen with fast pre-training using a non-conservative direct-force model, followed by fine-tuning to enforce energy conservation. This captures the efficiency of direct-force training while ensuring physical robustness.</p>
</li>
</ol>
<h2 id="evaluating-ood-energy-conservation-and-physical-properties">Evaluating OOD Energy Conservation and Physical Properties</h2>
<p>The paper presents a comprehensive experimental validation:</p>
<ol>
<li>
<p><strong>Ablation Studies on Energy Conservation</strong>: MD simulations on out-of-distribution systems (TM23 and MD22 datasets) systematically tested key design choices (direct-force vs. conservative, representation discretization, neighbor limits, envelope functions). This empirically demonstrated which choices lead to energy drift despite negligible impact on test-set MAE.</p>
</li>
<li>
<p><strong>Physical Property Prediction Benchmarks</strong>: The eSEN model was evaluated on challenging downstream tasks:</p>
<ul>
<li><strong>Matbench-Discovery</strong>: Materials stability and thermal conductivity prediction, where eSEN achieved the highest F1 score among compliant models and excelled at both metrics simultaneously.</li>
<li><strong>MDR Phonon Benchmark</strong>: Predicting phonon properties that test accurate second and third-order derivatives of the PES. eSEN achieved state-of-the-art results, particularly outperforming direct-force models.</li>
<li><strong>SPICE-MACE-OFF</strong>: Standard energy and force prediction on organic molecules, demonstrating that physical plausibility design choices enhanced raw accuracy.</li>
</ul>
</li>
<li>
<p><strong>Correlation Analysis</strong>: Explicit plots of test-set energy MAE versus performance on downstream benchmarks showed weak overall correlation that becomes strong and predictive when restricted to models passing the energy conservation test.</p>
</li>
</ol>
<h2 id="outcomes-and-conclusions">Outcomes and Conclusions</h2>
<ul>
<li>
<p><strong>Primary Conclusion</strong>: Energy conservation is a critical, practical property for MLIPs. Using it as a filter re-establishes test-set error as a reliable proxy for model development, dramatically accelerating the innovation cycle. Models that are not conservative, even with low test error, are unreliable for many critical scientific applications.</p>
</li>
<li>
<p><strong>Model Performance</strong>: The eSEN architecture outperforms base models across diverse tasks, from energy/force prediction to geometry optimization, phonon calculations, and thermal conductivity prediction.</p>
</li>
<li>
<p><strong>Actionable Design Principles</strong>: The paper provides experimentally-validated architectural choices that promote physical plausibility. Seemingly minor details, like how atomic neighbors are selected, can have profound impacts on a model&rsquo;s utility in simulations.</p>
</li>
<li>
<p><strong>Efficient Path to Robust Models</strong>: The direct-force pre-training plus conservative fine-tuning strategy offers a practical method for developing physically robust models without incurring the full computational cost of conservative training from scratch.</p>
</li>
</ul>
<hr>
<h2 id="reproducibility">Reproducibility</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/fairchem">fairchem (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation within FAIR Chemistry framework</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/facebook/OMAT24">OMAT24 (Hugging Face)</a></td>
          <td>Model</td>
          <td>FAIR Acceptable Use Policy</td>
          <td>Pre-trained eSEN-30M-MP and eSEN-30M-OAM checkpoints</td>
      </tr>
      <tr>
          <td><a href="https://openreview.net/forum?id=R0PBjxIbgm">OpenReview</a></td>
          <td>Paper</td>
          <td>CC BY 4.0</td>
          <td>ICML 2025 camera-ready paper</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>The eSEN architecture builds on components from <strong>eSCN</strong> (Equivariant Spherical Channel Network) and <strong>Equiformer</strong>, combining them with design choices that prioritize smoothness and energy conservation. The implementation integrates into the standard <code>fairchem</code> Open Catalyst experimental framework.</p>
<h4 id="layer-structure">Layer Structure</h4>
<ul>
<li><strong>Edgewise Convolution</strong>: Uses <code>SO2</code> convolution layers (from eSCN) with an envelope function applied. Source and target embeddings are concatenated before convolution.</li>
<li><strong>Nodewise Feed-Forward</strong>: Two equivariant linear layers with an intermediate <strong>SiLU-based gated non-linearity</strong> (from Equiformer).</li>
<li><strong>Normalization</strong>: Equivariant Layer Normalization (from Equiformer).</li>
</ul>
<h4 id="smoothness-design-choices">Smoothness Design Choices</h4>
<p>Several architectural decisions distinguish eSEN from prior work:</p>
<ul>
<li><strong>No Grid Projection</strong>: eSEN performs operations directly in the spherical harmonic space to maintain equivariance and energy conservation, bypassing the projection of spherical harmonics to spatial grids for non-linearity.</li>
<li><strong>Distance Cutoff for Graph Construction</strong>: Uses a strict distance cutoff (6 Å for MPTrj models, 5 Å for SPICE models). Neighbor limits introduce discontinuities that break energy conservation.</li>
<li><strong>Polynomial Envelope Functions</strong>: Ensures derivatives go to zero smoothly at the cutoff radius.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<h4 id="two-stage-training-esen-30m-mp">Two-Stage Training (eSEN-30M-MP)</h4>
<ol>
<li><strong>Direct-Force Pre-training</strong> (60 epochs): Uses <strong>DeNS</strong> (Denoising Non-equilibrium Structures) to reduce overfitting. This stage is fast because it does not require backpropagation through energy gradients.</li>
<li><strong>Conservative Fine-tuning</strong> (40 epochs): The direct-force head is removed, and forces are calculated via gradients ($F = -\nabla E$). This enforces energy conservation.</li>
</ol>
<p><strong>Important</strong>: DeNS is used exclusively during the direct-force pre-training stage, with a noising probability of 0.5, a standard deviation of 0.1 Å for the added Gaussian noise, and a DeNS loss coefficient of 10. The fine-tuning strategy reduces the wall-clock time for model training by 40%.</p>
<h4 id="optimization">Optimization</h4>
<ul>
<li><strong>Optimizer</strong>: AdamW with cosine learning rate scheduler</li>
<li><strong>Max Learning Rate</strong>: $4 \times 10^{-4}$</li>
<li><strong>Batch Size</strong>: 512 (for MPTrj models)</li>
<li><strong>Weight Decay</strong>: $1 \times 10^{-3}$</li>
<li><strong>Gradient Clipping</strong>: Norm of 100</li>
<li><strong>Warmup</strong>: 0.1 epochs with a factor of 0.2</li>
</ul>
<h4 id="loss-function">Loss Function</h4>
<p>A composite loss combining per-atom energy MAE, force $L_2$ loss, and stress MAE:</p>
<p>$$
\begin{aligned}
\mathcal{L} = \lambda_{\text{e}} \frac{1}{N} \sum_{i=1}^N \lvert E_{i} - \hat{E}_{i} \rvert + \lambda_{\text{f}} \frac{1}{3N} \sum_{i=1}^N \lVert \mathbf{F}_{i} - \hat{\mathbf{F}}_{i} \rVert_2^2 + \lambda_{\text{s}} \lVert \mathbf{S} - \hat{\mathbf{S}} \rVert_1
\end{aligned}
$$</p>
<p>For MPTrj-30M, the weighting coefficients are set to $\lambda_{\text{e}} = 20$, $\lambda_{\text{f}} = 20$, and $\lambda_{\text{s}} = 5$.</p>
<h3 id="data">Data</h3>
<h4 id="training-data">Training Data</h4>
<ul>
<li><strong>Inorganic</strong>: MPTrj (Materials Project Trajectory) dataset</li>
<li><strong>Organic</strong>: SPICE-MACE-OFF dataset</li>
</ul>
<h4 id="test-data-construction">Test Data Construction</h4>
<ul>
<li><strong>MPTrj Testing</strong>: Since MPTrj lacks an official test split, the authors created a test set using 5,000 random samples from the <strong>subsampled Alexandria (sAlex)</strong> dataset to ensure fair comparison.</li>
<li><strong>Out-of-Distribution Conservation Testing</strong>:
<ul>
<li><em>Inorganic</em>: <strong>TM23</strong> dataset (transition metal defects). Simulation: 100 ps, 5 fs timestep.</li>
<li><em>Organic</em>: <strong>MD22</strong> dataset (large molecules). Simulation: 100 ps, 1 fs timestep.</li>
</ul>
</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Compute for training operations predominantly utilizes <strong>80GB NVIDIA A100 GPUs</strong>.</p>
<h4 id="inference-efficiency">Inference Efficiency</h4>
<p>For a periodic system of <strong>216 atoms</strong> on a single A100 (PyTorch 2.4.0, CUDA 12.1, no compile/torchscript), the 2-layer eSEN models achieve approximately <strong>0.4 million steps per day</strong> (3.2M parameters) and <strong>0.8 million steps per day</strong> (6.5M parameters), comparable to MACE-OFF-L at 0.7 million steps per day.</p>
<h3 id="evaluation">Evaluation</h3>
<p>The paper evaluated eSEN across three major benchmark tasks. Key evaluation metrics included energy MAE (meV/atom), force MAE (meV/Å), stress MAE (meV/Å/atom), F1 score for stability prediction, $\kappa_{\text{SRME}}$ for thermal conductivity, and phonon frequency accuracy.</p>
<h4 id="ablation-test-set-mae-table-1">Ablation Test-Set MAE (Table 1)</h4>
<p>Design choices that dramatically affect energy conservation have negligible impact on static test-set MAE, which is precisely why test-set error alone is misleading. All models are 2-layer with 3.2M parameters, $L_{\text{max}} = 2$, $M_{\text{max}} = 2$:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Energy MAE</th>
          <th>Force MAE</th>
          <th>Stress MAE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>eSEN (default)</td>
          <td>17.02</td>
          <td>43.96</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, direct-force</td>
          <td>18.66</td>
          <td>43.62</td>
          <td>0.16</td>
      </tr>
      <tr>
          <td>eSEN, neighbor limit</td>
          <td>17.30</td>
          <td>44.11</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, no envelope</td>
          <td>17.60</td>
          <td>44.69</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, $N_{\text{basis}} = 512$</td>
          <td>19.87</td>
          <td>48.29</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, Bessel</td>
          <td>17.65</td>
          <td>44.83</td>
          <td>0.15</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=6</td>
          <td>17.05</td>
          <td>43.10</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=10</td>
          <td>17.11</td>
          <td>43.13</td>
          <td>0.14</td>
      </tr>
      <tr>
          <td>eSEN, discrete, res=14</td>
          <td>17.12</td>
          <td>43.09</td>
          <td>0.14</td>
      </tr>
  </tbody>
</table>
<p>Energy MAE in meV/atom. Force MAE in meV/Å. Stress MAE in meV/Å/atom.</p>
<h4 id="matbench-discovery-tables-2-and-3">Matbench-Discovery (Tables 2 and 3)</h4>
<p><strong>Compliant models</strong> (trained only on MPTrj or its subset), unique prototype split:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>DAF</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>0.831</strong></td>
          <td><strong>5.260</strong></td>
          <td><strong>0.340</strong></td>
          <td><strong>0.0752</strong></td>
      </tr>
      <tr>
          <td>eqV2-S-DeNS</td>
          <td>0.815</td>
          <td>5.042</td>
          <td>1.676</td>
          <td>0.0757</td>
      </tr>
      <tr>
          <td>MatRIS-MP</td>
          <td>0.809</td>
          <td>5.049</td>
          <td>0.861</td>
          <td>0.0773</td>
      </tr>
      <tr>
          <td>AlphaNet-MP</td>
          <td>0.799</td>
          <td>4.863</td>
          <td>1.31</td>
          <td>0.1067</td>
      </tr>
      <tr>
          <td>DPA3-v2-MP</td>
          <td>0.786</td>
          <td>4.822</td>
          <td>0.959</td>
          <td>0.0823</td>
      </tr>
      <tr>
          <td>ORB v2 MPtrj</td>
          <td>0.765</td>
          <td>4.702</td>
          <td>1.725</td>
          <td>0.1007</td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>0.760</td>
          <td>4.629</td>
          <td>0.550</td>
          <td>0.0847</td>
      </tr>
      <tr>
          <td>GRACE-2L-MPtrj</td>
          <td>0.691</td>
          <td>4.163</td>
          <td>0.525</td>
          <td>0.0897</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>0.669</td>
          <td>3.777</td>
          <td>0.647</td>
          <td>0.0915</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>0.613</td>
          <td>3.361</td>
          <td>1.717</td>
          <td>0.0949</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>0.569</td>
          <td>2.882</td>
          <td>1.412</td>
          <td>0.1117</td>
      </tr>
  </tbody>
</table>
<p>eSEN-30M-MP excels at both F1 and $\kappa_{\text{SRME}}$ simultaneously, while all previous models only achieve SOTA on one or the other.</p>
<p><strong>Non-compliant models</strong> (trained on additional datasets):</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>F1</th>
          <th>$\kappa_{\text{SRME}}$</th>
          <th>RMSD</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-OAM</strong></td>
          <td><strong>0.925</strong></td>
          <td><strong>0.170</strong></td>
          <td><strong>0.0608</strong></td>
      </tr>
      <tr>
          <td>eqV2-M-OAM</td>
          <td>0.917</td>
          <td>1.771</td>
          <td>0.0691</td>
      </tr>
      <tr>
          <td>ORB v3</td>
          <td>0.905</td>
          <td>0.210</td>
          <td>0.0750</td>
      </tr>
      <tr>
          <td>SevenNet-MF-ompa</td>
          <td>0.901</td>
          <td>0.317</td>
          <td>0.0639</td>
      </tr>
      <tr>
          <td>DPA3-v2-OpenLAM</td>
          <td>0.890</td>
          <td>0.687</td>
          <td>0.0679</td>
      </tr>
      <tr>
          <td>GRACE-2L-OAM</td>
          <td>0.880</td>
          <td>0.294</td>
          <td>0.0666</td>
      </tr>
      <tr>
          <td>MatterSim-v1-5M</td>
          <td>0.862</td>
          <td>0.574</td>
          <td>0.0733</td>
      </tr>
      <tr>
          <td>MACE-MPA-0</td>
          <td>0.852</td>
          <td>0.412</td>
          <td>0.0731</td>
      </tr>
  </tbody>
</table>
<p>The eSEN-30M-OAM model is pre-trained on the OMat24 dataset, then fine-tuned on the subsampled Alexandria (sAlex) dataset and MPTrj dataset.</p>
<h4 id="mdr-phonon-benchmark-table-4">MDR Phonon Benchmark (Table 4)</h4>
<p>Metrics: maximum phonon frequency MAE($\omega_{\text{max}}$) in K, vibrational entropy MAE($S$) in J/K/mol, Helmholtz free energy MAE($F$) in kJ/mol, heat capacity MAE($C_V$) in J/K/mol.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>MAE($\omega_{\text{max}}$)</th>
          <th>MAE($S$)</th>
          <th>MAE($F$)</th>
          <th>MAE($C_V$)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>eSEN-30M-MP</strong></td>
          <td><strong>21</strong></td>
          <td><strong>13</strong></td>
          <td><strong>5</strong></td>
          <td><strong>4</strong></td>
      </tr>
      <tr>
          <td>SevenNet-13i5</td>
          <td>26</td>
          <td>28</td>
          <td>10</td>
          <td>5</td>
      </tr>
      <tr>
          <td>GRACE-2L (r6)</td>
          <td>40</td>
          <td>25</td>
          <td>9</td>
          <td>5</td>
      </tr>
      <tr>
          <td>SevenNet-0</td>
          <td>40</td>
          <td>48</td>
          <td>19</td>
          <td>9</td>
      </tr>
      <tr>
          <td>MACE</td>
          <td>61</td>
          <td>60</td>
          <td>24</td>
          <td>13</td>
      </tr>
      <tr>
          <td>CHGNet</td>
          <td>89</td>
          <td>114</td>
          <td>45</td>
          <td>21</td>
      </tr>
      <tr>
          <td>M3GNet</td>
          <td>98</td>
          <td>150</td>
          <td>56</td>
          <td>22</td>
      </tr>
  </tbody>
</table>
<p>Direct-force models show dramatically worse performance at the standard 0.01 Å displacement (e.g., eqV2-S-DeNS: 280/224/54/94) but improve at larger displacements (0.2 Å: 58/26/8/8), revealing that their PES is rough near energy minima.</p>
<h4 id="spice-mace-off-table-5">SPICE-MACE-OFF (Table 5)</h4>
<p>Test set MAE for organic molecule energy/force prediction. Energy MAE in meV/atom, force MAE in meV/Å:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>MACE-4.7M (E/F)</th>
          <th>EscAIP-45M* (E/F)</th>
          <th>eSEN-3.2M (E/F)</th>
          <th>eSEN-6.5M (E/F)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PubChem</td>
          <td>0.88 / 14.75</td>
          <td>0.53 / 5.86</td>
          <td>0.22 / 6.10</td>
          <td><strong>0.15</strong> / <strong>4.21</strong></td>
      </tr>
      <tr>
          <td>DES370K M.</td>
          <td>0.59 / 6.58</td>
          <td>0.41 / 3.48</td>
          <td>0.17 / 1.85</td>
          <td><strong>0.13</strong> / <strong>1.24</strong></td>
      </tr>
      <tr>
          <td>DES370K D.</td>
          <td>0.54 / 6.62</td>
          <td>0.38 / 2.18</td>
          <td>0.20 / 2.77</td>
          <td><strong>0.15</strong> / <strong>2.12</strong></td>
      </tr>
      <tr>
          <td>Dipeptides</td>
          <td>0.42 / 10.19</td>
          <td>0.31 / 5.21</td>
          <td>0.10 / 3.04</td>
          <td><strong>0.07</strong> / <strong>2.00</strong></td>
      </tr>
      <tr>
          <td>Sol. AA</td>
          <td>0.98 / 19.43</td>
          <td>0.61 / 11.52</td>
          <td>0.30 / 5.76</td>
          <td><strong>0.25</strong> / <strong>3.68</strong></td>
      </tr>
      <tr>
          <td>Water</td>
          <td>0.83 / 13.57</td>
          <td>0.72 / 10.31</td>
          <td>0.24 / 3.88</td>
          <td><strong>0.15</strong> / <strong>2.50</strong></td>
      </tr>
      <tr>
          <td>QMugs</td>
          <td>0.45 / 16.93</td>
          <td>0.41 / 8.74</td>
          <td>0.16 / 5.70</td>
          <td><strong>0.12</strong> / <strong>3.78</strong></td>
      </tr>
  </tbody>
</table>
<p>*EscAIP-45M is a direct-force model. eSEN-6.5M outperforms MACE-OFF-L and EscAIP on all test splits. The smaller eSEN-3.2M has inference efficiency comparable to MACE-4.7M while achieving lower MAE.</p>
<hr>
<h2 id="why-these-design-choices-matter">Why These Design Choices Matter</h2>
<h3 id="bounded-energy-derivatives-and-the-verlet-integrator">Bounded Energy Derivatives and the Verlet Integrator</h3>
<p>The theoretical foundation for why smoothness matters comes from Theorem 5.1 of Hairer et al. (2003). For the Verlet integrator (the standard NVE integrator), the total energy drift satisfies:</p>
<p>$$
|E(\mathbf{r}_T, \mathbf{a}) - E(\mathbf{r}_0, \mathbf{a})| \leq C \Delta t^2 + C_N \Delta t^N T
$$</p>
<p>where $T$ is the total simulation time ($T \leq \Delta t^{-N}$), $N$ is the highest order for which the $N$th derivative of $E$ is continuously differentiable with bounded derivative, and $C$, $C_N$ are constants independent of $T$ and $\Delta t$. The first term is a time-independent fluctuation of $O(\Delta t^2)$; the second term governs long-term conservation. This means the PES must be continuously differentiable to high order, with bounded derivatives, for energy conservation in long-time simulations.</p>
<h3 id="architectural-choices-that-break-conservation">Architectural Choices That Break Conservation</h3>
<p>The authors provide theoretical justification for why specific architectural choices break energy conservation:</p>
<ul>
<li><strong>Max Neighbor Limit (KNN)</strong>: Introduces discontinuity in the PES. If a neighbor at distance $r$ moves to $r + \epsilon$ and drops out of the top-$K$, the energy changes discontinuously.</li>
<li><strong>Grid Discretization</strong>: Projecting spherical harmonics to a spatial grid introduces discretization errors in energy gradients that break conservation. This can be mitigated with higher-resolution grids but not eliminated.</li>
<li><strong>Direct-Force Prediction</strong>: Imposes no mathematical constraint that forces must be the gradient of an energy scalar field. In other words, $\nabla \times \mathbf{F} \neq 0$ is permitted, violating the requirement for a conservative force field.</li>
</ul>
<h3 id="displacement-sensitivity-in-phonon-calculations">Displacement Sensitivity in Phonon Calculations</h3>
<p>An important empirical finding concerns how displacement values affect phonon predictions. Conservative models (eSEN, MACE) show convergent phonon band structures as displacement decreases toward zero. In contrast, direct-force models (eqV2-S-DeNS) fail to converge, exhibiting missing acoustic branches and spurious imaginary frequencies at small displacements. While direct-force models achieve competitive thermodynamic property accuracy at large displacements (0.2 Å), this is deceptive: the underlying phonon band structures remain inaccurate, and the apparent accuracy comes from Boltzmann-weighted integrals smoothing over errors.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Fu, X., Wood, B. M., Barroso-Luque, L., Levine, D. S., Gao, M., Dzamba, M., &amp; Zitnick, C. L. (2025). Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>, PMLR 267:17875–17893.</p>
<p><strong>Publication</strong>: ICML 2025 (Spotlight)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{fu2025learning,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Learning Smooth and Expressive Interatomic Potentials for Physical Property Prediction}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fu, Xiang and Wood, Brandon M. and Barroso-Luque, Luis and Levine, Daniel S. and Gao, Meng and Dzamba, Misko and Zitnick, C. Lawrence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{17875--17893}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45302">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=R0PBjxIbgm">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=R0PBjxIbgm">PDF on OpenReview</a></li>
<li><a href="https://huggingface.co/facebook/OMAT24">OMAT24 model on Hugging Face</a></li>
<li><a href="https://github.com/facebookresearch/fairchem">Code on GitHub (fairchem)</a></li>
</ul>
]]></content:encoded></item><item><title>Efficient DFT Hamiltonian Prediction via Adaptive Sparsity</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/efficient-dft-hamiltonian-predicton-sphnet/</guid><description>Luo et al. introduce SPHNet, using adaptive sparsity to achieve up to 7x speedup in SE(3)-equivariant Hamiltonian prediction.</description><content:encoded><![CDATA[<h2 id="core-innovation-adaptive-sparsity-in-se3-networks">Core Innovation: Adaptive Sparsity in SE(3) Networks</h2>
<p>This is a <strong>methodological paper</strong> introducing a novel architecture and training curriculum to solve efficiency bottlenecks in Geometric Deep Learning. It directly tackles the primary computational bottleneck in modern SE(3)-equivariant graph neural networks (the tensor product operation) and proposes a generalizable solution through adaptive network sparsification.</p>
<h2 id="the-computational-bottleneck-in-dft-hamiltonian-prediction">The Computational Bottleneck in DFT Hamiltonian Prediction</h2>
<p>SE(3)-equivariant networks are accurate but unscalable for DFT Hamiltonian prediction due to two key bottlenecks:</p>
<ul>
<li><strong>Atom Scaling</strong>: Tensor Product (TP) operations grow quadratically with atoms ($N^2$).</li>
<li><strong>Basis Set Scaling</strong>: Computational complexity grows with the sixth power of the angular momentum order ($L^6$). Larger basis sets (e.g., def2-TZVP) require higher orders ($L=6$), making them prohibitively slow.</li>
</ul>
<p>Existing SE(3)-equivariant models cannot handle large molecules (40-100 atoms) with high-quality basis sets, limiting their practical applicability in computational chemistry.</p>
<h2 id="sphnet-architecture-and-the-three-phase-sparsity-scheduler">SPHNet Architecture and the Three-Phase Sparsity Scheduler</h2>
<p><strong>SPHNet</strong> introduces <strong>Adaptive Sparsity</strong> to prune redundant computations at two levels:</p>
<ol>
<li><strong>Sparse Pair Gate</strong>: Learns which atom pairs to include in message passing, adapting the interaction graph based on importance.</li>
<li><strong>Sparse TP Gate</strong>: Filters which spherical harmonic triplets $(l_1, l_2, l_3)$ are computed in tensor product operations, pruning higher-order combinations that contribute less to accuracy.</li>
<li><strong>Three-Phase Sparsity Scheduler</strong>: A training curriculum (Random → Adaptive → Fixed) that enables stable convergence to high-performing sparse subnetworks.</li>
</ol>
<p>Key insight: The Sparse Pair Gate learns to preserve long-range interactions (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are abundant and easier to learn, while rare long-range interactions require more samples for accurate representation, making them more critical to retain.</p>
<h2 id="benchmarks-and-ablation-studies">Benchmarks and Ablation Studies</h2>
<p>The authors evaluated SPHNet on three datasets (MD17, QH9, and PubChemQH) with varying molecule sizes and basis set complexities. Baselines include SchNOrb, PhiSNet, QHNet, and WANet. SchNOrb and PhiSNet results are limited to MD17, as those models are designed for trajectory datasets. WANet was not open-sourced, so only partial metrics from its paper are reported.</p>
<h3 id="evaluation-metrics">Evaluation Metrics</h3>
<ul>
<li><strong>Hamiltonian MAE ($H$)</strong>: Mean absolute error between predicted and DFT-computed Hamiltonian matrices, in Hartrees ($E_h$)</li>
<li><strong>Occupied Orbital Energy MAE ($\epsilon$)</strong>: Mean absolute error of all occupied molecular orbital energies derived from the predicted Hamiltonian</li>
<li><strong>Orbital Coefficient Similarity ($\psi$)</strong>: Cosine similarity of occupied molecular orbital coefficients between predicted and reference wavefunctions</li>
</ul>
<h3 id="ablation-studies">Ablation Studies</h3>
<p><strong>Sparse Gates</strong> (on PubChemQH):</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Both gates</td>
          <td>97.31</td>
          <td>5.62</td>
          <td>7.09x</td>
      </tr>
      <tr>
          <td>Pair Gate only</td>
          <td>87.70</td>
          <td>6.98</td>
          <td>2.73x</td>
      </tr>
      <tr>
          <td>TP Gate only</td>
          <td>94.31</td>
          <td>8.04</td>
          <td>3.98x</td>
      </tr>
      <tr>
          <td>Neither gate</td>
          <td>86.35</td>
          <td>10.91</td>
          <td>1.73x</td>
      </tr>
  </tbody>
</table>
<p>The Sparse Pair Gate contributes a 78% speedup with 30% memory reduction. The Sparse TP Gate (pruning 70% of combinations) yields a 160% speedup. Both gates together achieve the highest speedup, though accuracy slightly decreases compared to no gating.</p>
<p><strong>Three-Phase Scheduler</strong>: Removing the random phase causes convergence to local optima ($112.68 \pm 10.75$ vs $97.31 \pm 0.52$). Removing the adaptive phase increases variance and lowers accuracy ($122.79 \pm 19.02$). Removing the fixed phase has minimal accuracy impact but reduces speedup from 7.09x to 5.45x due to dynamic graph overhead.</p>
<p><strong>Sparsity Rate</strong>: The critical sparsity threshold scales with system complexity: 30% for MD17 (small molecules), 40% for QH9 (medium), and 70% for PubChemQH (large). Beyond the threshold, MAE increases sharply. Computational cost decreases approximately linearly with sparsity rate.</p>
<h3 id="transferability-to-other-models">Transferability to Other Models</h3>
<p>To demonstrate the speedup is architecture-agnostic, the authors applied the Sparse Pair Gate and Sparse TP Gate to the QHNet baseline on PubChemQH:</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet baseline</td>
          <td>123.74</td>
          <td>22.50</td>
          <td>1.00x</td>
      </tr>
      <tr>
          <td>+ TP Gate</td>
          <td>128.16</td>
          <td>12.68</td>
          <td>2.04x</td>
      </tr>
      <tr>
          <td>+ Pair Gate</td>
          <td>126.27</td>
          <td>10.07</td>
          <td>1.66x</td>
      </tr>
      <tr>
          <td>+ Both gates</td>
          <td>128.89</td>
          <td>8.46</td>
          <td>3.30x</td>
      </tr>
  </tbody>
</table>
<p>The gates reduced QHNet&rsquo;s memory by 62% and improved speed by 3.3x with modest accuracy trade-off, confirming the gates are portable modules applicable to other SE(3)-equivariant architectures.</p>
<h2 id="performance-results">Performance Results</h2>
<h3 id="qh9-134k-molecules-leq-20-atoms">QH9 (134k molecules, $\leq$ 20 atoms)</h3>
<p>SPHNet achieves 3.3x to 4.0x speedup over QHNet across all four QH9 splits, with improved Hamiltonian MAE and orbital energy MAE. Memory drops to 0.23 GB/sample (33% of QHNet&rsquo;s 0.70 GB). On the stable-iid split, Hamiltonian MAE improves from 76.31 to 45.48 ($10^{-6} E_h$).</p>
<h3 id="pubchemqh-50k-molecules-40-100-atoms">PubChemQH (50k molecules, 40-100 atoms)</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>$H$ [$10^{-6} E_h$] $\downarrow$</th>
          <th>$\epsilon$ [$E_h$] $\downarrow$</th>
          <th>$\psi$ [$10^{-2}$] $\uparrow$</th>
          <th>Memory [GB] $\downarrow$</th>
          <th>Speedup $\uparrow$</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>QHNet</td>
          <td>123.74</td>
          <td>3.33</td>
          <td>2.32</td>
          <td>22.5</td>
          <td>1.0x</td>
      </tr>
      <tr>
          <td>WANet</td>
          <td>99.98</td>
          <td><strong>1.17</strong></td>
          <td><strong>3.13</strong></td>
          <td>15.0</td>
          <td>2.4x</td>
      </tr>
      <tr>
          <td>SPHNet</td>
          <td><strong>97.31</strong></td>
          <td>2.16</td>
          <td>2.97</td>
          <td><strong>5.62</strong></td>
          <td><strong>7.1x</strong></td>
      </tr>
  </tbody>
</table>
<p>SPHNet achieves the best Hamiltonian MAE and efficiency, though WANet outperforms on orbital energy MAE and coefficient similarity. The higher speedup on PubChemQH (vs QH9) reflects greater computational redundancy in larger systems with higher-order basis sets ($L_{max} = 6$ for def2-TZVP vs $L_{max} = 4$ for def2-SVP).</p>
<h3 id="md17-small-molecule-trajectories">MD17 (Small Molecule Trajectories)</h3>
<p>SPHNet achieves accuracy comparable to QHNet and PhiSNet on four MD17 molecules (water, ethanol, malondialdehyde, uracil; 3-12 atoms). MD17 represents a simpler task where baseline models already perform well, leaving limited room for improvement. For water (3 atoms), the number of interaction combinations is inherently small, limiting the benefit of adaptive sparsification.</p>
<h3 id="scaling-limit">Scaling Limit</h3>
<p>SPHNet can train on systems with approximately 3000 atomic orbitals on a single A6000 GPU; the QHNet baseline runs out of memory at approximately 1800 orbitals. Memory consumption scales more favorably as molecule size increases.</p>
<h3 id="key-findings">Key Findings</h3>
<ul>
<li><strong>Adaptive sparsity scales with system complexity</strong>: The method is most effective for large systems where redundancy is high. For small molecules (e.g., water with only 3 atoms), every interaction is critical, so pruning hurts accuracy and yields negligible speedup.</li>
<li><strong>Long-range pair preservation</strong>: The Sparse Pair Gate selects long-range pairs (16-25 Angstrom) at higher rates than short-range ones. Short-range pairs are numerous and easier to learn, while rare long-range interactions are harder to represent and thus more critical to retain.</li>
<li><strong>Generalizable components</strong>: The sparsification techniques are portable modules, demonstrated by successful integration into QHNet with 3.3x speedup.</li>
<li><strong>Architecture ablation</strong>: Removing one Vectorial Node Interaction block or Spherical Node Interaction block significantly hurts accuracy, confirming the importance of the progressive order-increase design. Removing one Pair Construction block has less impact, suggesting room for further speedup.</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/microsoft/SPHNet">SPHNet (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation; archived by Microsoft (Dec 2025), read-only</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/EperLuo/PubChemQH">PubChemQH (Hugging Face)</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>50k molecules, 40-100 atoms, def2-TZVP basis</td>
      </tr>
  </tbody>
</table>
<p>No pre-trained model weights are provided. MD17 and QH9 are publicly available community datasets. Training requires 4x NVIDIA A100 (80GB) GPUs; benchmarking uses a single NVIDIA RTX A6000 (46GB).</p>
<h3 id="data">Data</h3>
<p>The experiments evaluated SPHNet on three datasets with different molecular sizes and basis set complexities. All datasets use DFT calculations as ground truth, with MD17 using the PBE exchange-correlation functional and QH9/PubChemQH using B3LYP.</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Molecules</th>
          <th>Molecule Size</th>
          <th>Basis Set</th>
          <th>$L_{max}$</th>
          <th>Functional</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MD17</td>
          <td>4 systems</td>
          <td>3-12 atoms (water, ethanol, malondialdehyde, uracil)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>PBE</td>
      </tr>
      <tr>
          <td>QH9</td>
          <td>134k</td>
          <td>$\leq$ 20 atoms (Stable/Dynamic splits)</td>
          <td>def2-SVP</td>
          <td>4</td>
          <td>B3LYP</td>
      </tr>
      <tr>
          <td>PubChemQH</td>
          <td>50k</td>
          <td>40-100 atoms</td>
          <td>def2-TZVP</td>
          <td>6</td>
          <td>B3LYP</td>
      </tr>
  </tbody>
</table>
<p><strong>Data Availability</strong>:</p>
<ul>
<li><strong>MD17 &amp; QH9</strong>: Publicly available</li>
<li><strong>PubChemQH</strong>: Publicly available on Hugging Face (<a href="https://huggingface.co/datasets/EperLuo/PubChemQH">EperLuo/PubChemQH</a>)</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Loss Function</strong>:</p>
<p>The model learns the <strong>residual</strong> $\Delta H$:</p>
<p>$$
\begin{aligned}
\Delta H &amp;= H_{\text{ref}} - H_{\text{init}} \\
\mathcal{L} &amp;= \text{MAE}(H_{\text{ref}}, H_{\text{pred}}) + \text{MSE}(H_{\text{ref}}, H_{\text{pred}})
\end{aligned}
$$</p>
<p>where $H_{\text{init}}$ is a computationally inexpensive initial guess computed via PySCF.</p>
<p><strong>Hyperparameters</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>PubChemQH</th>
          <th>QH9</th>
          <th>MD17</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Batch Size</td>
          <td>8</td>
          <td>32</td>
          <td>10 (uracil: 5)</td>
      </tr>
      <tr>
          <td>Training Steps</td>
          <td>300k</td>
          <td>260k</td>
          <td>200k</td>
      </tr>
      <tr>
          <td>Warmup Steps</td>
          <td>1k</td>
          <td>1k</td>
          <td>1k</td>
      </tr>
      <tr>
          <td>Learning Rate</td>
          <td>1e-3</td>
          <td>1e-3</td>
          <td>5e-4</td>
      </tr>
      <tr>
          <td>Sparsity Rate</td>
          <td>0.7</td>
          <td>0.4</td>
          <td>0.1-0.3</td>
      </tr>
      <tr>
          <td>TSS Epoch $t$</td>
          <td>3</td>
          <td>3</td>
          <td>3</td>
      </tr>
  </tbody>
</table>
<p><strong>Sparse Pair Gate</strong>: Adapts the interaction graph. It concatenates zero-order features and inner products of atom pairs, then passes them through a linear layer $F_p$ with sigmoid activation to learn a weight $W_p^{ij}$ for every pair. Pairs are kept only if selected by the scheduler ($U_p^{TSS}$). The overhead comes primarily from the linear layer $F_p$.</p>
<p><strong>Sparse TP Gate</strong>: Filters triplets $(l_1, l_2, l_3)$ inside the TP operation. Higher-order combinations are more likely to be pruned. Complexity: $\mathcal{O}(L^3)$.</p>
<p><strong>Three-Phase Sparsity Scheduler</strong>: Training curriculum designed to optimize the sparse gates effectively:</p>
<ul>
<li><strong>Phase 1 (Random)</strong>: Random selection ($1-k$ probability) to ensure unbiased weight updates. Complexity: $\mathcal{O}(|U|)$.</li>
<li><strong>Phase 2 (Adaptive)</strong>: Selects top $(1-k)$ percent based on learned magnitude. Complexity: $\mathcal{O}(|U|\log|U|)$.</li>
<li><strong>Phase 3 (Fixed)</strong>: Freezes the connectivity mask for maximum inference speed. No overhead.</li>
</ul>
<p><strong>Weight Initialization</strong>: Learnable sparsity weights ($W$) initialized as all-ones vector.</p>
<h3 id="models">Models</h3>
<p>The model predicts the Hamiltonian matrix $H$ from atomic numbers $Z$ and coordinates $r$.</p>
<p><strong>Inputs</strong>: Atomic numbers ($Z$) and 3D coordinates.</p>
<p><strong>Backbone Structure</strong>:</p>
<ol>
<li><strong>Vectorial Node Interaction (x4)</strong>: Uses long-short range message passing. Extracts vectorial representations ($l=1$) without high-order TPs to save cost.</li>
<li><strong>Spherical Node Interaction (x2)</strong>: Projects features to high-order spherical harmonics (up to $L_{max}$). The first block increases the maximum order from 0 to $L_{max}$ without the Sparse Pair Gate; the second block applies the <strong>Sparse Pair Gate</strong> to filter node pairs.</li>
<li><strong>Pair Construction Block (x2)</strong>: Splits into <strong>Diagonal</strong> (self-interaction) and <strong>Non-Diagonal</strong> (cross-interaction) blocks. Both use the <strong>Sparse TP Gate</strong> to prune cross-order combinations $(l_1, l_2, l_3)$. The Non-Diagonal blocks also use the <strong>Sparse Pair Gate</strong> to filter atom pairs. The two Pair Construction blocks receive representations from the two Spherical Node Interaction blocks respectively, and their outputs are summed.</li>
<li><strong>Expansion Block</strong>: Reconstructs the full Hamiltonian matrix from the sparse irreducible representations, exploiting symmetry ($H_{ji} = H_{ij}^T$) to halve computations.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Training</strong>: 4x NVIDIA A100 (80GB)</li>
<li><strong>Benchmarking</strong>: Single NVIDIA RTX A6000 (46GB)</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Luo, E., Wei, X., Huang, L., Li, Y., Yang, H., Xia, Z., Wang, Z., Liu, C., Shao, B., &amp; Zhang, J. (2025). Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity. <em>Proceedings of the 42nd International Conference on Machine Learning</em>, PMLR 267:41368&ndash;41390.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{luo2025efficient,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Luo, Erpai and Wei, Xinran and Huang, Lin and Li, Yunyang and Yang, Han and Xia, Zaishuo and Wang, Zun and Liu, Chang and Shao, Bin and Zhang, Jia}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{41368--41390}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45656">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/forum?id=K3lykWhXON">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=K3lykWhXON">PDF on OpenReview</a></li>
<li><a href="https://github.com/microsoft/SPHNet">GitHub Repository</a> <em>(Note: The official repository was archived by Microsoft in December 2025. It is available for reference but no longer actively maintained.)</em></li>
</ul>
]]></content:encoded></item><item><title>Dark Side of Forces: Non-Conservative ML Force Models</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/dark-side-of-forces/</guid><description>Bigi et al. critique non-conservative force models in ML potentials, showing their simulation failures and proposing hybrid solutions.</description><content:encoded><![CDATA[<h2 id="contribution-systematic-assessment-of-non-conservative-ml-force-models">Contribution: Systematic Assessment of Non-Conservative ML Force Models</h2>
<p>This is a <strong>Systematization</strong> paper. It systematically catalogs the exact failure modes of existing non-conservative force approaches, quantifies them with a new diagnostic metric, and proposes a hybrid Multiple Time-Stepping solution combining the speed benefits of direct force prediction with the physical correctness of conservative models.</p>
<h2 id="motivation-the-speed-accuracy-trade-off-in-ml-force-fields">Motivation: The Speed-Accuracy Trade-off in ML Force Fields</h2>
<p>Many recent machine learning interatomic potential (MLIP) architectures predict forces directly ($F_\theta(r)$). This &ldquo;non-conservative&rdquo; approach avoids the computational overhead of automatic differentiation, yielding faster inference (typically 2-3x speedup) and faster training (up to 3x). However, it sacrifices energy conservation and rotational constraints, potentially destabilizing molecular dynamics simulations. The field lacks rigorous quantification of when this trade-off breaks down and how to mitigate the failures.</p>
<h2 id="novelty-jacobian-asymmetry-and-hybrid-architectures">Novelty: Jacobian Asymmetry and Hybrid Architectures</h2>
<p>Four key contributions:</p>
<ol>
<li>
<p><strong>Jacobian Asymmetry Metric ($\lambda$):</strong> A quantitative diagnostic for non-conservation. Since conservative forces derive from a scalar field, their Jacobian (the Hessian of energy) must be symmetric. The normalized norm of the antisymmetric part quantifies the degree of violation:
$$ \lambda = \frac{|| \mathbf{J}_{\text{anti}} ||_F}{|| \mathbf{J} ||_F} $$
where $\mathbf{J}_{\text{anti}} = (\mathbf{J} - \mathbf{J}^\top)/2$. Measured values range from $\lambda \approx 0.004$ (PET-NC) to $\lambda \approx 0.032$ (SOAP-BPNN-NC), with ORB at 0.015 and EquiformerV2 at 0.017. Notably, the pairwise $\lambda_{ij}$ approaches 1 at large interatomic distances, meaning non-conservative artifacts disproportionately affect long-range and collective interactions.</p>
</li>
<li>
<p><strong>Systematic Failure Mode Catalog:</strong> First comprehensive demonstration that non-conservative models cause runaway heating in NVE ensembles (temperature drifts of ~7,000 billion K/s for PET-NC and ~10x larger for ORB) and equipartition violations in NVT ensembles where different atom types equilibrate to different temperatures, a physical impossibility.</p>
</li>
<li>
<p><strong>Theoretical Analysis of Force vs. Energy Training:</strong> Force-only training overemphasizes high-frequency vibrational modes because force labels carry per-atom gradients that are dominated by stiff, short-range interactions. Energy labels provide a more balanced representation across the frequency spectrum. Additionally, conservative models benefit from backpropagation extending the effective receptive field to approximately 2x the interaction cutoff, while direct-force models are limited to the nominal cutoff radius.</p>
</li>
<li>
<p><strong>Hybrid Training and Inference Protocol:</strong> A practical workflow that combines fast direct-force prediction with conservative corrections:</p>
<ul>
<li><strong>Training:</strong> Pre-train on direct forces, then fine-tune on energy gradients (2-4x faster than training conservative models from scratch)</li>
<li><strong>Inference:</strong> Multiple Time-Stepping (MTS) where fast non-conservative forces are periodically corrected by slower conservative forces</li>
</ul>
</li>
</ol>
<h2 id="methodology-systematic-failure-mode-analysis">Methodology: Systematic Failure Mode Analysis</h2>
<p>The evaluation systematically tests multiple state-of-the-art models across diverse simulation scenarios:</p>
<p><strong>Models tested:</strong></p>
<ul>
<li><strong>PET-C/PET-NC</strong> (Point Edge Transformer, conservative and non-conservative variants)</li>
<li><strong>PET-M</strong> (hybrid variant jointly predicting both conservative and non-conservative forces)</li>
<li><strong>ORB-v2</strong> (non-conservative, trained on Alexandria/MPtrj)</li>
<li><strong>EquiformerV2</strong> (non-conservative equivariant Transformer)</li>
<li><strong>MACE-MP-0</strong> (conservative message-passing)</li>
<li><strong>SevenNet</strong> (conservative message-passing)</li>
<li><strong>SOAP-BPNN-C/SOAP-BPNN-NC</strong> (descriptor-based baseline, both conservative and non-conservative variants)</li>
</ul>
<p><strong>Test scenarios:</strong></p>
<ol>
<li><strong>NVE stability tests</strong> on bulk liquid water, graphene, amorphous carbon, and FCC aluminum</li>
<li><strong>Thermostat artifact analysis</strong> with Langevin and GLE thermostats</li>
<li><strong>Geometry optimization</strong> on water snapshots and <a href="/notes/chemistry/datasets/qm9/">QM9</a> molecules using FIRE and L-BFGS</li>
<li><strong>MTS validation</strong> on OC20 catalysis dataset</li>
<li><strong>Species-resolved temperature measurements</strong> for equipartition testing</li>
</ol>
<p><strong>Key metrics:</strong></p>
<ul>
<li>Jacobian asymmetry ($\lambda$)</li>
<li>Kinetic temperature drift in NVE</li>
<li>Velocity-velocity correlations</li>
<li>Radial distribution functions</li>
<li>Species-resolved temperatures</li>
<li>Inference speed benchmarks</li>
</ul>
<h2 id="results-simulation-instability-and-hybrid-solutions">Results: Simulation Instability and Hybrid Solutions</h2>
<p>Purely non-conservative models are <strong>unsuitable for production simulations</strong> due to uncontrollable unphysical artifacts that no thermostat can correct. Key findings:</p>
<p><strong>Performance failures:</strong></p>
<ul>
<li>Non-conservative models exhibited catastrophic temperature drift in NVE simulations: ~7,000 billion K/s for PET-NC and ~70,000 billion K/s for ORB, with EquiformerV2 comparable to PET-NC</li>
<li>Strong Langevin thermostats ($\tau=10$ fs) damped diffusion by ~5x, negating the speed benefits of non-conservative models</li>
<li>Advanced GLE thermostats also failed to control non-conservative drift (ORB reached 1181 K vs. 300 K target)</li>
<li>Equipartition violations: under stochastic velocity rescaling, O and H atoms equilibrated at different temperatures. For ORB, H atoms reached 336 K and O atoms 230 K against a 300 K target. For PET-NC, deviations were smaller but still significant (H at 296 K, O at 310 K).</li>
<li>Geometry optimization was more fragile with non-conservative forces: inaccurate NC models (SOAP-BPNN-NC) failed catastrophically, while more accurate ones (PET-NC) could converge with FIRE but showed large force fluctuations with L-BFGS. Non-conservative models consistently had lower success rates across water and QM9 benchmarks.</li>
</ul>
<p><strong>Hybrid solution success:</strong></p>
<ul>
<li>MTS with non-conservative forces corrected every 8 steps ($M=8$) achieved conservative stability with only ~20% overhead compared to a purely non-conservative trajectory. Results were essentially indistinguishable from fully conservative simulations. Higher stride values ($M=16$) became unstable due to resonances between fast degrees of freedom and integration errors.</li>
<li>Conservative fine-tuning achieved the accuracy of from-scratch training in about 1/3 the total training time (2-4x resource reduction)</li>
<li>Validated on OC20 catalysis benchmark</li>
</ul>
<p><strong>Scaling caveat:</strong> The authors note that as training datasets grow and models become more expressive, non-conservative artifacts should diminish because accurate models naturally exhibit less non-conservative behavior. However, they argue the best path forward is hybrid approaches rather than waiting for scale to solve the problem.</p>
<p><strong>Recommendation:</strong> The optimal production path is hybrid architectures using direct forces for acceleration (via MTS and pre-training) while anchoring models in conservative energy surfaces. This captures computational benefits without sacrificing physical reliability.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Primary training/evaluation:</strong></p>
<ul>
<li><strong>Bulk Liquid Water</strong> (Cheng et al., 2019): revPBE0-D3 calculations with over 250,000 force/energy targets, chosen for rigorous thermodynamic testing</li>
</ul>
<p><strong>Generalization tests:</strong></p>
<ul>
<li>Graphene, amorphous carbon, FCC aluminum (tested with general-purpose foundation models)</li>
</ul>
<p><strong>Benchmarks:</strong></p>
<ul>
<li><strong>QM9</strong>: Geometry optimization tests</li>
<li><strong>OC20</strong> (Open Catalyst): Oxygen on alloy surfaces for MTS validation</li>
</ul>
<p>All datasets publicly available through cited sources.</p>
<h3 id="models">Models</h3>
<p><strong>Point Edge Transformer (PET)</strong> variants:</p>
<ul>
<li><strong>PET-C (Conservative)</strong>: Forces via energy backpropagation</li>
<li><strong>PET-NC (Non-Conservative)</strong>: Direct force prediction head, slightly higher parameter count</li>
<li><strong>PET-M (Hybrid)</strong>: Jointly predicts both conservative and non-conservative forces, accuracy within ~10% of the best single-task models</li>
</ul>
<p><strong>Baseline comparisons:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Type</th>
          <th>Training Data</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>ORB-v2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Rotationally unconstrained</td>
      </tr>
      <tr>
          <td>EquiformerV2</td>
          <td>Non-conservative</td>
          <td>Alexandria/MPtrj</td>
          <td>Equivariant Transformer</td>
      </tr>
      <tr>
          <td>MACE-MP-0</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>Conservative</td>
          <td>MPtrj</td>
          <td>Equivariant message-passing</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-C</td>
          <td>Conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
      <tr>
          <td>SOAP-BPNN-NC</td>
          <td>Non-conservative</td>
          <td>Bulk water</td>
          <td>Descriptor-based baseline</td>
      </tr>
  </tbody>
</table>
<p><strong>Training details:</strong></p>
<ul>
<li><strong>Loss functions</strong>: PET-C uses joint Energy + Force $L^2$ loss; PET-NC uses Force-only $L^2$ loss</li>
<li><strong>Fine-tuning protocol</strong>: PET-NC converted to conservative via energy head fine-tuning</li>
<li><strong>MTS configuration</strong>: Non-conservative forces with conservative corrections every 8 steps ($M=8$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics &amp; Software:</strong>
Molecular dynamics evaluations were performed using <strong>i-PI</strong>, while geometry optimizations used <strong>ASE (Atomic Simulation Environment)</strong>. Note that primary code reproducibility is provided via an archived Zenodo snapshot; the authors did not link a live, public GitHub repository.</p>
<ol>
<li><strong>Jacobian asymmetry</strong> ($\lambda$): Quantifies non-conservation via antisymmetric component</li>
<li><strong>Temperature drift</strong>: NVE ensemble stability</li>
<li><strong>Velocity-velocity correlation</strong> ($\hat{c}_{vv}(\omega)$): Thermostat artifact detection</li>
<li><strong>Radial distribution functions</strong> ($g(r)$): Structural accuracy</li>
<li><strong>Species-resolved temperature</strong>: Equipartition testing</li>
<li><strong>Inference speed</strong>: Wall-clock time per MD step</li>
</ol>
<p><strong>Key results:</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Speed (ms/step)</th>
          <th>NVE Stability</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>PET-NC</td>
          <td>8.58</td>
          <td>Failed</td>
          <td>~7,000 billion K/s drift</td>
      </tr>
      <tr>
          <td>PET-C</td>
          <td>19.4</td>
          <td>Stable</td>
          <td>2.3x slower than PET-NC</td>
      </tr>
      <tr>
          <td>SevenNet</td>
          <td>52.8</td>
          <td>Stable</td>
          <td>Conservative baseline</td>
      </tr>
      <tr>
          <td><strong>PET Hybrid (MTS)</strong></td>
          <td><strong>~10.3</strong></td>
          <td><strong>Stable</strong></td>
          <td><strong>~20% overhead vs. pure NC</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Thermostat artifacts:</strong></p>
<ul>
<li>Langevin ($\tau=10$ fs) dampened diffusion by ~5x (weaker coupling at $\tau=100$ fs reduced diffusion by ~1.5x)</li>
<li>GLE thermostats also failed to control non-conservative drift</li>
<li>Equipartition violations under SVR: ORB showed H at 336 K and O at 230 K (target 300 K); PET-NC showed smaller but significant species-resolved deviations</li>
</ul>
<p><strong>Optimization failures:</strong></p>
<ul>
<li>Non-conservative models showed lower geometry optimization success rates across water and QM9 benchmarks, with inaccurate NC models failing catastrophically</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p><strong>Compute resources:</strong></p>
<ul>
<li><strong>Training</strong>: From-scratch baseline models were trained using 4x Nvidia H100 GPUs (over a duration of around two days).</li>
<li><strong>Fine-Tuning</strong>: Conservative fine-tuning was performed using a single (1x) Nvidia H100 GPU for a duration of one day.</li>
<li>This hybrid fine-tuning approach achieved a 2-4x reduction in computational resources compared to training conservative models from scratch.</li>
</ul>
<p><strong>Reproduction resources:</strong></p>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://zenodo.org/records/14778891">Zenodo repository</a></td>
          <td>Code/Data</td>
          <td>Unknown</td>
          <td>Code and data to reproduce all results</td>
      </tr>
      <tr>
          <td><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS inference tutorial</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Multiple time-stepping dynamics tutorial</td>
      </tr>
      <tr>
          <td><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative fine-tuning tutorial</a></td>
          <td>Other</td>
          <td>Unknown</td>
          <td>Fine-tuning workflow tutorial</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Bigi, F., Langer, M. F., &amp; Ceriotti, M. (2025). The dark side of the forces: assessing non-conservative force models for atomistic machine learning. <em>Proceedings of the 42nd International Conference on Machine Learning</em>, PMLR 267.</p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{bigi2025dark,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{The dark side of the forces: assessing non-conservative force models for atomistic machine learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Bigi, Filippo and Langer, Marcel F and Ceriotti, Michele}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span>=<span style="color:#e6db74">{Vancouver, Canada}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://icml.cc/virtual/2025/poster/45458">ICML 2025 poster page</a></li>
<li><a href="https://openreview.net/pdf?id=OEl3L8osas">PDF on OpenReview</a></li>
<li><a href="https://zenodo.org/records/14778891">Zenodo repository</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-mad-nc/pet-mad-nc.html">MTS Inference Tutorial</a></li>
<li><a href="https://atomistic-cookbook.org/examples/pet-finetuning/pet-ft-nc.html">Conservative Fine-Tuning Tutorial</a></li>
</ul>
]]></content:encoded></item><item><title>Beyond Atoms: 3D Space Modeling for Molecular Pretraining</title><link>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/beyond-atoms/</link><pubDate>Sat, 23 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/molecular-simulation/ml-potentials/beyond-atoms/</guid><description>Lu et al. introduce SpaceFormer, a Transformer that models entire 3D molecular space including atoms for superior representations.</description><content:encoded><![CDATA[<h2 id="paper-typology-and-contribution">Paper Typology and Contribution</h2>
<p>This is a <strong>Method</strong> paper. It challenges the atom-centric paradigm of molecular representation learning by proposing a novel framework that models the continuous 3D space surrounding atoms. The core contribution is <strong>SpaceFormer</strong>, a Transformer-based architecture that discretizes molecular space into grids to capture physical phenomena (electron density, electromagnetic fields) often missed by traditional point-cloud models.</p>
<h2 id="the-physical-intuition-modeling-empty-space">The Physical Intuition: Modeling &ldquo;Empty&rdquo; Space</h2>
<p><strong>The Gap</strong>: Prior 3D molecular representation models, such as Uni-Mol, treat molecules as discrete sets of atoms, essentially point clouds in 3D space. However, from a quantum physics perspective, the &ldquo;empty&rdquo; space between atoms is far from empty. It is permeated by electron density distributions and electromagnetic fields that determine molecular properties.</p>
<p><strong>The Hypothesis</strong>: Explicitly modeling this continuous 3D space alongside discrete atom positions yields superior representations for downstream tasks, particularly for computational properties that depend on electronic structure, such as HOMO/LUMO energies and energy gaps.</p>
<h2 id="a-surprising-observation-virtual-points-improve-representations">A Surprising Observation: Virtual Points Improve Representations</h2>
<p>Before proposing SpaceFormer, the authors present a simple yet revealing experiment. They augment Uni-Mol by adding randomly sampled virtual points (VPs) from the 3D space within the circumscribed cuboid of each molecule. These VPs carry no chemical information whatsoever: they are purely random noise points.</p>
<p>The result is surprising: adding just 10 random VPs already yields a noticeable improvement in validation loss. The improvement remains consistent and gradually increases as the number of VPs grows, eventually reaching a plateau. This observation holds across downstream tasks as well, with Uni-Mol + VPs improving on several quantum property predictions (LUMO, E1-CC2, E2-CC2) compared to vanilla Uni-Mol.</p>
<p>The implication is that even uninformative spatial context helps the model learn better representations, motivating a principled framework for modeling the full 3D molecular space.</p>
<h2 id="spaceformer-voxelization-and-3d-positional-encodings">SpaceFormer: Voxelization and 3D Positional Encodings</h2>
<p>The key innovation is treating the molecular representation problem as <strong>3D space modeling</strong>. SpaceFormer follows these core steps:</p>
<ol>
<li><strong>Voxelizes the entire 3D space</strong> into a grid with cells of $0.49\text{\AA}$ (based on O-H bond length to ensure at most one atom per cell).</li>
<li><strong>Uses adaptive multi-resolution grids</strong> to efficiently handle empty space, keeping it fine-grained near atoms and coarse-grained far away.</li>
<li><strong>Applies Transformers to 3D spatial tokens</strong> with custom positional encodings that achieve linear complexity.</li>
</ol>
<p>Specifically, the model utilizes two forms of 3D Positional Encoding:</p>
<p><strong>3D Directional PE (RoPE Extension)</strong>
They extend Rotary Positional Encoding (RoPE) to 3D continuous space by splitting the Query and Key vectors into three blocks (one for each spatial axis). The directional attention mechanism takes the form:</p>
<p>$$
\begin{aligned}
\mathbf{q}_{i}^{\top} \mathbf{k}_{j} = \sum_{s=1}^{3} \mathbf{q}_{i,s}^{\top} \mathbf{R}(c_{j,s} - c_{i,s}) \mathbf{k}_{j,s}
\end{aligned}
$$</p>
<p><strong>3D Distance PE (RFF Approximation)</strong>
To compute invariant geometric distance without incurring quadratic memory overhead, they use Random Fourier Features (RFF) to approximate a Gaussian kernel of pairwise distances:</p>
<p>$$
\begin{aligned}
\exp \left( - \frac{| \mathbf{c}_i - \mathbf{c}_j |_2^2}{2\sigma^2} \right) &amp;\approx z(\mathbf{c}_i)^\top z(\mathbf{c}_j) \\
z(\mathbf{c}_i) &amp;= \sqrt{\frac{2}{d}} \cos(\sigma^{-1} \mathbf{c}_i^\top \boldsymbol{\omega} + \mathbf{b})
\end{aligned}
$$</p>
<p>This approach enables the model to natively encode complex field-like phenomena without computing exhaustive $O(N^2)$ distance matrices.</p>
<h2 id="experimental-setup-and-downstream-tasks">Experimental Setup and Downstream Tasks</h2>
<p><strong>Pretraining Data</strong>: 19 million unlabeled molecules from the same dataset used by Uni-Mol.</p>
<p><strong>Downstream Benchmarks</strong>: The authors propose a new benchmark of 15 tasks, motivated by known limitations of MoleculeNet: invalid structures, inconsistent chemical representations, data curation errors, and an inability to adequately distinguish model performance. The tasks split into two categories:</p>
<ol>
<li>
<p><strong>Computational Properties (Quantum Mechanics)</strong></p>
<ul>
<li>Subsets of <a href="/notes/chemistry/datasets/gdb-17/">GDB-17</a> (HOMO, LUMO, GAP energy prediction, 20K samples; E1-CC2, E2-CC2, f1-CC2, f2-CC2, 21.7K samples)</li>
<li>Cata-condensed polybenzenoid hydrocarbons (Dipole moment, adiabatic ionization potential, D3 dispersion correction, 8,678 samples)</li>
<li>Metric: Mean Absolute Error (MAE)</li>
</ul>
</li>
<li>
<p><strong>Experimental Properties (Pharma/Bio)</strong></p>
<ul>
<li>MoleculeNet tasks (BBBP, BACE for drug discovery)</li>
<li>Biogen ADME tasks (HLM, MME, Solubility)</li>
<li>Metrics: AUC for classification, MAE for regression</li>
</ul>
</li>
</ol>
<p><strong>Splitting Strategy</strong>: All datasets use 8:1:1 train/validation/test ratio with <strong>scaffold splitting</strong> to test out-of-distribution generalization.</p>
<p><strong>Training Setup</strong>:</p>
<ul>
<li><strong>Objective</strong>: Masked Auto-Encoder (MAE) with 30% random masking. Model predicts whether a cell contains an atom, and if so, regresses both atom type and precise offset position.</li>
<li><strong>Hardware</strong>: ~50 hours on 8 NVIDIA A100 GPUs</li>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9, \beta_2=0.99$)</li>
<li><strong>Learning Rate</strong>: Peak 1e-4 with linear decay and 0.01 warmup ratio</li>
<li><strong>Batch Size</strong>: 128</li>
<li><strong>Total Updates</strong>: 1 million</li>
</ul>
<p><strong>Baseline Comparisons</strong>: GROVER (2D graph-based MPR), GEM (2D graph enhanced with 3D information), 3D Infomax (GNN with 3D information), Uni-Mol (3D MPR, primary baseline using the same pretraining dataset), and Mol-AE (extends Uni-Mol with atom-based MAE pretraining).</p>
<h2 id="results-and-analysis">Results and Analysis</h2>
<p><strong>Strong Contextual Performance</strong>: SpaceFormer ranked 1st in 10 of 15 tasks and in the top 2 for 14 of 15 tasks. It surpassed the runner-up models by approximately 20% on quantum property tasks (HOMO, LUMO, GAP, E1-CC2, Dipmom), validating that modeling non-atom space captures electronic structure better than atom-only regimes.</p>
<h3 id="key-results-on-quantum-properties">Key Results on Quantum Properties</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>GROVER</th>
          <th>GEM</th>
          <th>3D Infomax</th>
          <th>Uni-Mol</th>
          <th>Mol-AE</th>
          <th><strong>SpaceFormer</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>HOMO (Ha)</td>
          <td>0.0075</td>
          <td>0.0068</td>
          <td>0.0065</td>
          <td>0.0052</td>
          <td>0.0050</td>
          <td><strong>0.0042</strong></td>
      </tr>
      <tr>
          <td>LUMO (Ha)</td>
          <td>0.0086</td>
          <td>0.0080</td>
          <td>0.0070</td>
          <td>0.0060</td>
          <td>0.0057</td>
          <td><strong>0.0040</strong></td>
      </tr>
      <tr>
          <td>GAP (Ha)</td>
          <td>0.0109</td>
          <td>0.0107</td>
          <td>0.0095</td>
          <td>0.0081</td>
          <td>0.0080</td>
          <td><strong>0.0064</strong></td>
      </tr>
      <tr>
          <td>E1-CC2 (eV)</td>
          <td>0.0101</td>
          <td>0.0090</td>
          <td>0.0089</td>
          <td>0.0067</td>
          <td>0.0070</td>
          <td><strong>0.0058</strong></td>
      </tr>
      <tr>
          <td>Dipmom (Debye)</td>
          <td>0.0752</td>
          <td>0.0289</td>
          <td>0.0291</td>
          <td>0.0106</td>
          <td>0.0113</td>
          <td><strong>0.0083</strong></td>
      </tr>
  </tbody>
</table>
<p>SpaceFormer&rsquo;s advantage is most pronounced on computational properties that depend on electronic structure. On experimental biological tasks (e.g., BBBP), where measurements are noisy, the advantage narrows or reverses: Uni-Mol achieves 0.9066 AUC on BBBP compared to SpaceFormer&rsquo;s 0.8605.</p>
<h3 id="ablation-studies">Ablation Studies</h3>
<p>The authors present several ablations that isolate the source of SpaceFormer&rsquo;s improvements:</p>
<p><strong>MAE vs. Denoising</strong>: SpaceFormer with MAE pretraining outperforms SpaceFormer with denoising on all four ablation tasks. The MAE objective requires predicting <em>whether</em> an atom exists in a masked voxel, which forces the model to learn global structural dependencies. In the denoising variant, only atom cells are masked so the model never needs to predict atom existence, reducing the task to coordinate regression.</p>
<p><strong>FLOPs Control</strong>: A SpaceFormer-Large model (4x width, atom-only) trained with comparable FLOPs still falls short of SpaceFormer with 1000 non-atom cells on most downstream tasks. This confirms the improvement comes from modeling 3D space, not from additional compute.</p>
<p><strong>Virtual Points vs. SpaceFormer</strong>: Adding up to 200 random virtual points to Uni-Mol improves some tasks but leaves a significant gap compared to SpaceFormer, demonstrating that principled space discretization outperforms naive point augmentation.</p>
<p><strong>Efficiency Validation</strong>: The Adaptive Grid Merging method reduces the number of cells by roughly 10x with virtually no performance degradation. The 3D positional encodings scale linearly with the number of cells, while Uni-Mol&rsquo;s pretraining cost scales quadratically.</p>
<h3 id="scope-and-future-directions">Scope and Future Directions</h3>
<p>SpaceFormer does not incorporate built-in SE(3) equivariance, relying instead on data augmentation (random rotations and random boundary padding) during training. The authors identify extending SpaceFormer to force field tasks and larger systems such as proteins and complexes as promising future directions.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="code-and-data-availability">Code and Data Availability</h3>
<ul>
<li><strong>Source Code</strong>: As of the current date, the authors have not released the official source code or pre-trained weights.</li>
<li><strong>Datasets</strong>: Pretraining utilized the same 19M unlabeled molecule dataset as Uni-Mol. Downstream tasks use a newly curated internal benchmark built from subsets of GDB-17, MoleculeNet, and Biogen ADME. The exact customized scaffold splits for these evaluations are pending the official code release.</li>
<li><strong>Compute</strong>: Pretraining the base SpaceFormer encoder (~67.8M parameters, configured to merge level 3) required approximately 50 hours on 8 NVIDIA A100 GPUs.</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Source code</td>
          <td>Code</td>
          <td>N/A</td>
          <td>Not publicly released as of March 2026</td>
      </tr>
      <tr>
          <td>Pre-trained weights</td>
          <td>Model</td>
          <td>N/A</td>
          <td>Not publicly released</td>
      </tr>
      <tr>
          <td>Pretraining data (19M molecules)</td>
          <td>Dataset</td>
          <td>Unknown</td>
          <td>Same dataset as Uni-Mol; not independently released</td>
      </tr>
      <tr>
          <td>Downstream benchmark splits</td>
          <td>Dataset</td>
          <td>N/A</td>
          <td>Custom scaffold splits pending code release</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>The model treats a molecule as a 3D &ldquo;image&rdquo; via voxelization, processed by a Transformer.</p>
<p><strong>Input Representation</strong>:</p>
<ul>
<li><strong>Discretization</strong>: 3D space divided into grid cells with length <strong>$0.49\text{\AA}$</strong> (based on O-H bond length to ensure at most one atom per cell)</li>
<li><strong>Tokenization</strong>: Tokens are pairs $(t_i, c_i)$ where $t_i$ is atom type (or NULL) and $c_i$ is the coordinate</li>
<li><strong>Embeddings</strong>: Continuous embeddings with dimension 512. Inner-cell positions discretized with $0.01\text{\AA}$ precision</li>
</ul>
<p><strong>Transformer Specifications</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Component</th>
          <th>Layers</th>
          <th>Attention Heads</th>
          <th>Embedding Dim</th>
          <th>FFN Dim</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Encoder</strong></td>
          <td>16</td>
          <td>8</td>
          <td>512</td>
          <td>2048</td>
      </tr>
      <tr>
          <td><strong>Decoder</strong> (MAE)</td>
          <td>4</td>
          <td>4</td>
          <td>256</td>
          <td>1024</td>
      </tr>
  </tbody>
</table>
<p><strong>Attention Mechanism</strong>: FlashAttention for efficient handling of large sequence lengths.</p>
<p><strong>Positional Encodings</strong>:</p>
<ol>
<li><strong>3D Directional PE</strong>: Extension of Rotary Positional Embedding (RoPE) to 3D continuous space, capturing relative directionality</li>
<li><strong>3D Distance PE</strong>: Random Fourier Features (RFF) to approximate Gaussian kernel of pairwise distances with linear complexity</li>
</ol>
<h4 id="visualizing-rff-and-rope">Visualizing RFF and RoPE</h4>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-rff-rope-visualization.webp"
         alt="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         title="Four-panel visualization showing RFF distance encoding and RoPE directional encoding mechanisms"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Visual intuition for SpaceFormer&rsquo;s positional encodings: Top row shows RFF distance encoding (Gaussian-like attention decay and high-frequency feature fingerprints). Bottom row shows RoPE directional encoding (vector rotation fields and resulting attention patterns).</figcaption>
    
</figure>

<p><strong>Top Row (Distance / RFF):</strong> Shows how the model learns &ldquo;closeness.&rdquo; Distance is represented by a complex &ldquo;fingerprint&rdquo; of waves that creates a Gaussian-like force field.</p>
<ul>
<li><strong>Top Left (The Force Field):</strong> The attention score (dot product) naturally forms a Gaussian curve. It is high when atoms are close and decays to zero as they move apart. This mimics physical forces without the model needing to learn that math from scratch.</li>
<li><strong>Top Right (The Fingerprint):</strong> Each dimension oscillates at a different frequency. A specific distance (e.g., $d=2$) has a unique combination of high and low values across these dimensions, creating a unique &ldquo;fingerprint&rdquo; for that exact distance.</li>
</ul>
<p><strong>Bottom Row (Direction / RoPE):</strong> Shows how the model learns &ldquo;relative position.&rdquo; It visualizes the vector rotation and how that creates a grid-like attention pattern.</p>
<ul>
<li><strong>Bottom Left (The Rotation):</strong> This visualizes the &ldquo;X-axis chunk&rdquo; of the vector. As you move from left ($x=-3$) to right ($x=3$), the arrows rotate. The model compares angles between atoms to determine relative positions.</li>
<li><strong>Bottom Right (The Grid):</strong> The resulting attention pattern when combining X-rotations and Y-rotations. The red/blue regions show where the model pays attention relative to the center, forming a grid-like interference pattern that distinguishes relative positions (e.g., &ldquo;top-right&rdquo; vs &ldquo;bottom-left&rdquo;).</li>
</ul>
<h4 id="adaptive-grid-merging">Adaptive Grid Merging</h4>
<p>To make the 3D grid approach computationally tractable, two key strategies are employed:</p>
<ol>
<li><strong>Grid Sampling</strong>: Randomly selecting 10-20% of empty cells during training</li>
<li><strong>Adaptive Grid Merging</strong>: Recursively merging $2 \times 2 \times 2$ blocks of empty cells into larger &ldquo;coarse&rdquo; cells, creating a multi-resolution view that is fine-grained near atoms and coarse-grained in empty space (merging set to Level 3)</li>
</ol>
<p><strong>Visualizing Adaptive Grid Merging</strong>:</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-merging.webp"
         alt="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         title="2D simulation of adaptive grid merging for an H2O molecule showing multi-resolution cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging demonstrated on H₂O. Red cells (Level 0) contain atoms and remain at full resolution. Progressively darker blue cells represent merged empty regions at higher levels, covering the same volume with fewer tokens.</figcaption>
    
</figure>

<p>The adaptive grid process compresses empty space around molecules while maintaining high resolution near atoms:</p>
<ul>
<li><strong>Red Cells (Level 0):</strong> The smallest squares ($0.49$Å) containing atoms. These are kept at highest resolution because electron density changes rapidly here.</li>
<li><strong>Light Blue Cells (Level 0/1):</strong> Small empty regions close to atoms.</li>
<li><strong>Darker Blue Cells (Level 2/3):</strong> Large blocks of empty space further away.</li>
</ul>
<p>If we used a naive uniform grid, we would have to process thousands of empty &ldquo;Level 0&rdquo; cells containing almost zero information. By merging them into larger blocks (the dark blue squares), the model covers the same volume with significantly fewer input tokens, reducing the number of tokens by roughly <strong>10x</strong> compared to a dense grid.</p>















<figure class="post-figure center ">
    <img src="/img/notes/spaceformer-adaptive-grid-benzene.webp"
         alt="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         title="Adaptive grid merging visualization for benzene molecule showing hexagonal ring with multi-resolution grid cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Adaptive grid merging for benzene (C₆H₆). The model maintains maximum resolution (red Level 0 cells) only where atoms exist, while merging vast empty regions into large blocks (dark blue L3/L4 cells). This allows the model to focus computational power on chemically active zones.</figcaption>
    
</figure>

<p>The benzene example above demonstrates how this scales to larger molecules. The characteristic hexagonal ring of 6 carbon atoms (black) and 6 hydrogen atoms (white) occupies a small fraction of the total grid. The dark blue corners (L3, L4) represent massive merged blocks of empty space, allowing the model to focus 90% of its computational power on the red &ldquo;active&rdquo; zones where chemistry actually happens.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lu, S., Ji, X., Zhang, B., Yao, L., Liu, S., Gao, Z., Zhang, L., &amp; Ke, G. (2025). Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling. <em>Proceedings of the 42nd International Conference on Machine Learning (ICML)</em>, 267, 40491-40504. <a href="https://proceedings.mlr.press/v267/lu25e.html">https://proceedings.mlr.press/v267/lu25e.html</a></p>
<p><strong>Publication</strong>: ICML 2025</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lu2025beyond,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Beyond Atoms: Enhancing Molecular Pretrained Representations with 3D Space Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Lu, Shuqi and Ji, Xiaohong and Zhang, Bohang and Yao, Lin and Liu, Siyuan and Gao, Zhifeng and Zhang, Linfeng and Ke, Guolin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Proceedings of the 42nd International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{40491--40504}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{267}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span>=<span style="color:#e6db74">{Proceedings of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span>=<span style="color:#e6db74">{PMLR}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=Wd9KPQCKwq">OpenReview forum</a></li>
<li><a href="https://openreview.net/pdf?id=Wd9KPQCKwq">PDF on OpenReview</a></li>
<li><a href="https://icml.cc/virtual/2025/poster/45004">ICML 2025 poster page</a></li>
</ul>
]]></content:encoded></item><item><title>Contrastive Learning for Variational Autoencoder Priors</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</link><pubDate>Sun, 17 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</guid><description>Aneja et al.'s NeurIPS 2021 paper introducing Noise Contrastive Priors (NCPs) to address VAE's 'prior hole' problem with energy-based priors.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a training approach for Variational Autoencoders (VAEs) to address fundamental limitations in their generative quality through improved prior learning.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by a critical limitation in Variational Autoencoders known as the <strong>&ldquo;prior hole&rdquo; problem</strong>, where the prior distribution p(z) fails to match the aggregate approximate posterior q(z). This mismatch leads to areas in the latent space with high density under the prior that don&rsquo;t map to realistic data samples, resulting in poor generative quality compared to GANs and other generative models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/vae-prior-hole-problem-illustrated.webp"
         alt="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         title="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;prior hole&rsquo; problem: the standard Gaussian prior (red dashed contours) assigns highest probability to the center, but the aggregate posterior (blue dots) forms a ring with no data in that region.</figcaption>
    
</figure>

<p>The figure above illustrates this mismatch. The blue dots represent where a trained encoder actually places data in the latent space (the aggregate posterior $q(z)$), which often forms complex, non-Gaussian shapes. The red dashed contours show the standard Gaussian prior $p(z) = \mathcal{N}(0, I)$, which assumes data is centered at the origin. When generating new samples, we draw from this prior, making it likely to sample from the empty &ldquo;hole&rdquo; where the decoder has never seen training data, producing unrealistic outputs.</p>
<p>A natural question arises: the prior $p(z)$ is used for <em>sampling</em> at inference time, so why does learning a better prior also improve <em>likelihood</em> (NLL)? The answer lies in the VAE objective. VAEs maximize the Evidence Lower Bound (ELBO):</p>
<p>$$ \log p(x) \geq \mathcal{L}_{\text{ELBO}}(x) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(z|x) \parallel p(z))}_{\text{Regularization}} $$</p>
<p>The KL divergence term penalizes the mismatch between each data point&rsquo;s approximate posterior $q(z|x)$ and the prior $p(z)$. When the prior is a simple Gaussian but the aggregate posterior forms a complex shape (as in the figure above), this KL term remains unnecessarily high for every data point.</p>
<p>By replacing the simple prior with a learned $p_{\text{NCP}}(z)$ that matches the aggregate posterior, the KL penalty decreases, tightening the ELBO and improving NLL. The learned prior thus provides a <strong>unified solution</strong>: better likelihood during training (tighter bound) and better sampling at inference (no &ldquo;holes&rdquo;).</p>
<p>The OpenReview discussion contains a significant theoretical debate regarding the paper&rsquo;s core premise. Reviewers argued that the &ldquo;prior hole&rdquo; problem is actually a failure of the posterior to match the prior, or a failure of the encoder. The authors defended their approach by noting that even with a perfect posterior, a simple Normal prior might fail because the decoder lacks capacity to map a simple distribution to complex data without dropping modes. This justifies fixing the prior by making it learned and complex.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The authors propose an <strong>energy-based model (EBM) prior</strong> that is trained using <strong>Noise Contrastive Estimation (NCE)</strong>, which they term a <strong>Noise Contrastive Prior (NCP)</strong>. The key innovations are:</p>
<ul>
<li><strong>Two-Stage Training Process</strong>: First, a standard VAE is trained with a simple base prior. Then, the VAE weights are frozen and a binary classifier learns to distinguish between samples from the aggregate posterior q(z) and the base prior p(z).</li>
<li><strong>Reweighting Strategy</strong>: The core idea is to reweight a base prior distribution p(z) with a learned reweighting factor r(z) to make the resulting prior $p_{\text{NCP}}(z)$ better match the aggregate posterior q(z).</li>
<li><strong>NCE for EBM Training</strong>: The method frames EBM training as a binary classification task to avoid computationally expensive MCMC sampling.</li>
<li><strong>Scalability to Hierarchical Models</strong>: For hierarchical VAEs with multiple latent groups, the NCP approach can be applied independently and in parallel to each group&rsquo;s conditional prior.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling</li>
<li><strong>CelebA 64x64</strong>: Applied to both standard VAE architectures and more advanced VAEs with GMM priors (RAE model)</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>The hierarchical NVAE models used 30 latent groups for CIFAR-10 and CelebA-64, 20 groups for CelebA-HQ-256, and 10 groups of $4 \times 4$ latent variables for MNIST (deliberately small to enable reliable partition function estimation). The experiments compared FID scores, likelihood metrics, and qualitative sample quality between baseline VAEs and NCP-enhanced versions, with particular focus on hierarchical VAEs (NVAE).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The proposed NCP method demonstrated improvements in generative quality across evaluated datasets, with modest gains on standard VAEs and particularly large gains on hierarchical models like NVAE:</p>
<ul>
<li><strong>CelebA-64</strong>: NCP improved FID scores from 48.12 to 41.28 for standard VAEs, and from 40.95 to 39.00 for RAE models with GMM priors.</li>
<li><strong>Hierarchical Models (NVAE)</strong>: The impact was particularly pronounced on hierarchical VAEs:
<ul>
<li><strong>CIFAR-10</strong>: FID improved from 51.71 to 24.08</li>
<li><strong>CelebA-64</strong>: FID improved from 13.48 to 5.25, making it competitive with GANs</li>
<li><strong>CelebA HQ 256x256</strong>: FID reduced from 40.26 to 24.79</li>
</ul>
</li>
<li><strong>Likelihood Performance</strong>: On MNIST, NCP-VAE achieved 78.10 nats NLL vs. baseline NVAE&rsquo;s 78.67 nats</li>
</ul>
<p>On CIFAR-10 and CelebA-HQ-256, the concurrent VAEBM method (which forms an EBM on the data space rather than the latent space) outperforms NCP-VAE. However, the authors argue the two approaches are complementary: NCP-VAE targets the latent space while VAEBM operates in data space, and combining them could yield further gains. NCP-VAE also has the advantage of applicability to discrete data (e.g., binarized MNIST) and simpler setup since it only requires training binary classifiers rather than MCMC-based training and sampling.</p>
<p>The key conclusions are that <strong>two-stage training with noise contrastive estimation</strong> provides an effective framework for learning expressive energy-based priors that addresses the prior hole problem while scaling efficiently to hierarchical models.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code (Google Drive)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation; hosted on Google Drive (may become inaccessible)</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">Reviews, author responses, and supplementary material</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<h4 id="the-reweighting-mechanism">The Reweighting Mechanism</h4>
<p>The core innovation is defining the NCP prior as $p_{\text{NCP}}(z) \propto p(z)r(z)$. The reweighting factor $r(z)$ is derived from the binary classifier $D(z)$ using the <strong>likelihood ratio trick</strong>:</p>
<p>$$ r(z) \approx \frac{D(z)}{1 - D(z)} $$</p>
<p>Here, $D(z)$ is the sigmoid output of the trained discriminator, representing the probability that sample $z$ came from the aggregate posterior $q(z)$ (&ldquo;real&rdquo;). For an optimal discriminator $D^*(z)$, this ratio exactly equals $\frac{q(z)}{p(z)}$, allowing the model to approximate the density ratio without explicit density estimation.</p>















<figure class="post-figure center ">
    <img src="/img/notes/ncp-vae-reweighting-the-prior-posterior.webp"
         alt="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         title="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reweighting mechanism: the learned factor $r(z)$ (bottom) reweights the simple Gaussian prior $p(z)$ (middle) to approximate the complex aggregate posterior $q(z)$ (top). Where $q(z)$ has high density but $p(z)$ is low, $r(z)$ compensates with high values.</figcaption>
    
</figure>

<h4 id="hierarchical-architecture-strategy">Hierarchical Architecture Strategy</h4>
<p>For hierarchical models (like NVAE), the method trains $K$ binary classifiers in parallel (one for each latent group). Crucially, to ensure efficiency, the classifiers reuse the <strong>context feature</strong> $c(z_{&lt;k})$ extracted by the frozen VAE&rsquo;s prior network. This architectural choice provides significant computational savings.</p>
<h4 id="test-time-sampling-inference">Test-Time Sampling (Inference)</h4>
<p>Since $p_{\text{NCP}}(z)$ is an energy-based model, direct sampling is impossible. The paper employs two methods to generate samples:</p>
<ul>
<li><strong>Sampling-Importance-Resampling (SIR):</strong> Used for most results. It draws $M$ samples (e.g., $M=5000$) from the base prior $p(z)$ and resamples them based on weights $w^{(m)} = r(z^{(m)})$.</li>
<li><strong>Langevin Dynamics (LD):</strong> An iterative refinement method using the gradient of the energy function $E(z) = -\log r(z) - \log p(z)$.</li>
</ul>
<h3 id="models">Models</h3>
<h4 id="decoder-architecture">Decoder Architecture</h4>
<p>For RGB datasets (CIFAR-10, CelebA), the output likelihood must be changed from <strong>Discretized Logistic</strong> (standard NVAE) to a <strong>Normal distribution</strong>. The authors note this change alone led to &ldquo;significant improvements in the base model performance.&rdquo; Using the standard NVAE decoder will result in a weaker baseline than reported.</p>
<h4 id="discriminator-architecture">Discriminator Architecture</h4>
<p>The binary classifier uses a ResNet-style architecture with <strong>Squeeze-and-Excitation (SE)</strong> blocks:</p>
<ul>
<li><strong>Activation:</strong> Swish</li>
<li><strong>Normalization:</strong> Batch Normalization</li>
<li><strong>Optimization:</strong> Adam with Cosine Annealing (learning rate: $10^{-3} \to 10^{-7}$)</li>
</ul>
<p>The SE blocks help the model focus on channel-wise feature recalibration, which is important for distinguishing subtle differences between prior and aggregate posterior in high-dimensional latent spaces.</p>
<h3 id="hardware">Hardware</h3>
<p>The main paper is vague on training time, but the OpenReview rebuttal explicitly lists hardware costs:</p>
<ul>
<li><strong>Hardware:</strong> NVIDIA Tesla V100 (32GB) GPUs</li>
<li><strong>Per-Discriminator Training:</strong> ~13 hours for 100 epochs</li>
<li><strong>Parallelization:</strong> Because latent groups are independent, all discriminators can train in parallel</li>
<li><strong>Total Cost (CelebA-64):</strong> ~8.1 GPU-days</li>
<li><strong>Comparison:</strong> The authors argue this is efficient compared to VDVAE, which requires ~560 GPU-days</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<h4 id="inference-speed-vs-quality-trade-off">Inference Speed vs. Quality Trade-off</h4>
<p>Reviewers flagged that SIR sampling can be prohibitively slow. The authors clarified the exact trade-off:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Proposal Samples ($M$)</th>
          <th style="text-align: left">Time per Image</th>
          <th style="text-align: left">FID (CelebA-64)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">5,000 (paper default)</td>
          <td style="text-align: left">~10.11 seconds</td>
          <td style="text-align: left">5.25</td>
      </tr>
      <tr>
          <td style="text-align: left">500 (practical)</td>
          <td style="text-align: left">~1.25 seconds</td>
          <td style="text-align: left">6.76</td>
      </tr>
  </tbody>
</table>
<p>The quality gain from 500 to 5,000 samples is modest (FID difference of 1.51) while inference time increases roughly 8x, suggesting $M=500$ may be a practical default.</p>
<h4 id="hyperparameters">Hyperparameters</h4>
<ul>
<li><strong>FID Calculation:</strong> 50,000 samples</li>
<li><strong>SIR Proposals:</strong> 5,000 samples (paper default) or 500 (practical)</li>
<li><strong>MNIST:</strong> Dynamically binarized version used for likelihood evaluation</li>
<li><strong>Optimizers:</strong> The study largely adopts hyperparameters from baseline papers (e.g., Lawson et al. for MNIST, Ghosh et al. for RAE)</li>
</ul>
<h4 id="debugging-benchmark-25-gaussians">Debugging Benchmark: 25-Gaussians</h4>
<p>The supplement provides a toy experiment ideal for verifying a new implementation before running on expensive image datasets:</p>
<ul>
<li><strong>Setup:</strong> Synthetic dataset of 25 2D-Gaussians arranged on a grid</li>
<li><strong>Target NLL:</strong> ~-0.954 nats (NCP) vs. ~-2.753 nats (Vanilla VAE)</li>
<li><strong>Success Criterion:</strong> Samples should avoid low-density regions between grid points. A standard VAE will generate samples in these &ldquo;prior holes,&rdquo; while a working NCP implementation should cleanly remove these artifacts.</li>
</ul>
<h4 id="implementation-warnings">Implementation Warnings</h4>
<ul>
<li><strong>SIR Failure Mode:</strong> If the learned prior $p_{\text{NCP}}$ deviates too far from the base prior, SIR sampling collapses (low effective sample size). The paper shows a strong correlation between the NCE classification loss and the effective sample size (Fig. 5(b)), indicating that SIR reliability depends on how well the base prior matches the aggregate posterior.</li>
<li><strong>Temperature Scaling:</strong> The qualitative images in the paper use reduced temperature for improved visual sharpness (Section 5.3). The FID tables do not specify a temperature, so results may or may not use $T=1.0$.</li>
</ul>
<h3 id="data">Data</h3>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling (32x32 RGB images)</li>
<li><strong>CelebA 64x64</strong>: Face generation task with moderate resolution</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>All datasets use standard train/test splits from the computer vision literature.</p>
<h4 id="additional-metrics">Additional Metrics</h4>
<p>Beyond FID and NLL, the paper uses:</p>
<ul>
<li><strong>Effective Sample Size (ESS):</strong> Validates reliability of the SIR sampling procedure</li>
<li><strong>Maximum Mean Discrepancy (MMD):</strong> Measures distance between aggregate posterior and NCP prior distributions</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Aneja, J., Schwing, A. G., Kautz, J., &amp; Vahdat, A. (2021). A contrastive learning approach for training variational autoencoder priors. <em>Advances in Neural Information Processing Systems</em>, 34, 29604-29616. <a href="https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html">https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{aneja2021contrastive,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Contrastive Learning Approach for Training Variational Autoencoder Priors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Aneja, Jyoti and Schwing, Alexander G and Kautz, Jan and Vahdat, Arash}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{29604--29616}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview Discussion</a></li>
<li><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code Repository</a> (Google Drive; link may become inaccessible over time)</li>
</ul>
]]></content:encoded></item><item><title>SubGrapher: Visual Fingerprinting of Chemical Structures</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/subgrapher/</link><pubDate>Mon, 28 Apr 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/vision-language/subgrapher/</guid><description>SubGrapher creates molecular fingerprints directly from chemical structure images through functional group segmentation for database retrieval.</description><content:encoded><![CDATA[<h2 id="paper-classification-and-taxonomy">Paper Classification and Taxonomy</h2>
<p>This is primarily a <strong>Methodological Paper ($\Psi_{\text{Method}}$)</strong> with a secondary <strong>Resource ($\Psi_{\text{Resource}}$)</strong> contribution. Using the <a href="/notes/interdisciplinary/research-methods/ai-physical-sciences-paper-taxonomy/">AI and Physical Sciences paper taxonomy</a> framework:</p>
<p><strong>Primary Classification: Method</strong></p>
<p>The dominant basis vector is Methodological because SubGrapher introduces an architecture that replaces the two-step OCSR workflow (image, then structure, then fingerprint) with single-step fingerprinting (image to visual fingerprint). The paper validates this approach through systematic comparison against state-of-the-art methods (MolGrapher, OSRA, DECIMER, MolScribe), demonstrating superior performance on specific tasks like retrieval and robustness to image quality degradation.</p>
<p><strong>Secondary Classification: Resource</strong></p>
<p>The paper makes non-negligible resource contributions by releasing:</p>
<ul>
<li>Code and model weights on <a href="https://github.com/DS4SD/SubGrapher">GitHub</a> and <a href="https://huggingface.co/docling-project/SubGrapher">HuggingFace</a></li>
<li>Five new visual fingerprinting benchmark datasets for molecule retrieval tasks</li>
<li>Comprehensive functional group knowledge base (1,534 substructures)</li>
</ul>
<h2 id="motivation-extracting-complex-structures-from-noisy-images">Motivation: Extracting Complex Structures from Noisy Images</h2>
<p>The motivation tackles a fundamental challenge in chemical informatics: extracting molecular information from the vast amounts of unstructured scientific literature, particularly patents. Millions of molecular structures exist only as images in these documents, making them inaccessible for computational analysis, database searches, or machine learning applications.</p>
<p>Traditional Optical Chemical Structure Recognition (OCSR) tools attempt to fully reconstruct molecular graphs from images, converting them into machine-readable formats like SMILES. However, these approaches face two critical limitations:</p>
<ol>
<li><strong>Brittleness to image quality</strong>: Poor resolution, noise, or unconventional drawing styles frequently degrade recognition accuracy</li>
<li><strong>Limited handling of complex structures</strong>: Markush structures, generic molecular templates with variable R-groups commonly used in patents, are poorly supported by most conventional OCSR methods</li>
</ol>
<p>The key insight driving SubGrapher is that full molecular reconstruction may be unnecessary for many applications. For tasks like database searching, similarity analysis, or document retrieval, a molecular fingerprint - a vectorized representation capturing structural features - is often sufficient. This realization opens up a new approach: bypass the fragile reconstruction step and create fingerprints directly from visual information.</p>
<h2 id="key-innovation-direct-visual-fingerprinting">Key Innovation: Direct Visual Fingerprinting</h2>
<p>SubGrapher takes a different approach to extracting chemical information from images. It creates &ldquo;visual fingerprints&rdquo; through functional group recognition. The key innovations are:</p>
<ol>
<li>
<p><strong>Direct Image-to-Fingerprint Pipeline</strong>: SubGrapher eliminates the traditional two-step process (image → structure → fingerprint) by generating fingerprints directly from pixels. This single-stage approach avoids error accumulation from failed structure reconstructions and can handle images where conventional OCSR tools produce invalid outputs.</p>
</li>
<li>
<p><strong>Dual Instance Segmentation Architecture</strong>: The system employs two specialized Mask-RCNN networks working in parallel:</p>
<ul>
<li><strong>Functional group detector</strong>: Trained to identify 1,534 expert-defined functional groups using pixel-level segmentation masks</li>
<li><strong>Carbon backbone detector</strong>: Recognizes 27 common carbon chain patterns to capture the molecular scaffold</li>
</ul>
<p>Using instance segmentation provides detailed spatial information and higher accuracy through richer supervision during training.</p>
</li>
<li>
<p><strong>Extensive Functional Group Knowledge Base</strong>: The method uses one of the most comprehensive open-source collections of functional groups, encompassing 1,534 substructures. These were systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Additional halogen substituents and organometallic groups relevant to EUV photoresists</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li>
<p><strong>Substructure-Graph Construction</strong>: After detecting functional groups and carbon backbones, SubGrapher builds a connectivity graph where:</p>
<ul>
<li>Each node represents an identified substructure</li>
<li>Edges connect substructures whose bounding boxes overlap (with 10% margin expansion)</li>
<li>This graph captures both the chemical components and their spatial relationships</li>
</ul>
</li>
<li>
<p><strong>Substructure-based Visual Molecular Fingerprint (SVMF)</strong>: The final output is a continuous, count-based fingerprint formally defined as a matrix $SVMF(m) \in \mathbb{R}^{n \times n}$ where $n=1561$ (1,534 functional groups + 27 carbon backbones). The matrix is stored as a compressed upper triangular form:</p>
<p><strong>Diagonal elements</strong> ($i = j$): Weighted count of substructure $i$ plus self-intersection
$$SVMF_{ii}(m) = h_1 \cdot n_i + g_{ii}$$
where $h_1 = 10$ is the diagonal weight hyperparameter, $n_i$ is the instance count, and $g_{ii}$ is the self-intersection coefficient.</p>
<p><strong>Off-diagonal elements</strong> ($i \neq j$): Intersection coefficient based on shortest path distance $d$ in the substructure graph
$$SVMF_{ij}(m) = h_2(d) \cdot \text{intersection}(s_i, s_j)$$
where the distance decay function $h_2(d)$ is:</p>
<ul>
<li>$d \leq 1$: weight = 2</li>
<li>$d = 2$: weight = 2/4 = 0.5</li>
<li>$d = 3$: weight = 2/16 = 0.125</li>
<li>$d = 4$: weight = $2/256 \approx 0.0078$</li>
<li>$d &gt; 4$: weight = 0</li>
</ul>
<p><strong>Key properties</strong>:</p>
<ul>
<li>Carbon chain intersection coefficients are divided by 2, giving functional groups higher effective weight</li>
<li>Similarity between fingerprints calculated using a normalized Euclidean distance (ratio of L2 norm of difference to L2 norm of sum)</li>
<li>Resulting fingerprints are highly sparse (average 0.001% non-zero elements)</li>
<li>Compressed storage enables efficient database searches</li>
</ul>
</li>
<li>
<p><strong>Markush Structure Compatibility</strong>: SubGrapher processes Markush structures by recognizing their constituent functional groups and creating meaningful fingerprints for similarity searches, achieving higher accuracy than existing OCSR methods on the USPTO-Markush benchmark (S-F1: 88).</p>
</li>
</ol>
<h2 id="experimental-validation-and-benchmarks">Experimental Validation and Benchmarks</h2>
<p>The evaluation focused on demonstrating SubGrapher&rsquo;s effectiveness across two critical tasks: accurate substructure detection and robust molecule retrieval from diverse image collections.</p>
<h4 id="substructure-detection-performance">Substructure Detection Performance</h4>
<p>SubGrapher&rsquo;s ability to identify functional groups was tested on three challenging benchmarks that expose different failure modes of OCSR systems:</p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Size</th>
          <th>Description</th>
          <th>Key Challenge</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>JPO</strong></td>
          <td>341 images</td>
          <td>Japanese Patent Office images (molecules with abbreviations removed)</td>
          <td>Low quality, noise, artifacts, non-standard drawing styles</td>
      </tr>
      <tr>
          <td><strong>USPTO-10K-L</strong></td>
          <td>1,000 images</td>
          <td>Large molecules (&gt;70 atoms)</td>
          <td>Scale variation, structural complexity, many functional groups</td>
      </tr>
      <tr>
          <td><strong>USPTO-Markush</strong></td>
          <td>74 images</td>
          <td>Generic Markush structures</td>
          <td>Variable R-groups, abstract patterns, template representation</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings:</strong></p>
<ol>
<li>
<p><strong>JPO Dataset (Low-Quality Patent Images)</strong>: SubGrapher achieved the highest Molecule Exact Match rate (83%), demonstrating robustness to image quality degradation where rule-based methods like OSRA scored lower (67% M-EM).</p>
</li>
<li>
<p><strong>USPTO-10K-L (Large Molecules)</strong>: SubGrapher achieved an S-F1 of 97, matching the rule-based OSRA and outperforming all other learning-based methods (MolScribe: 90, DECIMER: 86, MolGrapher: 56). The object detection approach handled scale variation better than other deep-learning OCSR tools on these challenging targets.</p>
</li>
<li>
<p><strong>USPTO-Markush (Generic Structures)</strong>: SubGrapher achieved the highest Substructure F1-score (88) on this benchmark, outperforming MolScribe (86), OSRA (74), and DECIMER (10). While other OCSR tools can attempt these images, they have limited support for Markush features. SubGrapher&rsquo;s instance segmentation approach handles complex Markush structures more effectively by focusing on relevant image regions.</p>
</li>
</ol>
<p>Qualitative analysis revealed that SubGrapher correctly identified functional groups in scenarios where other methods failed completely: images with captions, unconventional drawing styles, or significant quality degradation.</p>
<h4 id="visual-fingerprinting-for-molecule-retrieval">Visual Fingerprinting for Molecule Retrieval</h4>
<p>The core application was evaluated using a retrieval task designed to simulate real-world database searching:</p>
<ol>
<li>
<p><strong>Benchmark Creation</strong>: Five benchmark datasets were constructed around structurally similar molecules (adenosine, camphor, cholesterol, limonene, and pyridine), each containing 500 molecules sampled from PubChem with at least 90% Tanimoto similarity to the reference molecule, rendered as augmented images.</p>
</li>
<li>
<p><strong>Retrieval Task</strong>: Given a SMILES string as a query, the goal was to find the corresponding molecular image within the dataset of 500 visually similar structures. This tests whether the visual fingerprint can distinguish between closely related molecules.</p>
</li>
<li>
<p><strong>Performance Comparison</strong>: SubGrapher significantly outperformed baseline methods, retrieving the correct molecule at an average rank of 95 out of 500. The key advantage was robustness: SubGrapher generates a unique fingerprint for every image, even with partial or uncertain predictions. In contrast, OCSR-based methods frequently fail to produce valid SMILES, making them unable to generate fingerprints for comparison.</p>
</li>
<li>
<p><strong>Real-World Case Study</strong>: A practical demonstration involved searching a 54-page patent document containing 356 chemical images for a specific Markush structure. SubGrapher successfully located the target structure, highlighting its utility for large-scale document mining.</p>
</li>
</ol>
<h4 id="training-data-generation">Training Data Generation</h4>
<p>Since no public datasets existed with the required pixel-level mask annotations for functional groups, the researchers developed a comprehensive synthetic data generation pipeline:</p>
<ol>
<li>
<p><strong>Extended MolDepictor</strong>: They enhanced existing molecular rendering tools to create images from SMILES strings and generate corresponding segmentation masks for all substructures present in each molecule.</p>
</li>
<li>
<p><strong>Markush Structure Rendering</strong>: The pipeline was extended to handle complex generic structures using CXSMILES representations and the CDK library for rendering, creating training data for molecular templates with structural, positional, and frequency variation indicators.</p>
</li>
<li>
<p><strong>Diverse Molecular Sources</strong>: Training molecules were sourced from PubChem to ensure broad chemical diversity and coverage of different structural families.</p>
</li>
</ol>
<h2 id="results-impact-and-limitations">Results, Impact, and Limitations</h2>
<ul>
<li><strong>Superior Robustness to Image Quality</strong>: SubGrapher consistently outperformed traditional OCSR methods on degraded images, particularly the JPO patent dataset. SubGrapher&rsquo;s learned representations proved more resilient to noise, artifacts, and unconventional drawing styles than rule-based alternatives like OSRA (M-EM: 83 vs. 67 on JPO).</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>SubGrapher</th>
          <th>MolScribe</th>
          <th>OSRA</th>
          <th>DECIMER</th>
          <th>MolGrapher</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>S-F1</strong> (JPO)</td>
          <td>92</td>
          <td><strong>94</strong></td>
          <td>81</td>
          <td>86</td>
          <td>89</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (JPO)</td>
          <td><strong>83</strong></td>
          <td>82</td>
          <td>67</td>
          <td>79</td>
          <td>80</td>
      </tr>
      <tr>
          <td><strong>S-F1</strong> (USPTO-10K-L)</td>
          <td><strong>97</strong></td>
          <td>90</td>
          <td><strong>97</strong></td>
          <td>86</td>
          <td>56</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (USPTO-10K-L)</td>
          <td>55</td>
          <td>55</td>
          <td><strong>75</strong></td>
          <td>66</td>
          <td>31</td>
      </tr>
      <tr>
          <td><strong>S-F1</strong> (USPTO-Markush)</td>
          <td><strong>88</strong></td>
          <td>86</td>
          <td>74</td>
          <td>10</td>
          <td>35</td>
      </tr>
      <tr>
          <td><strong>M-EM</strong> (USPTO-Markush)</td>
          <td>82</td>
          <td><strong>86</strong></td>
          <td>70</td>
          <td>11</td>
          <td>30</td>
      </tr>
      <tr>
          <td><strong>Avg Retrieval Rank</strong></td>
          <td><strong>95/500</strong></td>
          <td>181-241/500</td>
          <td>138-185/500</td>
          <td>N/A</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<p>Note: Retrieval rank ranges reflect the best and worst fingerprint method pairing for each OCSR model (RDKit Daylight or MHFP).</p>
<ul>
<li>
<p><strong>Effective Handling of Scale and Complexity</strong>: The instance segmentation approach successfully managed large molecules and complex structures where traditional graph-reconstruction methods struggled. The Substructure F1-scores on USPTO-10K-L and USPTO-Markush benchmarks demonstrated clear advantages for challenging molecular targets.</p>
</li>
<li>
<p><strong>Markush Structure Processing</strong>: SubGrapher achieves the highest Substructure F1-score on Markush structures (88 vs. MolScribe&rsquo;s 86 and OSRA&rsquo;s 74). While other OCSR methods can attempt Markush images, they support only limited features such as abbreviation-based variable groups. SubGrapher handles complex Markush features more effectively, expanding the scope of automatically extractable chemical information from patent literature.</p>
</li>
<li>
<p><strong>Robust Molecule Retrieval Performance</strong>: The visual fingerprinting approach achieved reliable retrieval performance (average rank 95/500) across diverse molecular families. The key advantage was consistency: SubGrapher generates meaningful fingerprints even from partial or uncertain predictions, while OCSR-based methods often fail to produce any usable output.</p>
</li>
<li>
<p><strong>Practical Document Mining Capability</strong>: The successful identification of specific Markush structures within large patent documents (54 pages, 356 images) demonstrates real-world applicability for large-scale literature mining and intellectual property analysis.</p>
</li>
<li>
<p><strong>Single-Stage Architecture Benefits</strong>: By eliminating the traditional image → structure → fingerprint pipeline, SubGrapher avoids error accumulation from failed molecular reconstructions. Every input image produces a fingerprint, making the system more reliable for batch processing of diverse document collections.</p>
</li>
<li>
<p><strong>Limitations and Scope</strong>: The method remains focused on common organic functional groups and may struggle with inorganic chemistry, organometallic complexes, or highly specialized molecular classes not well-represented in the training data. The 1,534 functional groups, while extensive, represent a curated subset of chemical space. SubGrapher also cannot distinguish enantiomers, as the detected substructures lack stereochemistry information. Additionally, the method currently cannot recognize substructures in abbreviations or single-atom fragments.</p>
</li>
</ul>
<p>The work demonstrates that direct fingerprint generation can be more robust and practical than traditional structure reconstruction approaches. SubGrapher&rsquo;s ability to handle Markush structures and degraded images makes it particularly valuable for patent analysis and large-scale document mining, where traditional OCSR methods frequently fail. The approach suggests that task-specific learning (fingerprints for retrieval) can outperform general-purpose reconstruction methods in many practical applications.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Training Data Generation</strong>: The paper developed a custom synthetic data pipeline since no public datasets existed with pixel-level mask annotations for functional groups:</p>
<ul>
<li><strong>Extended MolDepictor</strong>: Enhanced molecular rendering tool to generate both images and corresponding segmentation masks for all substructures</li>
<li><strong>Markush Structure Rendering</strong>: Pipeline extended to handle complex generic structures</li>
<li><strong>Source Molecules</strong>: PubChem for broad chemical diversity</li>
</ul>
<p><strong>Evaluation Benchmarks</strong>:</p>
<ul>
<li><strong>JPO Dataset</strong>: Real patent images with poor resolution, noise, and artifacts</li>
<li><strong>USPTO-10K-L</strong>: Large complex molecular structures</li>
<li><strong>USPTO-Markush</strong>: Generic structures with variable R-groups</li>
<li><strong>Retrieval Benchmarks</strong>: Five datasets (adenosine, camphor, cholesterol, limonene, pyridine), each with 500 similar molecular images</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Architecture</strong>: Dual instance segmentation system using Mask-RCNN</p>
<ul>
<li><strong>Functional Group Detector</strong>: Mask-RCNN trained to identify 1,534 expert-defined functional groups</li>
<li><strong>Carbon Backbone Detector</strong>: Mask-RCNN trained to recognize 27 common carbon chain patterns</li>
<li><strong>Backbone Network</strong>: Not specified in the paper</li>
</ul>
<p><strong>Functional Group Knowledge Base</strong>: 1,534 substructures systematically defined by:</p>
<ul>
<li>Starting with chemically logical atom combinations (C, O, S, N, B, P)</li>
<li>Expanding to include relevant subgroups and variations</li>
<li>Filtering based on frequency (appearing ~1,000+ times in PubChem)</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Functional Group Definition</strong>:</p>
<ul>
<li><strong>1,534 Functional Groups</strong>: Defined by manually curated SMARTS patterns
<ul>
<li>Must contain heteroatoms (O, N, S, P, B)</li>
<li>Frequency threshold: ~1,000+ occurrences in PubChem</li>
<li>Systematically constructed from chemically logical atom combinations</li>
<li>Manual curation with SMILES, SMARTS, and descriptive names</li>
</ul>
</li>
<li><strong>27 Carbon Backbones</strong>: Patterns of 3-6 carbon atoms (rings and chains) to capture molecular scaffolds</li>
</ul>
<p><strong>Substructure-Graph Construction</strong>:</p>
<ol>
<li>Detect functional groups and carbon backbones using Mask-RCNN models</li>
<li>Build connectivity graph:
<ul>
<li>Each node represents an identified substructure instance</li>
<li>Edges connect substructures whose bounding boxes overlap</li>
<li>Bounding boxes expanded by 10% of smallest box&rsquo;s diagonal to ensure connectivity between adjacent groups</li>
<li>Carbon chain intersection coefficients divided by 2, giving functional groups higher effective weight</li>
</ul>
</li>
</ol>
<p><strong>SVMF Fingerprint Generation</strong>:</p>
<ul>
<li>Matrix form: $SVMF(m) \in \mathbb{R}^{n \times n}$ where $n=1561$</li>
<li>Stored as compressed sparse upper triangular matrix</li>
<li><strong>Diagonal elements</strong>: $SVMF_{ii} = h_1 \cdot n_i + g_{ii}$ where $h_1 = 10$</li>
<li><strong>Off-diagonal elements</strong>: $SVMF_{ij} = h_2(d) \cdot \text{intersection}(s_i, s_j)$ where:
<ul>
<li>$h_2(d) = 2$ for $d = 0, 1$</li>
<li>$h_2(2) = 2/4$, $h_2(3) = 2/16$, $h_2(4) = 2/256$</li>
<li>$h_2(d) = 0$ for $d &gt; 4$</li>
</ul>
</li>
<li>Average sparsity: 0.001% non-zero elements</li>
<li>Similarity metric: Normalized Euclidean distance (L2 norm of difference divided by L2 norm of sum)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>Substructure F1-score (S-F1)</strong>: Harmonic mean of precision and recall for individual substructure detection across all molecules in the dataset</li>
<li><strong>Molecule Exact Match (M-EM)</strong>: Percentage of molecules where S-F1 = 1.0 (all substructures correctly identified)</li>
<li><strong>Retrieval Rank</strong>: Average rank of ground truth molecule in candidate list of 500 similar structures when querying with SMILES fingerprint, averaged across 50 queries per benchmark</li>
</ul>
<p><strong>Baselines</strong>: Compared against SOTA OCSR methods:</p>
<ul>
<li>Deep learning: MolScribe, MolGrapher, DECIMER</li>
<li>Rule-based: OSRA</li>
<li>Fingerprint methods: RDKit Daylight, MHFP (applied to OCSR outputs)</li>
</ul>
<h3 id="hardware">Hardware</h3>
<p>Not specified in the paper. Training and inference hardware details are not provided in the main text or would be found in supplementary materials.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/DS4SD/SubGrapher">SubGrapher (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official inference code</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/docling-project/SubGrapher">SubGrapher (HuggingFace)</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>Pre-trained model weights</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/docling-project/SubGrapher-Datasets">SubGrapher-Datasets (HuggingFace)</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Visual fingerprinting benchmark datasets</td>
      </tr>
  </tbody>
</table>
<h3 id="implementation-gaps">Implementation Gaps</h3>
<p>The following details are not available in the paper and would require access to the code repository or supplementary information:</p>
<ul>
<li>Specific backbone architecture for Mask-RCNN (ResNet variant, Swin Transformer, etc.)</li>
<li>Optimizer type (AdamW, SGD, etc.)</li>
<li>Learning rate and scheduler</li>
<li>Batch size and number of training epochs</li>
<li>Loss function weights (box loss vs. mask loss)</li>
<li>GPU/TPU specifications used for training</li>
<li>Training time and computational requirements</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Morin, L., Meijer, G. I., Weber, V., Van Gool, L., &amp; Staar, P. W. J. (2025). SubGrapher: Visual fingerprinting of chemical structures. Journal of Cheminformatics, 17(1), 149. <a href="https://doi.org/10.1186/s13321-025-01091-4">https://doi.org/10.1186/s13321-025-01091-4</a></p>
<p><strong>Publication</strong>: Journal of Cheminformatics (2025)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{morinSubGrapherVisualFingerprinting2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{SubGrapher: Visual Fingerprinting of Chemical Structures}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{SubGrapher}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Morin, Lucas and Meijer, Gerhard Ingmar and Weber, Valéry and Van Gool, Luc and Staar, Peter W. J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Journal of Cheminformatics}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{17}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{1}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{149}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1186/s13321-025-01091-4}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>3D Steerable CNNs: Rotationally Equivariant Features</title><link>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/3d-steerable-cnns/</link><pubDate>Thu, 16 Jan 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/geometric-deep-learning/3d-steerable-cnns/</guid><description>Weiler et al.'s NeurIPS 2018 paper introducing SE(3)-equivariant CNNs for volumetric data using group theory and spherical harmonics.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a novel neural network architecture, the 3D Steerable CNN. It provides a comprehensive theoretical derivation for the architecture grounded in group representation theory and demonstrates its practical application.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by the prevalence of <strong>symmetry</strong> in problems from the natural sciences. Standard 3D CNNs lack inherent equivariance to 3D rotations, a fundamental symmetry governed by the SE(3) group in many scientific datasets like molecular or protein structures. Building this symmetry directly into the model architecture as an <strong>inductive bias</strong> is expected to yield more data-efficient, generalizable, and physically meaningful models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/3d-cnn-versus-3d-steerable-cnn.webp"
         alt="Comparison of standard 3D CNN versus 3D Steerable CNN for handling rotational symmetry. Panel A shows how standard CNNs produce distorted outputs when inputs are rotated, requiring data augmentation. Panel B shows how Steerable CNNs use spherical harmonic kernel bases to produce equivariant geometric field outputs that transform predictably under rotation."
         title="Comparison of standard 3D CNN versus 3D Steerable CNN for handling rotational symmetry. Panel A shows how standard CNNs produce distorted outputs when inputs are rotated, requiring data augmentation. Panel B shows how Steerable CNNs use spherical harmonic kernel bases to produce equivariant geometric field outputs that transform predictably under rotation."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Standard 3D CNNs (Panel A) produce inconsistent feature maps when inputs are rotated, requiring expensive data augmentation. 3D Steerable CNNs (Panel B) use analytically-derived spherical harmonic kernels to produce geometric field outputs that transform equivariantly under rotation.</figcaption>
    
</figure>

<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the rigorous and practical construction of a CNN architecture that is equivariant to 3D rigid body motions (SE(3) group). The key contributions are:</p>
<ul>
<li><strong>Geometric Feature Representation</strong>: Features are modeled as geometric <strong>fields</strong> (collections of scalars, vectors, and higher-order tensors) defined over $\mathbb{R}^{3}$. Each type of feature transforms according to an <strong>irreducible representation (irrep)</strong> of the rotation group SO(3).</li>
<li><strong>General Equivariant Convolution</strong>: The paper proves that the most general form of an SE(3)-equivariant linear map between these fields is a convolution with a <strong>rotation-steerable kernel</strong>.</li>
<li><strong>Analytical Kernel Basis</strong>: The main theoretical breakthrough is the analytical derivation of a complete basis for these steerable kernels. They solve the kernel&rsquo;s equivariance constraint, $\kappa(rx) = D^{j}(r)\kappa(x)D^{l}(r)^{-1}$, showing the solutions are functions whose angular components are <strong>spherical harmonics</strong>. The network&rsquo;s kernels are then parameterized as a learnable linear combination of these pre-computed basis functions, making the implementation a minor modification to standard 3D convolutions.</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/spherical-harmonics.webp"
         alt="Spherical harmonics visualization showing the angular basis functions organized by degree l (rows) and order m (columns). Row 0 shows the single s-type orbital (l=0), row 1 shows three p-type orbitals (l=1), row 2 shows five d-type orbitals (l=2), and row 3 shows seven f-type orbitals (l=3)."
         title="Spherical harmonics visualization showing the angular basis functions organized by degree l (rows) and order m (columns). Row 0 shows the single s-type orbital (l=0), row 1 shows three p-type orbitals (l=1), row 2 shows five d-type orbitals (l=2), and row 3 shows seven f-type orbitals (l=3)."
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spherical harmonics $Y_l^m$ organized by degree $l$ (rows) and order $m$ (columns). These functions form the angular basis for steerable kernels: $l=0$ (scalar), $l=1$ (vector/p-orbital), $l=2$ (rank-2 tensor/d-orbital), $l=3$ (rank-3 tensor/f-orbital). Each degree $l$ has $2l+1$ components.</figcaption>
    
</figure>

<ul>
<li><strong>Equivariant Nonlinearity</strong>: A novel <strong>gated nonlinearity</strong> is proposed for non-scalar features. It preserves equivariance by multiplying a feature field by a separately computed, learned scalar field (the gate).</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The model&rsquo;s performance was evaluated on a series of tasks with inherent rotational symmetry:</p>
<ol>
<li><strong>Tetris Classification</strong>: A toy problem to empirically validate the model&rsquo;s rotational equivariance by training on aligned blocks and testing on randomly rotated ones.</li>
<li><strong>SHREC17 3D Model Classification</strong>: A benchmark for classifying complex 3D shapes that are arbitrarily rotated.</li>
<li><strong>Amino Acid Propensity Prediction</strong>: A scientific application to predict amino acid types from their 3D atomic environments.</li>
<li><strong>CATH Protein Structure Classification</strong>: A challenging task on a new dataset introduced by the authors, requiring classification of global protein architecture, a problem with full SE(3) invariance.</li>
</ol>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The 3D Steerable CNN demonstrated clear advantages due to its built-in equivariance:</p>
<ul>
<li>It was empirically confirmed to be <strong>rotationally equivariant</strong>, achieving $99 \pm 2%$ test accuracy on the rotated Tetris dataset (averaged over 17 runs), compared to a standard 3D CNN&rsquo;s $27 \pm 7%$ accuracy.</li>
<li>On the amino acid prediction task the model achieves 0.58 accuracy, compared to 0.50 (regular-grid) and 0.56 (concentric-grid) baselines, using roughly half the parameters. On SHREC17 it reaches a total score (micro + macro mAP) of 1.11, compared to 1.13 for the leading contemporary system.</li>
<li>On the CATH protein classification task, it <strong>outperformed a deep 3D CNN baseline</strong> while using ~110x fewer parameters. This performance gap widened as the training data was reduced, highlighting the model&rsquo;s superior <strong>data efficiency</strong>.</li>
</ul>
<p>The paper concludes that 3D Steerable CNNs provide a universal and effective framework for incorporating SE(3) symmetry into deep learning models, leading to improved accuracy and efficiency for tasks involving volumetric data, particularly in scientific domains.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Format</strong>: All inputs must be voxelized. Point clouds require voxelization before use.
<ul>
<li><strong>Proteins (CATH)</strong>: $50^3$ grid, 0.2 nm voxel size. Simplified to $C_\alpha$ atoms only; Gaussian density placed at each atom position.</li>
<li><strong>3D Objects (SHREC17)</strong>: $64^3$ voxel grids.</li>
<li><strong>Tetris</strong>: $36^3$ input grid.</li>
</ul>
</li>
<li><strong>Splitting Strategy</strong>: CATH used a 10-fold split (7 train, 1 val, 2 test) strictly separated by &ldquo;superfamily&rdquo; level to prevent data leakage (&lt;40% sequence identity).</li>
</ul>
<h3 id="models">Models</h3>
<p><strong>Kernel Basis Construction</strong>:</p>
<ul>
<li>Constructed from <strong>Spherical Harmonics</strong> multiplied by <strong>Gaussian Radial Shells</strong>: $\exp\left(-\frac{1}{2}(|x|-m)^{2}/\sigma^{2}\right)$</li>
<li><strong>Anti-aliasing</strong>: A radially dependent angular frequency cutoff ($J_{\max}$) is applied to prevent aliasing near the origin.</li>
</ul>
<p><strong>Normalization</strong>: Uses <strong>Equivariant Batch Norm</strong>. Non-scalar fields are normalized by the average of their norms.</p>
<p><strong>Downsampling</strong>: Standard strided convolution breaks equivariance. The architecture uses <strong>low-pass filtering</strong> (Gaussian blur) before downsampling to maintain equivariance.</p>
<p><strong>Exact Architecture Configurations</strong>:</p>
<p><strong>Tetris Architecture</strong> (4 layers):</p>
<table>
  <thead>
      <tr>
          <th>Layer</th>
          <th>Field Types</th>
          <th>Spatial Size</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Input</td>
          <td>1 scalar</td>
          <td>$36^3$</td>
      </tr>
      <tr>
          <td>Layer 1</td>
          <td>4 scalars, 4 vectors ($l=1$), 4 tensors ($l=2$), 1 tensor ($l=3$)</td>
          <td>$40^3$</td>
      </tr>
      <tr>
          <td>Layer 2</td>
          <td>16 scalars, 16 vectors, 16 tensors ($l=2$)</td>
          <td>$22^3$ (stride 2)</td>
      </tr>
      <tr>
          <td>Layer 3</td>
          <td>32 scalars, 16 vectors, 16 tensors ($l=2$)</td>
          <td>$13^3$ (stride 2)</td>
      </tr>
      <tr>
          <td>Layer 4</td>
          <td>128 scalars</td>
          <td>$17^3$</td>
      </tr>
      <tr>
          <td>Output</td>
          <td>8 scalars (global average pool)</td>
          <td>$1$</td>
      </tr>
  </tbody>
</table>
<p><strong>SHREC17 Architecture</strong> (8 layers):</p>
<table>
  <thead>
      <tr>
          <th>Layers</th>
          <th>Field Types</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1-2</td>
          <td>8 scalars, 4 vectors, 2 tensors ($l=2$)</td>
      </tr>
      <tr>
          <td>3-4</td>
          <td>16 scalars, 8 vectors, 4 tensors</td>
      </tr>
      <tr>
          <td>5-7</td>
          <td>32 scalars, 16 vectors, 8 tensors</td>
      </tr>
      <tr>
          <td>8</td>
          <td>512 scalars</td>
      </tr>
      <tr>
          <td>Output</td>
          <td>55 scalars (classes)</td>
      </tr>
  </tbody>
</table>
<p><strong>CATH Architecture</strong> (ResNet34-inspired):</p>
<p>Block progression: <code>(2,2,2,2)</code>, <code>(4,4,4,4)</code>, <code>(8,8,8,8)</code>, <code>(16,16,16,16)</code></p>
<p>Notation: <code>(a,b,c,d)</code> = $a$ scalars ($l=0$), $b$ vectors ($l=1$), $c$ rank-2 tensors ($l=2$), $d$ rank-3 tensors ($l=3$).</p>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Parameter Counts</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Model</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CATH</td>
          <td>3D Steerable CNN</td>
          <td>143,560</td>
      </tr>
      <tr>
          <td>CATH</td>
          <td>Baseline (ResNet34-style)</td>
          <td>15,878,764</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>3D Steerable CNN</td>
          <td>~32,600,000</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>Regular grid baseline</td>
          <td>~61,100,000</td>
      </tr>
      <tr>
          <td>Amino Acid</td>
          <td>Concentric grid baseline</td>
          <td>~75,300,000</td>
      </tr>
  </tbody>
</table>
<p>Note: The concentric grid baseline groups voxels by distance from the molecular center, reflecting that atomic interactions are primarily distance-dependent (Torng, W., &amp; Altman, R. B. (2017). 3D deep convolutional neural networks for amino acid environment similarity analysis. <em>BMC Bioinformatics</em>, 18, 302). Amino acid parameter counts are rounded figures as reported in the paper.</p>
<p><strong>Hyperparameters &amp; Training</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Parameter</th>
          <th>Value</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Optimizer</strong></td>
          <td>Adam</td>
      </tr>
      <tr>
          <td><strong>LR Scheduler</strong></td>
          <td>Exponential decay (0.94/epoch) after 40 epoch burn-in</td>
      </tr>
      <tr>
          <td><strong>Dropout</strong> (CATH)</td>
          <td>0.1 (Capsule-wide convolutional dropout)</td>
      </tr>
      <tr>
          <td><strong>Weight Decay</strong> (CATH)</td>
          <td>L1 &amp; L2 regularization: $10^{-8.5}$</td>
      </tr>
  </tbody>
</table>
<p><strong>Mathematical Formulations for Equivariance</strong>:</p>
<p>Standard operations like Batch Normalization and ReLU break rotational equivariance. The paper derives equivariant alternatives:</p>
<p><strong>Equivariant Batch Normalization</strong>:</p>
<p>Standard BN subtracts a mean, which introduces a preferred direction and breaks symmetry. <strong>Norm-based normalization</strong> scales feature fields by the average of their squared norms to preserve symmetry:</p>
<p>$$f_{i}(x) \mapsto f_{i}(x) \left( \frac{1}{|B|} \sum_{j \in B} \frac{1}{V} \int dx |f_{j}(x)|^{2} + \epsilon \right)^{-1/2}$$</p>
<p>This scales vector lengths to unit variance on average while avoiding mean subtraction, preserving directional information and symmetry.</p>
<p><strong>Equivariant Nonlinearities</strong>:</p>
<p>Applying ReLU to vector components independently breaks equivariance (this depends on the coordinate frame). Two approaches:</p>
<ol>
<li>
<p><strong>Norm Nonlinearity</strong> (geometric shrinking): Acts on magnitude, preserves direction. Shrinks vectors shorter than learned bias $\beta$ to zero:
$$f(x) \mapsto \text{ReLU}(|f(x)| - \beta) \frac{f(x)}{|f(x)|}$$
<em>Note: Found to converge slowly; omitted from final models.</em></p>
</li>
<li>
<p><strong>Gated Nonlinearity</strong> (used in practice): A separate scalar field $s(x)$ passes through sigmoid to create a gate $\sigma(s(x))$, which multiplies the geometric field:
$$f_{\text{out}}(x) = f_{\text{in}}(x) \cdot \sigma(s(x))$$
<em>Architecture implication: Requires extra scalar channels ($l=0$) specifically for gating higher-order channels ($l&gt;0$).</em></p>
</li>
</ol>
<p><strong>Voxelization Details</strong>:</p>
<p>For CATH protein inputs, Gaussian density is placed at each atom position with standard deviation equal to <strong>half the voxel width</strong> ($0.5 \times 0.2\text{ nm} = 0.1\text{ nm}$).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Task</th>
          <th>Metric</th>
          <th>Steerable CNN</th>
          <th>Baseline</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Tetris (rotated test)</td>
          <td>Accuracy</td>
          <td>$99 \pm 2%$</td>
          <td>$27 \pm 7%$ (standard 3D CNN)</td>
      </tr>
      <tr>
          <td>Amino Acid Propensity</td>
          <td>Accuracy</td>
          <td><strong>0.58</strong> (32.6M params)</td>
          <td>0.50 (regular grid, 61.1M params); 0.56 (concentric grid, 75.3M params)</td>
      </tr>
      <tr>
          <td>SHREC17</td>
          <td>micro + macro mAP (higher is better)</td>
          <td>1.11</td>
          <td>1.13 (SOTA)</td>
      </tr>
      <tr>
          <td>CATH</td>
          <td>Accuracy</td>
          <td>Higher across all training set sizes (see Figure 4; not reported as a single value) (143,560 params)</td>
          <td>Deep 3D CNN (15,878,764 params; ~110x more)</td>
      </tr>
  </tbody>
</table>
<p>Note: On SHREC17, the total score is micro mAP + macro mAP combined (higher is better). From Table 4 in the supplementary material: Steerable CNN micro mAP = 0.661, macro mAP = 0.449, total = 1.11. On CATH, the steerable CNN outperformed the baseline with ~110x fewer parameters, a gap that widened as training data was reduced.</p>
<h2 id="historical-context-from-peer-reviews">Historical Context (From Peer Reviews)</h2>
<p>The NeurIPS peer reviews reveal important context about the paper&rsquo;s structure and claims:</p>
<ul>
<li>
<p><strong>Evolution of Experiments</strong>: The <strong>SHREC17</strong> experiment and the <strong>arbitrary rotation</strong> test in Tetris were added during the rebuttal phase to address reviewer concerns about the lack of standard computer vision benchmarks. This explains why SHREC17 feels somewhat disconnected from the paper&rsquo;s &ldquo;AI for Science&rdquo; narrative.</p>
</li>
<li>
<p><strong>Continuous vs. Discrete Rotations</strong>: The Tetris experiment validates equivariance to <strong>continuous</strong> ($SO(3)$) rotations alongside discrete 90-degree turns. This distinction is crucial and separates Steerable CNNs from earlier Group CNNs that handled discrete rotation groups exclusively.</p>
</li>
<li>
<p><strong>Terminology Warning</strong>: The use of terms like &ldquo;fiber&rdquo; and &ldquo;induced representation&rdquo; was critiqued by reviewers as being denser than necessary and inconsistent with related work (e.g., Tensor Field Networks). If you find Section 3 difficult, this is a known barrier of this paper. Focus on the resulting kernel constraints.</p>
</li>
<li>
<p><strong>Parameter Efficiency Quantified</strong>: Reviewers highlighted that the main practical win is <strong>parameter efficiency</strong>. Standard 3D CNNs hit diminishing returns around $10^7$ parameters, while Steerable CNNs achieve better results with ~110x fewer parameters ($10^5$).</p>
</li>
</ul>
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden;">
      <iframe allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share; fullscreen" loading="eager" referrerpolicy="strict-origin-when-cross-origin" src="https://www.youtube-nocookie.com/embed/ENLJACPHSEA?autoplay=0&amp;controls=1&amp;end=0&amp;loop=0&amp;mute=0&amp;start=0" style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; border:0;" title="YouTube video"></iframe>
    </div>

<h2 id="artifacts">Artifacts</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/mariogeiger/se3cnn">se3cnn (GitHub)</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Original implementation; superseded by <a href="https://github.com/e3nn/e3nn">e3nn</a> for point cloud applications</td>
      </tr>
      <tr>
          <td><a href="https://github.com/wouterboomsma/cath_datasets">CATH Datasets (GitHub)</a></td>
          <td>Dataset</td>
          <td>CC-BY-4.0</td>
          <td>Protein structure classification dataset introduced in this paper</td>
      </tr>
  </tbody>
</table>
<p>Pre-trained model weights are not publicly released. Hardware and compute requirements are not specified in the paper.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Weiler, M., Geiger, M., Welling, M., Boomsma, W., &amp; Cohen, T. S. (2018). 3D steerable CNNs: Learning rotationally equivariant features in volumetric data. <em>Advances in Neural Information Processing Systems</em>, 31. <a href="https://proceedings.neurips.cc/paper/2018/hash/488e4104520c6aab692863cc1dba45af-Abstract.html">https://proceedings.neurips.cc/paper/2018/hash/488e4104520c6aab692863cc1dba45af-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2018</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{weiler20183d,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{3D Steerable CNNs: Learning Rotationally Equivariant Features in Volumetric Data}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Weiler, Maurice and Geiger, Mario and Welling, Max and Boomsma, Wouter and Cohen, Taco S}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{31}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2018}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/mariogeiger/se3cnn">GitHub Repository</a></li>
<li><a href="https://www.youtube.com/watch?v=ENLJACPHSEA">YouTube Video</a></li>
<li><a href="https://github.com/wouterboomsma/cath_datasets">CATH Dataset</a></li>
</ul>
]]></content:encoded></item><item><title>RFL: Simplifying Chemical Structure Recognition (AAAI 2025)</title><link>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/rfl/</link><pubDate>Thu, 19 Dec 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/chemistry/optical-structure-recognition/image-to-sequence/rfl/</guid><description>Ring-Free Language (RFL) and Molecular Skeleton Decoder (MSD) for improved optical chemical structure recognition from molecular images.</description><content:encoded><![CDATA[<h2 id="methodological-contribution">Methodological Contribution</h2>
<p>This is a <strong>Methodological</strong> paper ($\Psi_{\text{Method}}$). It introduces a novel representation system (Ring-Free Language) and a specialized neural architecture (Molecular Skeleton Decoder) designed to solve specific limitations in converting 2D images to 1D chemical strings. The paper validates this method through direct comparison with existing baselines and ablation studies.</p>
<h2 id="motivation-limitations-of-1d-serialization">Motivation: Limitations of 1D Serialization</h2>
<p>Current Optical Chemical Structure Recognition (OCSR) methods typically rely on &ldquo;unstructured modeling,&rdquo; where 2D molecular graphs are flattened into 1D strings like SMILES or SSML. While simple, these linear formats struggle to explicitly capture complex spatial relationships, particularly in molecules with multiple rings and branches. End-to-end models often fail to &ldquo;understand&rdquo; the graph structure when forced to predict these implicit 1D sequences, leading to error accumulation in complex scenarios.</p>
<h2 id="innovation-ring-free-language-rfl-and-molecular-skeleton-decoder-msd">Innovation: Ring-Free Language (RFL) and Molecular Skeleton Decoder (MSD)</h2>
<p>The authors propose two primary contributions to decouple spatial complexity:</p>
<ol>
<li><strong>Ring-Free Language (RFL)</strong>: A divide-and-conquer representation that splits a molecular graph $G$ into three explicit components: a molecular skeleton $\mathcal{S}$, individual ring structures $\mathcal{R}$, and branch information $\mathcal{F}$. This allows rings to be collapsed into &ldquo;SuperAtoms&rdquo; or &ldquo;SuperBonds&rdquo; during initial parsing.</li>
<li><strong>Molecular Skeleton Decoder (MSD)</strong>: A hierarchical architecture that progressively predicts the skeleton first, then the individual rings (using SuperAtom features as conditions), and finally classifies the branch connections.</li>
</ol>
<h2 id="methodology-and-experiments">Methodology and Experiments</h2>
<p>The method was evaluated on both handwritten and printed chemical structures against two baselines: DenseWAP (Zhang et al. 2018) and RCGD (Hu et al. 2023).</p>
<ul>
<li><strong>Datasets</strong>:
<ul>
<li><strong>EDU-CHEMC</strong>: ~49k handwritten samples (challenging, diverse styles)</li>
<li><strong>Mini-CASIA-CSDB</strong>: ~89k printed samples (from ChEMBL)</li>
<li><strong>Synthetic Complexity Dataset</strong>: A custom split of ChEMBL data grouped by structural complexity (atoms + bonds + rings) to test generalization</li>
</ul>
</li>
<li><strong>Ablation Studies</strong> (Table 2, on EDU-CHEMC with MSD-DenseWAP): Without MSD or <code>[conn]</code>, EM=38.70%. Adding <code>[conn]</code> alone raised EM to 44.02%. Adding MSD alone raised EM to 52.76%. Both together achieved EM=64.96%, confirming each component&rsquo;s contribution.</li>
</ul>
<h2 id="outcomes-and-conclusions">Outcomes and Conclusions</h2>
<ul>
<li><strong>New best results</strong>: MSD-RCGD achieved 65.39% EM on EDU-CHEMC (handwritten) and 95.23% EM on Mini-CASIA-CSDB (printed), outperforming the RCGD baseline (62.86% and 95.01%, respectively). MSD-DenseWAP surpassed the previous best on EDU-CHEMC by 2.06% EM (64.92% vs. 62.86%).</li>
<li><strong>Universal improvement</strong>: Applying MSD/RFL to DenseWAP improved its accuracy from 61.35% to 64.92% EM on EDU-CHEMC and from 92.09% to 94.10% EM on Mini-CASIA-CSDB, demonstrating the method is model-agnostic.</li>
<li><strong>Complexity handling</strong>: When trained on low-complexity molecules only (levels 1-2), MSD-DenseWAP still recognized higher-complexity unseen structures, while standard DenseWAP could hardly recognize them at all (Figure 6 in the paper).</li>
</ul>
<p>The authors note that this is the first end-to-end solution that decouples and models chemical structures in a structured form. Future work aims to extend structured-based modeling to other tasks such as tables, flowcharts, and diagrams.</p>
<hr>
<h2 id="artifacts">Artifacts</h2>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/JingMog/RFL-MSD">RFL-MSD</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official PyTorch implementation</td>
      </tr>
  </tbody>
</table>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The authors utilized one handwritten and one printed dataset, plus a synthetic set for stress-testing complexity.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td><strong>EDU-CHEMC</strong></td>
          <td>48,998 Train / 2,992 Test</td>
          <td>Handwritten images from educational scenarios</td>
      </tr>
      <tr>
          <td><strong>Training/Test</strong></td>
          <td><strong>Mini-CASIA-CSDB</strong></td>
          <td>89,023 Train / 8,287 Test</td>
          <td>Printed images rendered from ChEMBL using RDKit</td>
      </tr>
      <tr>
          <td><strong>Generalization</strong></td>
          <td><strong>ChEMBL Subset</strong></td>
          <td>5 levels of complexity</td>
          <td>Custom split based on Eq: $N_{atom} + N_{bond} + 12 \times N_{ring}$</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p><strong>RFL Splitting (Encoding)</strong>:</p>
<ol>
<li><strong>Detect Rings</strong>: Use DFS to find all non-nested rings $\mathcal{R}$.</li>
<li><strong>Determine Adjacency ($\gamma$)</strong>: Calculate shared edges between rings.</li>
<li><strong>Merge</strong>:
<ul>
<li>If $\gamma(r_i) = 0$ (isolated), merge ring into a <strong>SuperAtom</strong> node.</li>
<li>If $\gamma(r_i) &gt; 0$ (adjacent), merge ring into a <strong>SuperBond</strong> edge.</li>
</ul>
</li>
<li><strong>Update</strong>: Record connection info in $\mathcal{F}$ and remove ring details from the main graph to form Skeleton $\mathcal{S}$.</li>
</ol>
<p><strong>MSD Decoding</strong>:</p>
<ul>
<li><strong>Hierarchical Prediction</strong>: The model predicts the Skeleton $\mathcal{S}$ first.</li>
<li><strong>Contextual Ring Prediction</strong>: When a SuperAtom/Bond token is predicted, its hidden state $f^s$ is stored. After the skeleton is finished, $f^s$ is used as a condition to autoregressively decode the specific ring structure.</li>
<li><strong>Token <code>[conn]</code></strong>: A special token separates connected ring bonds from unconnected ones to sparsify the branch classification task.</li>
</ul>
<h3 id="models">Models</h3>
<p>The architecture follows a standard Image-to-Sequence pattern but with a forked decoder.</p>
<ul>
<li><strong>Encoder</strong>: DenseNet (Growth rate=24, Depth=32 per block)</li>
<li><strong>Decoder (MSD)</strong>:
<ul>
<li><strong>Core</strong>: GRU with Attention (Hidden dim=256, Embedding dim=256, Dropout=0.15)</li>
<li><strong>Skeleton Module</strong>: Autoregressively predicts sequence tokens. Uses Maxout activation.</li>
<li><strong>Branch Module</strong>: A binary classifier (MLP) taking concatenated features of skeleton bonds $f_{bs}$ and ring bonds $f_{br}$ to predict connectivity matrix $\mathcal{F}$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>: $\mathcal{O} = \lambda_1 \mathcal{L}_{ce} + \lambda_2 \mathcal{L}_{cls}$ (where $\lambda_1 = \lambda_2 = 1$)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>Metrics focus on exact image reconstruction and structural validity.</p>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Description</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>EM (Exact Match)</strong></td>
          <td>% of images where predicted graph exactly matches ground truth.</td>
          <td>Primary metric</td>
      </tr>
      <tr>
          <td><strong>Struct-EM</strong></td>
          <td>% of correctly identified chemical structures (ignoring non-chemical text).</td>
          <td>Auxiliary metric</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: 4 x NVIDIA Tesla V100 (32GB RAM)</li>
<li><strong>Training Configuration</strong>:
<ul>
<li>Batch size: 8 (Handwritten), 32 (Printed)</li>
<li>Epochs: 50</li>
<li>Optimizer: Adam ($lr=2\times10^{-4}$, decayed by 0.5 via MultiStepLR)</li>
</ul>
</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Chang, Q., Chen, M., Pi, C., Hu, P., Zhang, Z., Ma, J., Du, J., Yin, B., &amp; Hu, J. (2025). RFL: Simplifying Chemical Structure Recognition with Ring-Free Language. In <em>Proceedings of the AAAI Conference on Artificial Intelligence</em>, 39(2), 2007-2015. <a href="https://doi.org/10.1609/aaai.v39i2.32197">https://doi.org/10.1609/aaai.v39i2.32197</a></p>
<p><strong>Publication</strong>: AAAI 2025 (Oral)</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/JingMog/RFL-MSD">Official Code Repository</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{changRFLSimplifyingChemical2025,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{RFL: Simplifying Chemical Structure Recognition with Ring-Free Language}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">shorttitle</span> = <span style="color:#e6db74">{RFL}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Chang, Qikai and Chen, Mingjun and Pi, Changpeng and Hu, Pengfei and Zhang, Zhenrong and Ma, Jiefeng and Du, Jun and Yin, Baocai and Hu, Jinshui}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2025}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the AAAI Conference on Artificial Intelligence}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{39}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{2}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{2007--2015}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span> = <span style="color:#e6db74">{2412.07594}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span> = <span style="color:#e6db74">{cs}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1609/aaai.v39i2.32197}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span> = <span style="color:#e6db74">{arXiv}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Optimizing Sequence Models for Dynamical Systems</title><link>https://hunterheidenreich.com/research/deconstructing-recurrence-attention-gating/</link><pubDate>Tue, 01 Oct 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/research/deconstructing-recurrence-attention-gating/</guid><description>Ablation study deconstructing sequence models. Attention-augmented Recurrent Highway Networks outperform Transformers on chaotic systems.</description><content:encoded><![CDATA[<h2 id="abstract">Abstract</h2>
<p>Advanced neural network architectures developed for tasks like natural language processing are often transferred to spatiotemporal forecasting without a deep understanding of which components drive their performance. This can lead to suboptimal results and reinforces the view of these models as &ldquo;black boxes&rdquo;. In this work, we deconstruct the core mechanisms of Transformers and Recurrent Neural Networks (RNNs) (namely attention, gating, and recurrence). We then build and test novel hybrid architectures to identify which components are most effective. A key finding is that while adding recurrence is detrimental to Transformers, augmenting RNNs with attention and neural gating consistently improves their forecasting accuracy. Our study reveals that a seldom-used architecture, the Recurrent Highway Network (RHN) enhanced with these mechanisms, emerges as the top-performing model for forecasting high-dimensional chaotic systems.</p>
<h2 id="key-contributions">Key Contributions</h2>
<ul>
<li><strong>Systematic Ablation</strong>: Deconstructed Transformers and RNNs into core mechanisms (attention, gating, recurrence) to isolate performance drivers</li>
<li><strong>Novel Hybrid Architectures</strong>: Synthesized and tested new combinations of neural primitives for spatiotemporal forecasting</li>
<li><strong>RHN advantage on chaotic systems</strong>: Demonstrated that attention-augmented Recurrent Highway Networks outperform standard Transformers on high-dimensional chaotic systems</li>
<li><strong>Robustness Analysis</strong>: Validated models across both clean physics simulations and noisy real-world industrial datasets</li>
</ul>
<h2 id="motivation">Motivation</h2>
<p>In modern ML, architectures are often transferred from one domain (like NLP) to another (like physical forecasting) without understanding the underlying mechanics. This &ldquo;black box&rdquo; approach leads to suboptimal compute usage and performance ceilings.</p>
<p>Our goal was to break these architectures down. We treated the core mechanisms of <strong>Transformers</strong> and <strong>RNNs</strong> (<strong>Gating, Attention, and Recurrence</strong>) as orthogonal basis vectors. By decoupling these components, we could synthesize and test hybrid architectures to find the best configuration for spatiotemporal forecasting.</p>
<h2 id="methodological-approach">Methodological Approach</h2>
<p>We built a modular framework to mix and match neural primitives. We systematically evaluated:</p>
<ol>
<li><strong>Gating Mechanisms:</strong> Testing Additive, Learned Rate, Input-Dependent, and Coupled Input-Dependent variants</li>
<li><strong>Attention:</strong> Implementing multi-headed attention with relative positional biases</li>
<li><strong>Recurrence:</strong> Testing standard cells (LSTM, GRU) against deeper transition cells like Recurrent Highway Networks (RHN)</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/neural-gates.webp"
         alt="Neural gating mechanisms: Additive, Learned Rate, Dependent-Coupled, and Dependent variants"
         title="Neural gating mechanisms: Additive, Learned Rate, Dependent-Coupled, and Dependent variants"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The hierarchy of neural gating mechanisms we tested, from simple additive to fully input-dependent.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/rnn-cell-types.webp"
         alt="RNN cell architectures: Elman, LSTM, GRU, and RHN cells"
         title="RNN cell architectures: Elman, LSTM, GRU, and RHN cells"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Recurrent cell types compared in our study. The RHN (d) extends processing depth within each timestep.</figcaption>
    
</figure>

<p>This ablation isolated exactly <em>which</em> mathematical operation was driving the performance gain.</p>
<h2 id="key-findings">Key Findings</h2>
<h3 id="recurrent-highway-networks-on-chaotic-systems">Recurrent Highway Networks on Chaotic Systems</h3>
<p>For high-dimensional chaotic systems like the Multiscale Lorenz-96 shown below, we found that a <strong>Recurrent Highway Network (RHN)</strong> augmented with <strong>Attention and Neural Gating</strong> was the top-performing architecture. This hybrid exceeded the forecasting accuracy of standard Transformers, suggesting that deeper recurrence (processing depth per timestep) matters for complex dynamics.</p>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/multiscale-lorenz.webp"
         alt="Forecasting comparison on Multiscale Lorenz-96 system"
         title="Forecasting comparison on Multiscale Lorenz-96 system"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Forecasting the Multiscale Lorenz-96 system. The top row shows the &rsquo;texture&rsquo; of the chaotic evolution. Notice how the RHN (far right) maintains the coherent wave-like structures for nearly a full Lyapunov time, holding structure longer than the Transformer variants (the plotted window spans two Lyapunov times).</figcaption>
    
</figure>

<h3 id="transformers-recurrence-hurts-gating-helps">Transformers: Recurrence Hurts, Gating Helps</h3>
<p>We attempted to force recurrence into Transformers to give them &ldquo;memory,&rdquo; but it consistently hurt performance. However, <strong>Neural Gating</strong> significantly improved Transformer robustness. For real-world, noisy data (traffic, weather), the <strong>Pre-Layer Normalization (PreLN) Transformer</strong> with added gating proved to be the most robust model.</p>
<h3 id="adding-attention-to-lstms-and-grus">Adding Attention to LSTMs and GRUs</h3>
<p>We tested on the Kuramoto-Sivashinsky equation, a model of turbulence and flame fronts. We found that standard LSTMs and GRUs are under-optimized for this setting: adding <strong>attention</strong> to these cells improved their valid-prediction time several-fold, with the best attention-augmented LSTM and GRU reaching roughly 4x and 6.6x their baseline valid-prediction time, respectively (the paper reports the top RNNs at 2-7x baseline on K-S). (On the partially-observed Multiscale Lorenz-96 system the same attention-plus-gating gain is smaller, more than 40%.)</p>















<figure class="post-figure center ">
    <img src="/img/deconstructing-sequence-prediction/kuramoto-sivashinksy.webp"
         alt="Forecasting comparison on Kuramoto-Sivashinsky system"
         title="Forecasting comparison on Kuramoto-Sivashinsky system"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Forecasting the Kuramoto-Sivashinsky system. The error heatmaps (bottom row) show how prediction quality degrades over time (lighter means larger error). The RHN maintains structural fidelity longer than competing architectures.</figcaption>
    
</figure>

<h3 id="robustness-on-real-world-datasets">Robustness on Real-World Datasets</h3>
<p>While chaotic systems test the limits of theory, we also validated our models on seven standard real-world datasets: the four <strong>Electricity Transformer Temperature (ETT)</strong> subsets plus <strong>Traffic</strong>, <strong>Electricity</strong>, and <strong>Weather</strong>.</p>
<p>Unlike the clean physics simulations, these datasets contain real-world noise and irregularities. In this environment, the <strong>Pre-Layer Normalization (PreLN) Transformer</strong> proved to be the most robust architecture. While it didn&rsquo;t always beat the RHN on pure chaos, its stability makes it a strong default choice for general time-series forecasting tasks where training speed and reliability are paramount.</p>
<h2 id="why-this-matters">Why This Matters</h2>
<p>This work treats architectural components as independently tunable choices rather than fixed defaults, and that framing surfaces a concrete trade-off. Transformers train in only 25-50% of the time the RNNs require (roughly 2-4x faster), while the attention-augmented RNNs give better inference accuracy on the chaotic physical systems. Which mechanism to select depends on whether the training budget or the forecast precision is the binding constraint, and the ablation makes that an informed choice rather than a default one.</p>
<p>The ablation framework here, treating architectural components as independently tunable factors and measuring their marginal contribution, shaped how later evaluation work is structured. The same principle of isolating variables rather than comparing end-to-end black boxes appears in the document processing research, from benchmark construction in page stream segmentation to grounded evaluation in GutenOCR.</p>
<h2 id="related-work">Related Work</h2>
<p>The methodology here shares a design philosophy with <a href="/research/eigennoise-contrastive-prior/">EigenNoise</a>,
which similarly decomposes a neural mechanism (word vector initialization) into theoretically
grounded components to isolate what drives performance. Both papers treat model components as
testable hypotheses rather than fixed architectural choices.</p>
<p>For broader context on where this fits in the portfolio&rsquo;s Scientific Machine Learning arc,
see the <a href="/research/">Research</a> overview.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{heidenreich2024deconstructingrecurrenceattentiongating,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Deconstructing Recurrence, Attention, and Gating: Investigating the transferability of Transformers and Gated Recurrent Neural Networks in forecasting of dynamical systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hunter S. Heidenreich and Pantelis R. Vlachas and Petros Koumoutsakos}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2024}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2410.02654}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{cs.LG}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2410.02654}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are used for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \times 28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For inputs in $[0, 1]$</strong> (like MNIST pixel intensities, which after <code>ToTensor()</code> are continuous values in $[0, 1]$), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli variable whose target is its intensity. The decoder outputs the <em>logits</em> for each pixel, and the BCE-with-logits loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \times \text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>IQCRNN: Certified Stability for Neural Networks</title><link>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</link><pubDate>Wed, 11 May 2022 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/projects/iqcrnn-pytorch/</guid><description>PyTorch IQCRNN enforcing stability guarantees on RNNs via Integral Quadratic Constraints and semidefinite programming.</description><content:encoded><![CDATA[<p>This project is a PyTorch re-implementation of <strong>IQCRNN</strong>, a method that enforces strict stability guarantees on Recurrent Neural Networks used in control systems.</p>
<h2 id="overview">Overview</h2>
<p>Standard Reinforcement Learning agents can behave unpredictably in unseen states. This approach forces the agent&rsquo;s weights to satisfy <strong>Integral Quadratic Constraints (IQC)</strong> via a projection step. Effectively, it solves a convex optimization problem (Semidefinite Program) inside the gradient descent loop to ensure the controller never violates Lyapunov stability criteria.</p>
<p>The method bridges classic <strong>Robust Control Theory</strong> (1990s) with <strong>Deep Reinforcement Learning</strong> (2020s), providing mathematical certificates of safety for neural network controllers.</p>
<h2 id="features">Features</h2>
<ul>
<li><strong>Hybrid Optimization:</strong> Interleaved standard Gradient Descent (PyTorch) with Convex Optimization (<code>cvxpy</code> + <code>MOSEK</code>) to project weights onto the &ldquo;safe&rdquo; manifold after each training step.</li>
<li><strong>Complex Constraints:</strong> Implemented the &ldquo;Tilde&rdquo; parametrization from the original paper to convexify the non-convex stability conditions of the RNN dynamics, transforming an intractable problem into a solvable Linear Matrix Inequality (LMI).</li>
<li><strong>Safety-Critical Domains:</strong> Applied the controller across six control systems (cartpole, inverted pendulum, nonlinear pendulum, pendubot, power grid, and vehicle dynamics), including unstable plants where &ldquo;crashing&rdquo; during training is unacceptable.</li>
</ul>
<h2 id="usage">Usage</h2>
<p>The repository includes training scripts for the inverted pendulum and power grid environments, demonstrating the stability guarantees in practice.</p>
<h2 id="results">Results</h2>
<p>This project was a deep dive into the tension between <strong>Safety</strong> and <strong>Speed</strong>.</p>
<ul>
<li><strong>The Bottleneck:</strong> Solving an SDP at every few steps of training is computationally expensive (interior-point SDP solvers scale steeply, roughly $O(n^6)$ in the matrix dimension). While it provided mathematical certificates of safety, it highlighted why these methods haven&rsquo;t yet overtaken standard PPO/SAC in production: the &ldquo;safety tax&rdquo; on training time is steep.</li>
<li><strong>The Lesson:</strong> It taught me that &ldquo;theoretical guarantees&rdquo; often come with &ldquo;engineering fine print.&rdquo; If I were to redo this today, I would look into <strong>differentiable convex optimization layers</strong> (like <code>cvxpylayers</code>) to make the projection end-to-end differentiable.</li>
<li><strong>The &ldquo;Rough Edges&rdquo;:</strong> The codebase has artifacts of its research origins (e.g., the <code>reqs.txt</code> dependency dump). Reading a dense control theory paper (Gu et al., 2021) and implementing the math correctly was the primary focus.</li>
</ul>
<h2 id="citation">Citation</h2>
<p>Credit to the original authors:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{gu2021recurrentneuralnetworkcontrollers,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Recurrent Neural Network Controllers Synthesis with Stability Guarantees for Partially Observed Systems}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Fangda Gu and He Yin and Laurent El Ghaoui and Murat Arcak and Peter Seiler and Ming Jin}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2109.03861}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{eess.SY}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/2109.03861}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><h2 id="related-work">Related Work</h2>
<ul>
<li><a href="/research/deconstructing-recurrence-attention-gating/">Deconstructing Recurrence and Attention Gating</a>: research on recurrent network architectures, providing context for why stability guarantees on RNNs matter</li>
</ul>
]]></content:encoded></item><item><title>GPT-2 Susceptibility to Universal Adversarial Triggers</title><link>https://hunterheidenreich.com/research/gpt2-adversarial-triggers/</link><pubDate>Sat, 01 May 2021 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/research/gpt2-adversarial-triggers/</guid><description>Investigation into whether universal adversarial triggers can control both topic and stance of GPT-2's generated text and security implications.</description><content:encoded><![CDATA[<blockquote>
<p><strong>Historical context:</strong> This paper was published at AIES 2021, the AAAI/ACM Conference on AI, Ethics, and Society, predating the modern red-teaming practices and adversarial robustness benchmarks that emerged with instruction-tuned and RLHF-trained models. GPT-2 is now a historical baseline, but the core methodology and findings remain a relevant foundation for current adversarial robustness work.</p></blockquote>
<h2 id="abstract">Abstract</h2>
<p>This work investigates universal adversarial triggers (UATs), a method for disrupting language models using input-agnostic token sequences. We investigated whether it is possible to use these triggers to control the <strong>topic</strong> and the <strong>stance</strong> of text generated by GPT-2. Across four controversial topics, we demonstrated success in identifying triggers that guide the model to produce text on a targeted subject and influence the position it takes. Our goal is to raise awareness that even deployed models are susceptible to this influence and to advocate for immediate safeguards.</p>
<h2 id="key-findings--contributions">Key Findings &amp; Contributions</h2>
<ul>
<li><strong>Topic and Stance Control</strong>: We were the first to systematically explore using UATs to control both the topic and the stance of a language model&rsquo;s output. We found that controlling the topic is highly feasible, and controlling the stance is also possible.</li>
<li><strong>The &ldquo;Filter Bubble&rdquo; Hypothesis</strong>: We observed that triggers for fringe topics (e.g., Flat Earth) were harder to find but offered a higher degree of stance control than broader topics. We posit this may reflect &ldquo;filter bubbles&rdquo; in the training data, where fringe viewpoints use distinct linguistic patterns.</li>
<li><strong>Ethical &amp; Security Analysis</strong>: We highlighted the security risks of deployed models being manipulated by external adversaries without internal model access. To be responsible, we withheld the most sensitive triggers we discovered.</li>
<li><strong>Constructive Applications</strong>: Beyond a security flaw, we proposed that UATs could be used constructively as a <strong>diagnostic tool</strong> to audit models for bias or as a method for <strong>bot detection</strong> on social media.</li>
</ul>
<h2 id="significance--why-this-matters">Significance &amp; Why This Matters</h2>
<p>This work extended early research on UATs by moving beyond single-issue attacks (like generating toxic content) to a nuanced analysis of topic and stance control. It demonstrated that a <strong>gradient-based search process (adapting HotFlip)</strong> is effective at manipulating model outputs, emphasizing a critical vulnerability for any organization deploying large language models.</p>
<p>For ML practitioners and security researchers, this highlights the importance of robust safeguards against input-agnostic attacks. It also opens the door to using these same adversarial techniques constructively: as diagnostic tools to audit models for hidden biases or to detect automated bot activity on social media platforms.</p>
<h2 id="related-work">Related Work</h2>
<p>The constructive bot-detection application proposed here connects directly to empirical work on coordinated inauthentic behavior. <a href="/research/coordinated-social-targeting/">Coordinated Social Targeting on Twitter</a> documents real-world follower-manipulation patterns on high-profile accounts, illustrating the kind of automated adversarial activity that UAT-based detection methods could help identify.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{10.1145/3461702.3462578,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Heidenreich, Hunter Scott and Williams, Jake Ryland}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{The Earth Is Flat and the Sun Is Not a Star: The Susceptibility of GPT-2 to Universal Adversarial Triggers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">isbn</span> = <span style="color:#e6db74">{9781450384735}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">{Association for Computing Machinery}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">{New York, NY, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://doi.org/10.1145/3461702.3462578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1145/3461702.3462578}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the 2021 AAAI/ACM Conference on AI, Ethics, and Society}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{566--573}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">numpages</span> = <span style="color:#e6db74">{8}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">keywords</span> = <span style="color:#e6db74">{adversarial attacks, bias, language modeling, natural language processing}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">location</span> = <span style="color:#e6db74">{Virtual Event, USA}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">series</span> = <span style="color:#e6db74">{AIES &#39;21}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>A Guide to Neuroevolution: NEAT and HyperNEAT</title><link>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</link><pubDate>Wed, 02 Jan 2019 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/neuroevolution-neat-and-hyperneat/</guid><description>Explore the evolution of neural network topologies with NEAT and how HyperNEAT scales this approach using geometric patterns and indirect encoding.</description><content:encoded><![CDATA[<h2 id="automating-neural-architecture-design">Automating Neural Architecture Design</h2>
<p>Designing neural network architectures is typically a manual, iterative process. Researchers experiment with different layer configurations, activation functions, and connection patterns, often guided by intuition and empirical results. Evolution offers an automated alternative to this design process.</p>
<p><a href="https://nn.cs.utexas.edu/downloads/papers/stanley.ec02.pdf">NEAT (NeuroEvolution of Augmenting Topologies)</a>, introduced in 2002, optimizes network weights and evolves the network structure itself, starting from minimal topologies and growing complexity only when beneficial.</p>
<p>NEAT&rsquo;s core innovations solved fundamental problems that had plagued earlier attempts at topology evolution. Its solutions for genetic encoding, structural crossover, and innovation protection remain influential today, especially as neural architecture search and automated ML gain prominence.</p>
<h2 id="the-core-challenges-of-neat">The Core Challenges of NEAT</h2>
<p>Evolving neural network topologies presents several fundamental challenges that NEAT elegantly addressed. Understanding these problems helps explain why NEAT&rsquo;s solutions were so influential.</p>
<h3 id="genetic-encoding-how-to-represent-networks">Genetic Encoding: How to Represent Networks</h3>
<p>Evolutionary algorithms require a genetic representation, a way to encode individuals that enables meaningful selection, mutation, and crossover. For neural networks, this choice is critical.</p>
<p><strong>Direct encoding</strong> explicitly represents each network component. Genes directly correspond to nodes and connections. This approach is intuitive and readable, and it works well for smaller networks.</p>
<p><strong>Indirect encoding</strong> specifies construction rules or processes. These encodings are more compact and can generate highly complex structures from simple rules.</p>
<p>NEAT chose direct encoding with a simple two-part structure: separate gene lists for nodes and connections. This balances simplicity with the flexibility needed for evolutionary operations.</p>















<figure class="post-figure center ">
    <img src="/img/neat_genomes.webp"
         alt="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         title="NEAT genome encoding showing node genes and connection genes with innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT&rsquo;s direct encoding: node genes (top) and connection genes (bottom) with historical markings</figcaption>
    
</figure>

<p>Connection genes specify the source and target nodes, weight, enabled status, and an innovation number for historical tracking. Input and output nodes are fixed; only hidden nodes evolve.</p>
<h3 id="structural-mutations-growing-complexity">Structural Mutations: Growing Complexity</h3>
<p>NEAT employs two categories of mutations to evolve both weights and structure:</p>
<p><strong>Weight mutations</strong> adjust existing connection strengths using standard perturbation methods, the familiar approach from traditional neuroevolution.</p>
<p><strong>Structural mutations</strong> add new network components:</p>
<ul>
<li><strong>Add connection</strong>: Creates a new link between existing nodes with a random initial weight</li>
<li><strong>Add node</strong>: Splits an existing connection by inserting a new node. The original connection is disabled, while two new connections replace it. One inherits the original weight, the other starts at 1.0</li>
</ul>
<p>This node-splitting approach minimizes disruption. The new node initially acts as an identity function, giving it time to prove useful before natural selection pressure intensifies.</p>
<h3 id="solving-the-competing-conventions-problem">Solving the Competing Conventions Problem</h3>
<p>Performing crossover between networks with different structures presents a fundamental challenge. Consider two networks that solve the same problem using different internal organizations. Naive crossover between them typically produces broken offspring.</p>















<figure class="post-figure center ">
    <img src="/img/competing_conventions.webp"
         alt="Two neural networks performing the same function but with different internal structures"
         title="Two neural networks performing the same function but with different internal structures"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The competing conventions problem: identical functions, different implementations</figcaption>
    
</figure>

<p>NEAT&rsquo;s solution draws inspiration from biology through <strong>historical markings</strong>. Each structural innovation (adding a node or connection) receives a unique innovation number, a timestamp of when that change first appeared in the population.</p>
<p>During crossover, genes with matching innovation numbers are aligned and combined. This biological concept of homology enables meaningful recombination between networks of different sizes and structures.</p>















<figure class="post-figure center ">
    <img src="/img/neat_crossover.webp"
         alt="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         title="Diagram showing how NEAT aligns genes during crossover using innovation numbers"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">NEAT crossover using historical markings for gene alignment</figcaption>
    
</figure>

<h3 id="protecting-innovation-through-speciation">Protecting Innovation Through Speciation</h3>
<p>New structural innovations face a harsh reality: they usually perform worse initially. Adding nodes or connections typically decreases performance before optimization can improve the new structure. Without protection, these innovations disappear before realizing their potential.</p>
<p>NEAT addresses this through <strong>speciation</strong>: dividing the population into species based on structural and weight similarity. The historical markings that enable crossover also measure compatibility between individuals.</p>
<p>Crucially, individuals only compete within their species. This gives new structural innovations time to optimize without immediately competing against established, well-tuned networks.</p>
<p><strong>Explicit fitness sharing</strong> enhances this protection: species divide their collective fitness among members, preventing any single species from dominating the population while maintaining diversity for continued exploration.</p>
<h3 id="complexification-starting-minimal">Complexification: Starting Minimal</h3>
<p>NEAT begins with the simplest possible networks (just input and output nodes connected by random weights). No hidden layers exist initially. Complexity emerges only when mutations that add structure prove beneficial.</p>
<p>This complexification approach builds efficient solutions that solve problems with minimal structure. Combined with speciation, it tends to produce highly optimized architectures.</p>
<h2 id="scaling-up-hyperneat">Scaling Up: HyperNEAT</h2>
<p>NEAT evolved networks through direct encoding, where each gene explicitly specifies nodes and connections. Scaling this approach to larger architectures requires a fundamentally different method. Evolving networks with billions of connections like the brain requires indirect encoding.</p>
<p><a href="https://doi.org/10.1162/artl.2009.15.2.15202">HyperNEAT</a> introduces <strong>indirect encoding</strong> through geometric principles. HyperNEAT evolves geometric patterns that generate connections based on spatial relationships. This enables the evolution of large networks with biological regularities: symmetry, repetition, and locality.</p>
<p>The key insight is leveraging Compositional Pattern Producing Networks (CPPNs) to map coordinates to connection weights, exploiting the geometric organization found in natural neural networks.</p>
<h3 id="biological-motivation">Biological Motivation</h3>
<p>The human brain exhibits remarkable organizational principles:</p>
<ul>
<li><strong>Scale</strong>: ~86 billion neurons with ~100 trillion connections</li>
<li><strong>Repetition</strong>: Structural patterns reused across regions</li>
<li><strong>Symmetry</strong>: Mirrored structures like bilateral visual processing</li>
<li><strong>Locality</strong>: Spatial proximity influences connectivity and function</li>
</ul>
<p>HyperNEAT aims to evolve networks that capture these biological regularities, leading to more efficient and interpretable architectures.</p>
<h3 id="compositional-pattern-producing-networks">Compositional Pattern Producing Networks</h3>
<p>CPPNs are the foundation of HyperNEAT&rsquo;s indirect encoding. Think of them as pattern generators that create complex spatial structures from simple coordinate inputs.</p>
<p>DNA exemplifies indirect encoding (roughly 20,000 protein-coding genes specify a brain with trillions of connections). This massive compression ratio suggests that simple rules can generate complex structures through developmental processes.</p>
<p>CPPNs abstract this concept, using compositions of mathematical functions to create patterns in coordinate space. The same genetic program (function composition) can be reused across different locations and scales, just like how developmental genes control pattern formation throughout an organism.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppns.webp"
         alt="Various symmetric and repetitive patterns created by CPPNs"
         title="Various symmetric and repetitive patterns created by CPPNs"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Complex patterns generated by CPPNs through function composition</figcaption>
    
</figure>

<h3 id="pattern-generation-through-function-composition">Pattern Generation Through Function Composition</h3>
<p>CPPNs generate patterns by composing simple mathematical functions. Key function types include:</p>
<ul>
<li><strong>Gaussian functions</strong>: Create symmetric patterns and gradients</li>
<li><strong>Trigonometric functions</strong>: Generate periodic/repetitive structures</li>
<li><strong>Linear functions</strong>: Produce gradients and asymmetric patterns</li>
<li><strong>Sigmoid functions</strong>: Create sharp transitions and boundaries</li>
</ul>
<p>By combining these functions, CPPNs can encode complex regularities that would require many explicit rules in direct encoding.</p>
<h3 id="evolution-of-cppns">Evolution of CPPNs</h3>
<p>HyperNEAT uses NEAT to evolve the CPPN structure. This brings several advantages:</p>
<ul>
<li><strong>Complexification</strong>: CPPNs start simple and grow more complex only when beneficial</li>
<li><strong>Historical markings</strong>: Enable proper crossover between different CPPN topologies</li>
<li><strong>Speciation</strong>: Protects innovative CPPN patterns during evolution</li>
</ul>
<p>Additional activation functions beyond standard neural networks are crucial:</p>
<ul>
<li>Gaussian functions for symmetry</li>
<li>Sine/cosine for repetition</li>
<li>Specialized functions for specific geometric patterns</li>
</ul>
<h2 id="the-hyperneat-process">The HyperNEAT Process</h2>
<h3 id="substrates-geometric-organization">Substrates: Geometric Organization</h3>
<p>A <strong>substrate</strong> defines the spatial arrangement of neurons. Substrates embed neurons in geometric space (2D grids, 3D volumes, etc.), providing an alternative to layer-based connectivity rules.</p>
<p>The CPPN maps from coordinates to connection weights:</p>
<p>$$\text{CPPN}(x_1, y_1, x_2, y_2) = w$$</p>
<p>Where $(x_1, y_1)$ and $(x_2, y_2)$ are the coordinates of two neurons, and $w$ determines their connection weight.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_cppn_basics.webp"
         alt="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         title="Diagram showing CPPN taking four coordinate inputs and outputting connection weight"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Basic CPPN architecture mapping coordinates to connection weights</figcaption>
    
</figure>

<p>This geometric approach enables several key properties:</p>
<ul>
<li><strong>Locality</strong>: Nearby neurons tend to have similar connectivity patterns</li>
<li><strong>Symmetry</strong>: Patterns can be mirrored across spatial axes</li>
<li><strong>Repetition</strong>: Periodic functions create repeating motifs</li>
<li><strong>Scalability</strong>: The same pattern can be applied at different resolutions</li>
</ul>
<h3 id="emergent-regularities">Emergent Regularities</h3>
<p>The geometric encoding naturally produces the desired biological patterns:</p>
<p><strong>Symmetry</strong> emerges from symmetric functions. A Gaussian centered at the origin creates identical patterns when $(x_1, y_1)$ and $(x_2, y_2)$ are equidistant from the center.</p>
<p><strong>Repetition</strong> arises from periodic functions like sine and cosine. These create repeating connectivity motifs across the substrate.</p>
<p><strong>Locality</strong> results from functions that vary smoothly across space. Nearby coordinates produce similar outputs, leading to local connectivity patterns.</p>
<p><strong>Imperfect regularity</strong> occurs when these patterns are modulated by additional coordinate dependencies, creating biological-like variation within the basic structure.</p>
<h3 id="substrate-configurations">Substrate Configurations</h3>
<p>The choice of substrate geometry critically influences network behavior. Several standard configurations exist:</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_substrate_configurations.webp"
         alt="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         title="Various substrate layouts including grids, 3D arrangements, and circular patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Common substrate geometries for different problem types</figcaption>
    
</figure>

<p><strong>2D Grid</strong>: Simple planar arrangement, CPPN takes four coordinates $(x_1, y_1, x_2, y_2)$</p>
<p><strong>3D Volume</strong>: Extends to three dimensions, CPPN becomes six-dimensional $(x_1, y_1, z_1, x_2, y_2, z_2)$</p>
<p><strong>Sandwich</strong>: Input layer connects only to output layer, useful for sensory-motor tasks</p>
<p><strong>Circular</strong>: Radial geometry enables rotation-invariant patterns and cyclic behaviors</p>
<p>The substrate must be chosen before evolution begins, making domain knowledge important for success.</p>
<h3 id="exploiting-input-output-geometry">Exploiting Input-Output Geometry</h3>
<p>HyperNEAT exploits the spatial organization of inputs and outputs. For visual tasks, pixel coordinates provide meaningful geometric information. For control problems, sensor and actuator layouts can guide connectivity patterns.</p>















<figure class="post-figure center ">
    <img src="/img/hyperneat_inputs_outputs.webp"
         alt="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         title="Visual representation of how HyperNEAT maps spatial input arrangements to output patterns"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Spatial organization of inputs and outputs enables geometric exploitation</figcaption>
    
</figure>

<p>This spatial awareness allows HyperNEAT to:</p>
<ul>
<li>Develop receptive fields similar to biological vision systems</li>
<li>Create locally connected patterns for spatial processing</li>
<li>Generate symmetric motor control patterns</li>
<li>Scale across different input resolutions</li>
</ul>
<h3 id="resolution-independence">Resolution Independence</h3>
<p>A unique advantage of HyperNEAT is <strong>substrate resolution independence</strong>. Networks evolved on low-resolution substrates can be deployed on higher-resolution versions without retraining. The CPPN&rsquo;s coordinate-based mapping scales naturally across different granularities.</p>
<p>This property suggests that evolved patterns capture fundamental spatial relationships, providing a key insight for scalable neural architecture design.</p>
<h2 id="impact-and-future-directions">Impact and Future Directions</h2>
<p>NEAT and HyperNEAT demonstrated that evolution could design neural network topologies and scale them through indirect encoding. The algorithms&rsquo; key insights, exploiting geometry, generating patterns through function composition, and scaling across resolutions, continue to influence modern research.</p>
<p>Extensions like ES-HyperNEAT add even more sophisticated capabilities by evolving the substrate itself. As neural architecture search becomes increasingly important, these principles find new applications in hybrid approaches that combine evolutionary pattern generation with gradient-based optimization.</p>
<p>The emphasis on spatial organization and regularity also connects to contemporary work on geometric deep learning and equivariant networks, suggesting that evolution and hand-design converge on similar organizing principles for building structured, efficient neural architectures.</p>
]]></content:encoded></item><item><title>QuAC: Question Answering in Context Dataset</title><link>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</link><pubDate>Wed, 31 Oct 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/quac-question-answering-in-context/</guid><description>Analysis of QuAC's conversational QA through student-teacher interactions, featuring ~100K context-dependent questions and coreference challenges.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://aclanthology.org/D18-1241/">QuAC dataset</a> (Question Answering in Context) presents a conversational question answering approach that models student-teacher interactions. Published at EMNLP 2018, this work by Choi et al. addresses how systems can understand dialogue context, resolve references across conversation turns, and handle natural conversation ambiguity. Previous datasets treated questions independently.</p>
<p>The dataset addresses limitations in question answering research by incorporating real-world information-seeking dialogue complexities, where questions build upon previous exchanges and context drives understanding.</p>
<p>For comparison with related work, see my analysis of <a href="/posts/coqa-conversation-question-answering/">CoQA</a>.</p>
<h2 id="the-student-teacher-framework">The Student-Teacher Framework</h2>
<p>QuAC models information-seeking dialogue through a student-teacher setup:</p>
<ul>
<li><strong>Teacher</strong>: Has complete access to information (Wikipedia passage)</li>
<li><strong>Student</strong>: Seeks knowledge through questioning with limited initial context</li>
<li><strong>Interaction</strong>: Handles context-dependent questions, abstract inquiries, and unanswerable requests</li>
</ul>
<p>This framework mirrors real-world scenarios where one party has expertise while another seeks to learn through dialogue. AI systems must act as effective teachers, using available information to provide helpful responses despite ambiguous or incomplete questions.</p>
<p>The dataset contains roughly 100K questions across ~14K dialogues (precisely 98,407 questions and 13,594 dialogues), providing substantial scale for training and evaluation.</p>















<figure class="post-figure center ">
    <img src="/img/quac_stats.webp"
         alt="QuAC dataset statistics and scale"
         title="QuAC dataset statistics and scale"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">QuAC dataset statistics and scale</figcaption>
    
</figure>

<h2 id="dataset-construction">Dataset Construction</h2>
<p>QuAC was built using Amazon Mechanical Turk with a two-person dialogue setup:</p>
<p><strong>Teacher role</strong>: Has access to the complete Wikipedia passage and provides answers extracted directly from the text</p>
<p><strong>Student role</strong>: Sees only the article title, introduction paragraph, and section heading, then asks questions to learn about the content</p>
<p>This asymmetric information design ensures student questions naturally differ from the passage content, creating realistic information-seeking scenarios. The extractive answer requirement maintains objective evaluation while simplifying scoring.</p>
<p><strong>Dialogue termination</strong>:</p>
<ul>
<li>12 questions answered</li>
<li>Manual termination by either participant</li>
<li>Two consecutive unanswerable questions</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_convo.webp"
         alt="Example QuAC conversation showing student-teacher interaction"
         title="Example QuAC conversation showing student-teacher interaction"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Example QuAC conversation showing student-teacher interaction</figcaption>
    
</figure>

<h3 id="content-selection">Content Selection</h3>
<p>QuAC focuses on Wikipedia biographical articles for several practical reasons:</p>
<ul>
<li><strong>Reduced complexity</strong>: People-focused content requires less specialized domain knowledge</li>
<li><strong>Natural question flow</strong>: Biographical information lends itself to sequential questioning</li>
<li><strong>Quality control</strong>: Articles filtered to include only subjects with 100+ incoming links, ensuring content depth</li>
</ul>
<p>This focused scope enables consistent evaluation while maintaining broad coverage through diverse biographical subjects across fields and time periods.</p>
<h2 id="key-dataset-characteristics">Key Dataset Characteristics</h2>
<p>QuAC introduces several features that distinguish it from existing question answering benchmarks:</p>















<figure class="post-figure center ">
    <img src="/img/quac_comparison.webp"
         alt="Comparative analysis of QuAC against other QA datasets"
         title="Comparative analysis of QuAC against other QA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Comparative analysis of QuAC against other QA datasets</figcaption>
    
</figure>

<p><strong>Notable features</strong>:</p>
<ul>
<li><strong>High contextual dependency</strong>: a large majority of questions depend on the conversation context, and a substantial share require coreference resolution</li>
<li><strong>Non-factoid focus</strong>: 54% of questions go beyond simple fact retrieval</li>
<li><strong>Extended answers</strong>: Responses are longer and more detailed</li>
<li><strong>Unanswerable questions</strong>: Realistic scenarios where information isn&rsquo;t available</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_dist.webp"
         alt="Distribution of question types in QuAC"
         title="Distribution of question types in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of question types in QuAC</figcaption>
    
</figure>

<h3 id="the-coreference-resolution-challenge">The Coreference Resolution Challenge</h3>
<p>QuAC&rsquo;s complexity stems from its heavy reliance on coreference resolution across multiple contexts:</p>
<p><strong>Reference types</strong>:</p>
<ul>
<li><strong>Passage references</strong>: Pronouns and references to entities in the source text</li>
<li><strong>Dialogue references</strong>: References to previously discussed topics</li>
<li><strong>Abstract references</strong>: Challenging cases like &ldquo;what else?&rdquo; that require inferring the inquiry scope</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/quac_coref.webp"
         alt="Types and distribution of coreferences in QuAC"
         title="Types and distribution of coreferences in QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Types and distribution of coreferences in QuAC</figcaption>
    
</figure>

<p>The prevalence of coreference resolution makes QuAC particularly challenging, as this remains an active research problem in NLP. Models must understand passage content, track dialogue history, and resolve complex referential expressions simultaneously.</p>
<h2 id="performance-results">Performance Results</h2>
<p>Models face substantial challenges on QuAC, with significant gaps between human and machine performance:</p>















<figure class="post-figure center ">
    <img src="/img/quac_performance.webp"
         alt="Baseline model performance comparison on QuAC"
         title="Baseline model performance comparison on QuAC"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Baseline model performance comparison on QuAC</figcaption>
    
</figure>

<p><strong>Performance summary</strong>:</p>
<ul>
<li><strong>Human performance</strong>: 81.1% F1 score</li>
<li><strong>Best baseline</strong>: BiDAF++ with context achieves 60.2% F1</li>
<li><strong>Performance gap</strong>: 20+ point difference shows room for improvement</li>
</ul>
<h3 id="human-equivalence-metrics">Human Equivalence Metrics</h3>
<p>QuAC introduces evaluation metrics beyond traditional F1 scores:</p>
<p><strong>HEQ-Q (Human Equivalence Question-level)</strong>: Percentage of questions where the model achieves human-level or better performance</p>
<p><strong>HEQ-D (Human Equivalence Dialogue-level)</strong>: Percentage of complete dialogues where the model matches human performance across all questions</p>
<p><strong>Current results</strong>:</p>
<ul>
<li>Human baseline: 100% HEQ-Q, 100% HEQ-D (by definition)</li>
<li>Best model: 55.1% HEQ-Q, 5.2% HEQ-D</li>
</ul>
<p>These metrics show both average performance and consistency across questions and conversations, important for practical dialogue systems.</p>
<h2 id="research-impact">Research Impact</h2>
<p>QuAC represents an important step in question answering research by introducing realistic conversational dynamics that existing datasets lack. The student-teacher framework captures natural information-seeking behavior while maintaining extractive evaluation for objective assessment.</p>
<p><strong>Key contributions</strong>:</p>
<ul>
<li><strong>Conversational realism</strong>: Context-dependent questions that mirror dialogue patterns</li>
<li><strong>Coreference complexity</strong>: Integration of challenging NLP problems into QA evaluation</li>
<li><strong>Evaluation metrics</strong>: HEQ scores that measure consistency alongside average performance</li>
<li><strong>Large-scale framework</strong>: Substantial dataset enabling robust model training and evaluation</li>
</ul>
<p>The dataset&rsquo;s <a href="https://quac.ai/">leaderboard</a> provides researchers with a challenging benchmark for developing conversational AI systems. As models improve on QuAC, we can expect progress in dialogue agents, virtual assistants, and educational AI systems that engage in more natural, context-aware conversations.</p>
<p>QuAC&rsquo;s focus on dialogue context and reference resolution pushes the field toward AI systems that can engage in genuine conversation and understand complex dialogue flows.</p>
<h2 id="a-builders-perspective-quac-and-modern-instruction-tuning">A Builder&rsquo;s Perspective: QuAC and Modern Instruction Tuning</h2>
<p>Looking at QuAC through the lens of modern production ML, the student-teacher framework maps directly onto how we now train and evaluate assistants. Today, we train foundation models using Reinforcement Learning from Human Feedback (RLHF) and instruction tuning, which rely heavily on multi-turn, context-aware interactions.</p>
<p>When building a system like GutenOCR, users rarely ask perfectly formulated, context-free questions. They ask follow-ups, use pronouns, and expect the system to act as a knowledgeable &ldquo;teacher&rdquo; guiding them through the document. QuAC was an early dataset to formalize this asymmetric information dynamic. It highlighted the necessity of handling unanswerable questions gracefully, a critical feature for preventing hallucinations in today&rsquo;s production LLMs.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{choi-etal-2018-quac,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">&#34;{Q}u{AC}: Question Answering in Context&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">&#34;Choi, Eunsol and He, He and Iyyer, Mohit and Yatskar, Mark and Yih, Wen-tau and Choi, Yejin and Liang, Percy and Zettlemoyer, Luke&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">&#34;Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">month</span> = oct # <span style="color:#e6db74">&#34;-&#34;</span> # nov,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">&#34;2018&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">address</span> = <span style="color:#e6db74">&#34;Brussels, Belgium&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">publisher</span> = <span style="color:#e6db74">&#34;Association for Computational Linguistics&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">&#34;https://aclanthology.org/D18-1241/&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">&#34;10.18653/v1/D18-1241&#34;</span>,
</span></span><span style="display:flex;"><span>    <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">&#34;2174--2184&#34;</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>CoQA Dataset: Advancing Conversational Question Answering</title><link>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</link><pubDate>Thu, 23 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/coqa-conversation-question-answering/</guid><description>Analysis of CoQA, a conversational QA dataset with multi-turn dialogue, coreference resolution, and natural answers for QA research.</description><content:encoded><![CDATA[<h2 id="introduction">Introduction</h2>
<p>The <a href="https://doi.org/10.1162/tacl_a_00266">CoQA dataset</a> (Reddy et al., 2019) introduces conversational dynamics to question answering research. CoQA requires models to maintain context across multi-turn conversations while reading and reasoning about text passages. Previous datasets focused on isolated question-answer pairs.</p>
<p>This dataset addresses a gap in conversational AI research by providing a benchmark for systems that must understand dialogue flow and implicit references. These are key components of natural human conversation.</p>
<p>For related work on conversational question answering, see my analysis of <a href="/posts/quac-question-answering-in-context/">QuAC</a>.</p>
<h2 id="what-makes-conversational-qa-different">What Makes Conversational QA Different</h2>
<p>Conversational question answering introduces challenges beyond traditional reading comprehension:</p>
<ol>
<li><strong>Context dependency</strong>: Questions rely on previous dialogue turns for meaning</li>
<li><strong>Coreference resolution</strong>: Understanding pronouns and implicit references</li>
<li><strong>Abstractive answering</strong>: Rephrasing information to generate natural responses</li>
<li><strong>Multi-turn reasoning</strong>: Maintaining coherent dialogue across multiple exchanges</li>
</ol>
<p>These requirements differentiate CoQA from existing question answering datasets that treat each question independently.</p>
<h2 id="why-coqa-matters">Why CoQA Matters</h2>
<p>Question answering systems typically excel at finding specific information in text. However, they often struggle with natural conversation. Human communication involves building on previous exchanges, using pronouns and implicit references, and expressing ideas in varied ways.</p>
<p>CoQA addresses this by creating a large-scale dataset for conversational question answering with three primary characteristics:</p>
<ol>
<li>
<p><strong>Conversation-dependent questions</strong>: After the first question, every subsequent question depends on dialogue history across 127,000 questions spanning 8,000 conversations</p>
</li>
<li>
<p><strong>Natural, abstractive answers</strong>: CoQA requires rephrased responses that sound natural in conversation. The answerer first highlighted the relevant text span, then rephrased the information.</p>
</li>
<li>
<p><strong>Domain diversity</strong>: Training covers 5 domains with testing on 7 domains, including 2 unseen during training</p>
</li>
</ol>
<p>The performance gap is notable: humans achieve 88.8% F1 score while the best models at the time reached 65.1% F1, indicating substantial room for improvement.</p>
<h2 id="dataset-construction">Dataset Construction</h2>
<p>CoQA was constructed using Amazon Mechanical Turk, pairing workers in a question-answer dialogue setup. One worker asked questions about a given passage while another provided answers. The answerer first highlighted the relevant text span, then rephrased the information using different words to create natural, abstractive responses.</p>
<p>This methodology produces answers that sound conversational. This makes the dataset highly realistic for dialogue applications.</p>
<h3 id="domain-coverage">Domain Coverage</h3>
<p>CoQA spans diverse text types to ensure evaluation across different writing styles and topics:</p>
<p><strong>Training domains (5):</strong></p>
<ul>
<li>Children&rsquo;s stories from <a href="https://web.archive.org/web/20180829214346/https://uclmr.github.io/ai4exams/data.html#mctest">MCTest</a></li>
<li>Literature from <a href="https://www.gutenberg.org/">Project Gutenberg</a></li>
<li>Educational content from <a href="https://www.cs.cmu.edu/~glai1/data/race/">RACE</a> (middle/high school English)</li>
<li>CNN news articles</li>
<li>Wikipedia articles</li>
</ul>
<p><strong>Test-only domains (2):</strong></p>
<ul>
<li>Science articles from <a href="http://data.allenai.org/ai2-science-questions/">AI2 Science Questions</a></li>
<li>Creative writing from <a href="https://www.reddit.com/r/WritingPrompts/">Reddit WritingPrompts</a></li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_domains.webp"
         alt="Domain distribution in the CoQA dataset"
         title="Domain distribution in the CoQA dataset"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Domain distribution in the CoQA dataset</figcaption>
    
</figure>

<p>The inclusion of test-only domains provides a rigorous evaluation of model generalization to unseen text types.</p>
<h2 id="comparison-with-existing-datasets">Comparison with Existing Datasets</h2>
<p>Prior to CoQA, the dominant question answering benchmark was <a href="https://rajpurkar.github.io/SQuAD-explorer/">SQuAD (Stanford Question Answering Dataset)</a>. SQuAD established foundations for reading comprehension and presented specific constraints:</p>
<ul>
<li><strong>SQuAD 1.0</strong>: 100,000+ questions requiring exact text extraction from Wikipedia passages</li>
<li><strong>SQuAD 2.0</strong>: Added 50,000+ unanswerable questions to test when no answer exists</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_size.webp"
         alt="Scale comparison between SQuAD and CoQA datasets"
         title="Scale comparison between SQuAD and CoQA datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Scale comparison between SQuAD and CoQA datasets</figcaption>
    
</figure>

<p>SQuAD treats each question independently and requires only extractive answers. CoQA addresses these constraints through conversational context and abstractive responses.</p>
<h3 id="question-and-answer-analysis">Question and Answer Analysis</h3>
<p>The differences between SQuAD and CoQA extend beyond conversational context:</p>
<p><strong>Question diversity</strong>: SQuAD heavily favors &ldquo;what&rdquo; questions (~50%). CoQA shows a more balanced distribution across question types, reflecting natural conversation patterns.</p>















<figure class="post-figure center ">
    <img src="/img/squad_v_coqa.webp"
         alt="Question type distribution comparison between SQuAD and CoQA"
         title="Question type distribution comparison between SQuAD and CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Question type distribution comparison between SQuAD and CoQA</figcaption>
    
</figure>

<p><strong>Context dependence</strong>: CoQA includes challenging single-word questions like &ldquo;who?&rdquo;, &ldquo;where?&rdquo;, or &ldquo;why?&rdquo; that depend entirely on dialogue history.</p>
<p><strong>Answer characteristics</strong>: CoQA answers vary significantly in length and style. SQuAD primarily features extractive spans.</p>















<figure class="post-figure center ">
    <img src="/img/squad_coqa_answers.webp"
         alt="Answer length distribution in SQuAD vs CoQA"
         title="Answer length distribution in SQuAD vs CoQA"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Answer length distribution in SQuAD vs CoQA</figcaption>
    
</figure>

<h2 id="the-coreference-challenge">The Coreference Challenge</h2>
<p>CoQA&rsquo;s difficulty stems largely from its reliance on coreference resolution (determining when different expressions refer to the same entity). This remains a challenging research problem in NLP.</p>
<p><strong>Coreference types in CoQA</strong>:</p>
<ul>
<li><strong>Explicit coreferences</strong> (~50% of questions): Clear indicators like pronouns (&ldquo;him,&rdquo; &ldquo;it,&rdquo; &ldquo;her,&rdquo; &ldquo;that&rdquo;)</li>
<li><strong>Implicit coreferences</strong> (~20% of questions): Context-dependent references requiring inference (e.g., asking &ldquo;where?&rdquo; without specifying what)</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/coqa_coreferences.webp"
         alt="Distribution of coreference types in CoQA questions"
         title="Distribution of coreference types in CoQA questions"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Distribution of coreference types in CoQA questions</figcaption>
    
</figure>

<p>These linguistic phenomena make CoQA more difficult than traditional reading comprehension, as models must resolve references across dialogue turns while maintaining conversational coherence.</p>
<h2 id="performance-benchmarks">Performance Benchmarks</h2>
<p>Models faced significant challenges on CoQA, with substantial room for improvement:</p>















<figure class="post-figure center ">
    <img src="/img/coqa_scores.webp"
         alt="Performance comparison on CoQA across different model types"
         title="Performance comparison on CoQA across different model types"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Performance comparison on CoQA across different model types</figcaption>
    
</figure>

<p>The performance gap between human and machine capabilities highlighted conversational question answering as a challenging frontier in NLP research.</p>
<h2 id="research-impact-and-future-directions">Research Impact and Future Directions</h2>
<p>CoQA represents a step toward more natural conversational AI systems. By requiring models to handle dialogue context, coreference resolution, and abstractive reasoning simultaneously, it challenges current NLP system capabilities.</p>
<p>The dataset&rsquo;s <a href="https://stanfordnlp.github.io/coqa/">leaderboard</a> provides a benchmark for measuring progress on this task. As models improve on CoQA, we can expect advances in conversational AI applications, from chatbots to virtual assistants that engage in more natural, context-aware dialogue.</p>
<p>CoQA&rsquo;s contribution to the field aims to parallel ImageNet&rsquo;s impact on computer vision, providing a challenging, well-constructed benchmark that drives research toward more capable AI systems.</p>
<h2 id="a-builders-perspective-coqa-in-the-era-of-llms">A Builder&rsquo;s Perspective: CoQA in the Era of LLMs</h2>
<p>Looking back at CoQA from the perspective of modern production systems, the dataset anticipated where the field went. The challenges it introduced, such as multi-turn reasoning, coreference resolution, and abstractive answering, are the exact capabilities we now expect from instruction-tuned Large Language Models (LLMs).</p>
<p>Production document-processing pipelines rarely extract isolated facts. Users want to chat with their documents, asking follow-up questions like, &ldquo;What does that mean for the Q3 budget?&rdquo; Resolving &ldquo;that&rdquo; to a previous turn&rsquo;s context is exactly the problem CoQA formalized. Datasets like CoQA shifted the field&rsquo;s focus from simple extraction toward dialogue comprehension, the foundation modern conversational document interfaces are built on.</p>
<h2 id="references">References</h2>
<p>Reddy, S., Chen, D., &amp; Manning, C. D. (2019). CoQA: A conversational question answering challenge. <em>Transactions of the Association for Computational Linguistics</em>, 7, 249-266.</p>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions introduced by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) still inform how the field thinks about loss-function design and training stability today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan">Wasserstein GAN (WGAN)</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> replaced Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance, which gives meaningful gradients even when the real and generated distributions do not overlap.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with other Wasserstein-based methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>Moment-matching objectives like this are conceptually related to settings where aligning statistical moments matters, such as matching a generated distribution to a target physical distribution (e.g., molecular conformations). McGAN itself, however, was introduced and demonstrated as an IPM method for image generation.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work training VLMs on terabyte-scale multimodal data and forecasting chaotic physical systems, these foundational dynamics still matter. Most generation today runs on diffusion models or autoregressive transformers, but the loss-design and training-stability lessons that came out of GAN research carry over. The choice of objective function shapes generation quality, training stability, and compute cost.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item><item><title>Word Embeddings in NLP: An Introduction</title><link>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</link><pubDate>Sun, 05 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/intro-to-word-embeddings/</guid><description>Learn about word embeddings in NLP: from basic one-hot encoding to contextual models like ELMo. Guide with examples.</description><content:encoded><![CDATA[<h2 id="understanding-word-embeddings">Understanding Word Embeddings</h2>
<p>A word embedding maps words to real-valued vectors:</p>
<p>$$
\text{word} \rightarrow \mathbb{R}^n
$$</p>
<p>where $n$ represents the dimensionality of the embedding space.</p>
<p>The goal is simple: position semantically similar words close together in vector space. This dense representation typically uses hundreds of dimensions, a massive reduction from the millions required by one-hot encoding.</p>
<p>Word embeddings are grounded in <a href="https://en.wikipedia.org/wiki/Distributional_semantics">Zellig Harris&rsquo; distributional hypothesis</a>: words appearing in similar contexts tend to have similar meanings. This forms the foundation of distributional semantics.</p>















<figure class="post-figure center ">
    <img src="/img/distributional_semantics-50.webp"
         alt="Distributional semantics visualization"
         title="Distributional semantics visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Words embedded in three-dimensional space, organized by semantic similarity</figcaption>
    
</figure>

<p>Different embedding algorithms capture various aspects of this distributional principle. This post explores the main methods for creating word embeddings and their applications in natural language processing.</p>
<p>While modern foundation models and large Vision-Language Models rely on subword tokenizers (like BPE) and Transformer embedding layers, the goal is the same: mapping discrete text to a continuous vector space where math can capture meaning. These foundational techniques build the intuition for the embedding layers in today&rsquo;s models.</p>
<h2 id="why-word-embeddings-matter-in-nlp">Why Word Embeddings Matter in NLP</h2>
<p>Computers require numerical representations to apply machine learning algorithms to text. Word embeddings bridge this gap by converting text into dense vectors that preserve semantic and syntactic relationships.</p>
<p><strong>Key advantages:</strong></p>
<ol>
<li><strong>Dense representation</strong>: Hundreds of dimensions provide a compact alternative to vocabulary-sized sparse vectors.</li>
<li><strong>Semantic preservation</strong>: Similar words cluster together in vector space.</li>
<li><strong>Mathematical operations</strong>: Enable analogical reasoning ($\text{king} - \text{man} + \text{woman} \approx \text{queen}$).</li>
<li><strong>Transfer learning</strong>: Pre-trained embeddings work across multiple tasks and domains.</li>
</ol>
<p>Modern deep learning architectures leverage these properties extensively. The development of universal, pre-trained embeddings was a significant step forward. We can use versatile embeddings that generalize across applications, eliminating the need to train task-specific representations from scratch.</p>
<h2 id="word-embedding-approaches">Word Embedding Approaches</h2>
<h3 id="one-hot-encoding-and-count-vectorization">One-Hot Encoding and Count Vectorization</h3>
<p>One-hot encoding represents the simplest approach to word vectorization. Each word gets a unique dimension in a vocabulary-sized vector, marked with 1 for presence and 0 elsewhere. Count vectorization extends this by counting the occurrences of each word in a document.</p>















<figure class="post-figure center ">
    <img src="/img/word_vector_onehot-50.webp"
         alt="One-hot encoding visualization"
         title="One-hot encoding visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">One-hot encoding creates sparse vectors with single active dimensions</figcaption>
    
</figure>

<p><strong>Characteristics:</strong></p>
<ul>
<li><strong>High dimensionality</strong>: Vector length equals vocabulary size.</li>
<li><strong>Extreme sparsity</strong>: Most dimensions contain zeros.</li>
<li><strong>No relationships</strong>: Treats all words as equally distant.</li>
<li><strong>Computational efficiency</strong>: Simple to implement and understand.</li>
</ul>
<p>While lacking semantic information, count vectorization serves as a foundation for more complex methods. Let&rsquo;s look at a practical implementation using scikit-learn&rsquo;s <code>CountVectorizer</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize the vectorizer</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Sample text for demonstration</span>
</span></span><span style="display:flex;"><span>sample_text <span style="color:#f92672">=</span> [<span style="color:#e6db74">&#34;One of the most basic ways we can numerically represent words &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;is through the one-hot encoding method (also sometimes called &#34;</span>
</span></span><span style="display:flex;"><span>               <span style="color:#e6db74">&#34;count vectorizing).&#34;</span>]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Fit the vectorizer to our text data</span>
</span></span><span style="display:flex;"><span>vectorizer<span style="color:#f92672">.</span>fit(sample_text)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Examine the vocabulary and word indices</span>
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Vocabulary:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vectorizer<span style="color:#f92672">.</span>vocabulary_)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform text to vectors</span>
</span></span><span style="display:flex;"><span>vector <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(sample_text)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">&#39;Full vector:&#39;</span>)
</span></span><span style="display:flex;"><span>print(vector<span style="color:#f92672">.</span>toarray())
</span></span></code></pre></div><p>At scale, count vectorization introduces engineering challenges. With millions of documents, the vocabulary grows large, and the sparse matrices become expensive to store and compute on. In these scaling scenarios, practitioners often turn to the <strong>Hashing Trick</strong> (via <code>HashingVectorizer</code>) to bound the dimensionality, or they move entirely to the dense embeddings discussed later in this post.</p>
<p>We can see count vectorization in action with a real dataset, building a simple text classifier for the <a href="https://www.kaggle.com/datasets/crawford/20-newsgroups">20 Newsgroups dataset</a>:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.datasets <span style="color:#f92672">import</span> fetch_20newsgroups
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.feature_extraction.text <span style="color:#f92672">import</span> CountVectorizer
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn.naive_bayes <span style="color:#f92672">import</span> MultinomialNB
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> sklearn <span style="color:#f92672">import</span> metrics
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Load train and test splits, removing metadata for a cleaner signal</span>
</span></span><span style="display:flex;"><span>newsgroups_train <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;train&#39;</span>,
</span></span><span style="display:flex;"><span>                                      remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>newsgroups_test <span style="color:#f92672">=</span> fetch_20newsgroups(subset<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;test&#39;</span>,
</span></span><span style="display:flex;"><span>                                     remove<span style="color:#f92672">=</span>(<span style="color:#e6db74">&#39;headers&#39;</span>, <span style="color:#e6db74">&#39;footers&#39;</span>, <span style="color:#e6db74">&#39;quotes&#39;</span>))
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Initialize and fit vectorizer on training data</span>
</span></span><span style="display:flex;"><span>vectorizer <span style="color:#f92672">=</span> CountVectorizer()
</span></span><span style="display:flex;"><span>X_train <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>fit_transform(newsgroups_train<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Build and train classifier</span>
</span></span><span style="display:flex;"><span>classifier <span style="color:#f92672">=</span> MultinomialNB(alpha<span style="color:#f92672">=</span><span style="color:#ae81ff">0.01</span>)
</span></span><span style="display:flex;"><span>classifier<span style="color:#f92672">.</span>fit(X_train, newsgroups_train<span style="color:#f92672">.</span>target)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Transform test data and make predictions</span>
</span></span><span style="display:flex;"><span>X_test <span style="color:#f92672">=</span> vectorizer<span style="color:#f92672">.</span>transform(newsgroups_test<span style="color:#f92672">.</span>data)
</span></span><span style="display:flex;"><span>y_pred <span style="color:#f92672">=</span> classifier<span style="color:#f92672">.</span>predict(X_test)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Evaluate performance</span>
</span></span><span style="display:flex;"><span>accuracy <span style="color:#f92672">=</span> metrics<span style="color:#f92672">.</span>accuracy_score(newsgroups_test<span style="color:#f92672">.</span>target, y_pred)
</span></span><span style="display:flex;"><span>print(<span style="color:#e6db74">f</span><span style="color:#e6db74">&#39;Accuracy: </span><span style="color:#e6db74">{</span>accuracy<span style="color:#e6db74">:</span><span style="color:#e6db74">.3f</span><span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;</span>)
</span></span></code></pre></div><p>This provides a solid baseline. To capture actual semantic meaning and reduce dimensionality, we must move beyond simple counting.</p>
<h3 id="tf-idf-term-frequency-inverse-document-frequency">TF-IDF (Term Frequency-Inverse Document Frequency)</h3>
<p><a href="https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html">TF-IDF</a> extends one-hot encoding by weighting terms based on their importance across a document collection. TF-IDF combines:</p>
<ul>
<li><strong>Term Frequency (TF)</strong>: How often a word appears in a document</li>
<li><strong>Inverse Document Frequency (IDF)</strong>: How rare a word is across all documents</li>
</ul>
<p>This weighting scheme reduces the impact of common words (like &ldquo;the&rdquo; or &ldquo;and&rdquo;) while emphasizing distinctive terms that appear frequently in specific documents but rarely elsewhere.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li>Captures document-level importance</li>
<li>Reduces impact of stop words</li>
<li>Effective for information retrieval tasks</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>Still high-dimensional and sparse</li>
<li>No semantic relationships between terms</li>
<li>Context-independent representation</li>
</ul>
<h3 id="co-occurrence-matrices">Co-Occurrence Matrices</h3>
<p>Co-occurrence matrices capture word relationships by recording which terms appear together within defined contexts (sentences, paragraphs, or fixed windows). The resulting matrix has dimensions equal to vocabulary size squared, with entries showing co-occurrence frequency.</p>















<figure class="post-figure center ">
    <img src="/img/Word_co-occurrence_network_%28range_3_words%29_-_ENG-50.webp"
         alt="Co-occurrence network visualization"
         title="Co-occurrence network visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Co-occurrence relationships within a three-word window</figcaption>
    
</figure>

<p><strong>Key properties:</strong></p>
<ul>
<li><strong>Global statistics</strong>: Captures corpus-wide word relationships</li>
<li><strong>Symmetric relationships</strong>: Mutual co-occurrence patterns</li>
<li><strong>Extreme dimensionality</strong>: Vocabulary size squared creates storage challenges</li>
<li><strong>Sparse representation</strong>: Most word pairs never co-occur</li>
</ul>
<p>While computationally expensive to store and process, co-occurrence matrices form the foundation for advanced methods like GloVe that compress this information into dense representations.</p>
<h2 id="neural-network-based-embeddings">Neural Network-Based Embeddings</h2>
<h3 id="neural-probabilistic-language-models">Neural Probabilistic Language Models</h3>
<p><a href="https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf">Neural probabilistic models</a> pioneered the use of neural networks for learning word embeddings. These models learn dense representations as a byproduct of language modeling, predicting the next word in a sequence.</p>















<figure class="post-figure center ">
    <img src="/img/bengio-npm-50.webp"
         alt="Neural probabilistic model diagram"
         title="Neural probabilistic model diagram"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Architecture of neural probabilistic language models</figcaption>
    
</figure>

<p><strong>Training process:</strong></p>
<ol>
<li>Initialize random dense embeddings for each vocabulary word</li>
<li>Use embeddings as inputs to predict language modeling objectives</li>
<li>Update embeddings through backpropagation based on prediction errors</li>
<li>Resulting embeddings capture patterns useful for the training task</li>
</ol>
<p>This approach demonstrated that task-specific embeddings could be learned jointly with model objectives, establishing the foundation for modern embedding methods.</p>
<h3 id="word2vec">Word2Vec</h3>
<p><a href="https://code.google.com/archive/p/word2vec/">Word2Vec</a> made word embeddings practical at scale by introducing efficient training algorithms for massive corpora. It popularized compelling vector arithmetic properties, enabling analogical reasoning like the famous &ldquo;$\text{king} - \text{man} + \text{woman} \approx \text{queen}$&rdquo; example (a vector-offset regularity first reported by Mikolov, Yih &amp; Zweig (2013) on recurrent-network language-model embeddings).</p>















<figure class="post-figure center ">
    <img src="/img/Word_vector_illustration.webp"
         alt="Word2Vec vector arithmetic visualization"
         title="Word2Vec vector arithmetic visualization"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Word2Vec demonstrates analogical relationships through vector arithmetic</figcaption>
    
</figure>

<p><strong>Two training architectures:</strong></p>
<h4 id="continuous-bag-of-words-cbow">Continuous Bag-of-Words (CBOW)</h4>
<p>Predicts target words from surrounding context words. Given a window of context words, the model learns to predict the central word.</p>
<h4 id="skip-gram">Skip-Gram</h4>
<p>Predicts context words from target words. Given a central word, the model learns to predict surrounding words within a defined window.</p>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>Computational efficiency</strong>: Much faster than neural probabilistic models</li>
<li><strong>Scalable training</strong>: Can process billion-word corpora effectively</li>
<li><strong>Quality embeddings</strong>: Captures semantic and syntactic relationships</li>
<li><strong>Flexible context</strong>: Window size controls topical vs. functional similarity</li>
</ul>
<p>The choice of window size significantly impacts learned relationships. Larger windows capture topical associations, while smaller windows focus on syntactic and functional similarities.</p>
<h3 id="glove-global-vectors">GloVe (Global Vectors)</h3>
<p><a href="https://nlp.stanford.edu/projects/glove/">GloVe</a> combines the best aspects of matrix factorization methods (which capture global corpus statistics) and local context window approaches like Word2Vec. Matrix factorization methods excel at global patterns but struggle with analogical reasoning, while Word2Vec captures local relationships but may miss global structure.</p>
<p><strong>Key innovation:</strong>
GloVe trains on a global word-context co-occurrence matrix, incorporating corpus-wide statistical information while maintaining the analogical reasoning capabilities that made Word2Vec successful.</p>
<p><strong>Advantages over Word2Vec:</strong></p>
<ul>
<li><strong>Global optimization</strong>: Leverages entire corpus statistics</li>
<li><strong>Better performance</strong>: Often outperforms Word2Vec on word similarity and analogy tasks</li>
<li><strong>Stable training</strong>: More consistent convergence due to global objective function</li>
</ul>
<p>The result is embeddings that capture both local syntactic patterns and global semantic relationships more effectively.</p>
<h2 id="contextual-embedding-methods">Contextual Embedding Methods</h2>
<h3 id="fasttext">FastText</h3>
<p><a href="https://github.com/facebookresearch/fastText">FastText</a> addresses a critical limitation of previous methods: handling out-of-vocabulary (OOV) words. By incorporating subword information, FastText can generate meaningful representations for previously unseen words.</p>
<p><strong>Subword approach:</strong></p>
<ul>
<li>Decomposes words into character n-grams (typically 3-6 characters)</li>
<li>Represents words as sums of their component n-grams</li>
<li>Trains using skip-gram objective with negative sampling</li>
</ul>
<p><strong>Key advantages:</strong></p>
<ul>
<li><strong>OOV handling</strong>: Can embed unseen words using known subword components</li>
<li><strong>Morphological awareness</strong>: Captures relationships between related word forms</li>
<li><strong>Multilingual support</strong>: Facebook released pre-trained embeddings for 294 languages</li>
<li><strong>Robust performance</strong>: Particularly effective for morphologically rich languages</li>
</ul>
<p>For example, if the model knows &ldquo;navigate,&rdquo; it can provide meaningful representation for &ldquo;circumnavigate&rdquo; by leveraging shared subword components, even if &ldquo;circumnavigate&rdquo; wasn&rsquo;t in the training data.</p>
<h3 id="poincaré-embeddings">Poincaré Embeddings</h3>
<p><a href="https://radimrehurek.com/gensim/models/poincare.html">Poincaré embeddings</a> introduce a novel approach by learning representations in hyperbolic space. This geometric innovation specifically targets hierarchical relationships in data.</p>
<p><strong>Hyperbolic geometry advantages:</strong></p>
<ul>
<li><strong>Natural hierarchy encoding</strong>: Distance represents similarity, while norm encodes hierarchical level</li>
<li><strong>Efficient representation</strong>: Requires fewer dimensions for hierarchical data</li>
<li><strong>Mathematical elegance</strong>: Leverages properties of hyperbolic space for embedding optimization</li>
</ul>
<p><strong>Applications:</strong>
Particularly effective for data with inherent hierarchical structure, such as:</p>
<ul>
<li>WordNet taxonomies</li>
<li>Organizational charts</li>
<li>Computer network topologies</li>
<li>Knowledge graphs</li>
</ul>
<p>The <a href="https://arxiv.org/abs/1705.08039">original paper</a> demonstrates good efficiency in reproducing WordNet relationships with significantly lower dimensionality compared to traditional embedding methods.</p>
<h2 id="contextual-embeddings">Contextual Embeddings</h2>
<h3 id="elmo-embeddings-from-language-models">ELMo (Embeddings from Language Models)</h3>
<p><a href="https://github.com/allenai/allennlp-models">ELMo</a> represents a paradigm shift toward contextual word representations. ELMo generates dynamic representations based on sentence context, adapting to word usage patterns.</p>
<p><strong>Architecture:</strong></p>
<ul>
<li><strong>Bidirectional LSTM</strong>: Processes text in both forward and backward directions</li>
<li><strong>Character-level input</strong>: Handles OOV words and captures morphological patterns</li>
<li><strong>Multi-layer representations</strong>: Combines different abstraction levels</li>
</ul>
<p><strong>Layer specialization:</strong></p>
<ul>
<li><strong>Lower layers</strong>: Excel at syntactic tasks (POS tagging, parsing)</li>
<li><strong>Higher layers</strong>: Capture semantic relationships (word sense disambiguation)</li>
<li><strong>Combined layers</strong>: Weighted combination achieves good performance</li>
</ul>
<p><strong>Key innovation:</strong>
ELMo embeddings vary by context. The word &ldquo;bank&rdquo; receives different representations in &ldquo;river bank&rdquo; versus &ldquo;financial bank,&rdquo; addressing polysemy directly through contextual awareness.</p>
<p>This approach achieved strong performance across numerous NLP tasks by providing context-sensitive representations that adapt to word usage patterns.</p>
<h3 id="probabilistic-fasttext">Probabilistic FastText</h3>
<p><a href="https://github.com/benathi/multisense-prob-fasttext">Probabilistic FastText</a> addresses polysemy (words with multiple meanings) through probabilistic modeling. Traditional embeddings conflate different word senses into single representations, limiting their precision.</p>
<p><strong>The polysemy problem:</strong>
Consider &ldquo;rock&rdquo; which can mean:</p>
<ul>
<li>Rock music (genre)</li>
<li>A stone (geological object)</li>
<li>Rocking motion (verb)</li>
</ul>
<p>Standard embeddings average these meanings, producing representations that may not capture any sense precisely.</p>
<p><strong>Probabilistic approach:</strong>
Probabilistic FastText represents words as Gaussian mixture models: probability distributions that can capture multiple distinct meanings as separate components.</p>
<p><strong>Advantages:</strong></p>
<ul>
<li><strong>Multi-sense representation</strong>: Each word sense gets its own distribution</li>
<li><strong>Context sensitivity</strong>: Can select appropriate sense based on usage context</li>
<li><strong>Uncertainty quantification</strong>: Probabilistic framework captures embedding confidence</li>
</ul>
<p>This approach provides a more nuanced treatment of lexical ambiguity, particularly valuable for words with distinct, context-dependent meanings.</p>
<h2 id="summary-and-future-directions">Summary and Future Directions</h2>
<p>Word embeddings have evolved from simple one-hot encodings to contextual representations that capture nuanced linguistic relationships. Each approach offers distinct advantages:</p>
<p><strong>Static embeddings</strong> (Word2Vec, GloVe, FastText) provide:</p>
<ul>
<li>Computational efficiency for large-scale applications</li>
<li>Pre-trained models available for numerous languages</li>
<li>Clear analogical reasoning capabilities</li>
<li>Good performance on many downstream tasks</li>
</ul>
<p><strong>Contextual embeddings</strong> (ELMo, BERT, GPT) offer:</p>
<ul>
<li>Dynamic representations based on sentence context</li>
<li>Better handling of polysemy and word sense disambiguation</li>
<li>Strong performance on complex NLP tasks</li>
<li>Ability to capture subtle contextual nuances</li>
</ul>
<p><strong>Choosing the right approach</strong> depends on:</p>
<ul>
<li><strong>Task requirements</strong>: Static embeddings for efficiency, contextual for accuracy</li>
<li><strong>Data availability</strong>: Pre-trained models vs. domain-specific training</li>
<li><strong>Computational constraints</strong>: Static embeddings require less processing power</li>
<li><strong>Language coverage</strong>: Consider availability of pre-trained models for target languages</li>
</ul>
<p>The field continues advancing toward more efficient contextual models, better multilingual representations, and embeddings that capture increasingly complex linguistic phenomena.</p>
<p>For a from-scratch Word2Vec implementation in PyTorch (Skip-gram and CBOW, with hierarchical softmax and negative sampling) that takes these concepts further, see the <a href="/projects/modern-word2vec/">PyTorch Word2Vec project</a>.</p>
]]></content:encoded></item></channel></rss>