<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Language Models on Hunter Heidenreich | ML Research Scientist</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/</link><description>Recent content in Language Models on Hunter Heidenreich | ML Research Scientist</description><image><title>Hunter Heidenreich | ML Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Thu, 09 Apr 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/notes/natural-language-processing/language-models/index.xml" rel="self" type="application/rss+xml"/><item><title>T5: Exploring Transfer Learning Limits</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/</guid><description>Raffel et al. systematically study transfer learning for NLP with a text-to-text framework, ablating architectures, objectives, data, and multi-task mixing.</description><content:encoded><![CDATA[<h2 id="a-systematic-study-of-nlp-transfer-learning">A systematic study of NLP transfer learning</h2>
<p>This is a <strong>systematization paper</strong> that provides a comprehensive empirical survey of transfer learning techniques for NLP. Rather than proposing a single new method, T5 introduces a unified text-to-text framework and uses it as a testbed to systematically compare pre-training objectives, architectures, unlabeled data sources, transfer approaches, and multi-task mixing strategies. The scale of the ablation study (covering dozens of configurations) and the release of C4, pre-trained models, and code make it both a reference guide and a resource.</p>
<h2 id="unifying-nlp-tasks-as-text-to-text">Unifying NLP tasks as text-to-text</h2>
<p>The core design decision is to cast every NLP task as a text-to-text problem: both the input and output are text strings, with a task-specific prefix. Classification, regression, summarization, translation, and question answering all use the same model, loss function (cross-entropy on output tokens), and decoding procedure. This simplicity enables fair comparison across tasks and training strategies.</p>
<p>The model architecture is a standard encoder-decoder Transformer. The paper finds that this form outperforms decoder-only (language model) and encoder-only (BERT-style) variants in the text-to-text setting, while having similar computational cost to decoder-only models despite twice the parameters (the encoder processes the input only once, then the decoder attends to it).</p>
<h2 id="multi-task-mixing-strategies-and-findings">Multi-task mixing: strategies and findings</h2>
<p>The most thesis-relevant contribution is the systematic ablation of multi-task mixing strategies (Section 3.5.2). When training on multiple tasks simultaneously (which in the text-to-text framework simply means mixing data from different sources), the central question is how to set the proportion of data from each task.</p>
<h3 id="three-mixing-strategies">Three mixing strategies</h3>
<p><strong>Examples-proportional mixing.</strong> Sample in proportion to each dataset&rsquo;s size, with an artificial cap $K$ on the maximum dataset size. Without the cap, the unsupervised pre-training data (orders of magnitude larger) would dominate all batches. The mixing rate for task $m$ is:</p>
<p>$$
r_{m} = \frac{\min(e_{m}, K)}{\sum_{n} \min(e_{n}, K)}
$$</p>
<p>where $e_{m}$ is the number of examples in task $m$&rsquo;s dataset.</p>
<p><strong>Temperature-scaled mixing.</strong> Raise each mixing rate $r_{m}$ to the power $1/T$ and renormalize. At $T=1$ this equals examples-proportional mixing; as $T$ increases, proportions approach equal mixing. Uses a large cap $K = 2^{21}$.</p>
<p><strong>Equal mixing.</strong> Sample uniformly from all tasks. Included as a negative reference: the model overfits on low-resource tasks and underfits on high-resource tasks.</p>
<h3 id="results">Results</h3>
<table>
  <thead>
      <tr>
          <th>Mixing strategy</th>
          <th>GLUE</th>
          <th>CNN/DM</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
          <th>EnDe</th>
          <th>EnFr</th>
          <th>EnRo</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Baseline (pre-train/fine-tune)</td>
          <td>83.28</td>
          <td>19.24</td>
          <td>80.88</td>
          <td>71.36</td>
          <td>26.98</td>
          <td>39.82</td>
          <td>27.65</td>
      </tr>
      <tr>
          <td>Equal</td>
          <td>76.13</td>
          <td>19.02</td>
          <td>76.51</td>
          <td>63.37</td>
          <td>23.89</td>
          <td>34.31</td>
          <td>26.78</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{18}$</td>
          <td>81.67</td>
          <td>19.07</td>
          <td>78.17</td>
          <td>67.94</td>
          <td>24.57</td>
          <td>35.19</td>
          <td>27.39</td>
      </tr>
      <tr>
          <td>Examples-proportional, $K=2^{19}$</td>
          <td>81.42</td>
          <td>19.24</td>
          <td>79.78</td>
          <td>67.30</td>
          <td>25.21</td>
          <td>36.30</td>
          <td>27.76</td>
      </tr>
      <tr>
          <td>Temperature-scaled, $T=2$</td>
          <td>81.90</td>
          <td>19.28</td>
          <td>79.42</td>
          <td>69.92</td>
          <td>25.42</td>
          <td>36.72</td>
          <td>27.20</td>
      </tr>
  </tbody>
</table>
<p><strong>Key findings on mixing:</strong></p>
<ol>
<li>
<p><strong>Multi-task training underperforms pre-train-then-fine-tune on most tasks.</strong> No mixing strategy matches the baseline of unsupervised pre-training followed by task-specific fine-tuning.</p>
</li>
<li>
<p><strong>Equal mixing is worst.</strong> It dramatically degrades performance, confirming that proportions matter.</p>
</li>
<li>
<p><strong>There exists a task-specific sweet spot for the cap $K$.</strong> Most tasks have an optimal $K$ value; larger or smaller values hurt. The exception is very high-resource tasks (WMT English-French) that always benefit from higher mixing proportions.</p>
</li>
<li>
<p><strong>Temperature scaling at $T=2$ provides the best single compromise.</strong> It achieves reasonable performance across all tasks without requiring per-task tuning of $K$.</p>
</li>
<li>
<p><strong>Multi-task pre-training followed by fine-tuning closes the gap.</strong> When multi-task training is used as pre-training (not as the final training stage), followed by task-specific fine-tuning, performance becomes comparable to unsupervised pre-training alone. This suggests that multi-task exposure during pre-training provides useful early signal without the negative effects of forcing a single model to perform all tasks simultaneously.</p>
</li>
<li>
<p><strong>&ldquo;Leave-one-out&rdquo; training works.</strong> Pre-training on a multi-task mixture that excludes a target task, then fine-tuning on it, produces only slightly worse results. This indicates that multi-task pre-training builds general capabilities that transfer to unseen tasks without dramatic task interference.</p>
</li>
</ol>
<h2 id="data-repetition-degrades-performance">Data repetition degrades performance</h2>
<p>The paper also systematically tests the effect of pre-training data set size by truncating C4 and training over repeated data:</p>
<table>
  <thead>
      <tr>
          <th>Unique tokens</th>
          <th>Repeats</th>
          <th>GLUE</th>
          <th>SQuAD</th>
          <th>SuperGLUE</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Full dataset</td>
          <td>0</td>
          <td>83.28</td>
          <td>80.88</td>
          <td>71.36</td>
      </tr>
      <tr>
          <td>$2^{29}$</td>
          <td>64</td>
          <td>82.87</td>
          <td>80.97</td>
          <td>72.03</td>
      </tr>
      <tr>
          <td>$2^{27}$</td>
          <td>256</td>
          <td>82.62</td>
          <td>79.78</td>
          <td>69.97</td>
      </tr>
      <tr>
          <td>$2^{25}$</td>
          <td>1,024</td>
          <td>79.55</td>
          <td>76.27</td>
          <td>64.76</td>
      </tr>
      <tr>
          <td>$2^{23}$</td>
          <td>4,096</td>
          <td>76.34</td>
          <td>70.92</td>
          <td>59.29</td>
      </tr>
  </tbody>
</table>
<p>Performance degrades as data shrinks, with 64 repeats showing limited effects but 1,024+ repeats causing significant degradation. Training loss curves confirm memorization at high repetition counts. The paper recommends using large, diverse pre-training datasets whenever possible.</p>
<h2 id="scaling-and-final-configuration">Scaling and final configuration</h2>
<p>The paper compares scaling strategies: more data, larger models, and ensembles. Training a larger model for fewer steps generally outperforms training a smaller model on more data. Ensembles of independently pre-trained and fine-tuned models provide orthogonal gains.</p>
<p>The final T5-11B model combines the best choices from all ablations: encoder-decoder architecture, span corruption objective, C4 pre-training data, multi-task pre-training followed by fine-tuning, and scaling to 11B parameters trained on over 1 trillion tokens. It achieves state-of-the-art results on GLUE (90.3 average), SuperGLUE (88.9, near human performance of 89.8), SQuAD, and CNN/Daily Mail. It does not achieve state-of-the-art on WMT translation tasks, where methods using backtranslation and cross-lingual pre-training retain the lead.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The T5 paper&rsquo;s multi-task mixing findings are its most enduring contribution beyond the model itself. The core lessons: proportions matter enormously (equal mixing fails), examples-proportional mixing with a cap is a reasonable default, temperature scaling provides a single-knob alternative, and multi-task pre-training followed by fine-tuning can match pure unsupervised pre-training.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>All ablations use the same encoder-decoder architecture. Findings may not transfer to decoder-only models that dominate current practice.</li>
<li>The multi-task mixing experiments treat each task as a separate &ldquo;domain.&rdquo; Interactions between similar tasks (e.g., multiple classification tasks) are not isolated.</li>
<li>The paper does not provide a principled method for choosing $K$ or $T$; both require empirical search.</li>
<li>C4 has known quality issues (templated text, noisy content) that have been addressed in later datasets.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, pre-trained models, and the C4 dataset are all publicly released.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pre-training</td>
          <td>C4 (Colossal Clean Crawled Corpus)</td>
          <td>~750 GB</td>
          <td>Heuristically cleaned Common Crawl</td>
      </tr>
      <tr>
          <td>Downstream</td>
          <td>GLUE, SuperGLUE, SQuAD, CNN/DM, WMT (EnDe, EnFr, EnRo)</td>
          <td>Standard splits</td>
          <td>Text-to-text format</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>Encoder-decoder Transformer. Sizes: Base (220M), Small (60M), Large (770M), 3B, 11B. Baseline uses Base size. SentencePiece vocabulary with 32K tokens. Pre-trained for $2^{19}$ steps, fine-tuned for $2^{18}$ steps on individual tasks.</p>
<h3 id="algorithms">Algorithms</h3>
<p>Multi-task mixing: examples-proportional with cap $K \in {2^{16}, \ldots, 2^{21}}$, temperature-scaled with $T \in {2, 4, 8}$, and equal mixing. Unsupervised objective: span corruption (mean span length 3, 15% corruption rate). Training with Adafactor optimizer, inverse square root learning rate schedule.</p>
<h3 id="hardware">Hardware</h3>
<p>All models trained using Mesh TensorFlow on TPU slices. T5-11B pre-trained for 1M steps with batch size $2^{11}$ sequences of length 512 (~1 trillion tokens total). Exact TPU pod configurations per experiment not detailed.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer">T5 Code</a></td>
          <td>Code</td>
          <td>Apache 2.0</td>
          <td>Official TensorFlow implementation (JAX successor: T5X)</td>
      </tr>
      <tr>
          <td><a href="https://github.com/google-research/text-to-text-transfer-transformer#released-model-checkpoints">T5 Models</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Pre-trained checkpoints (Small through 11B)</td>
      </tr>
      <tr>
          <td><a href="https://www.tensorflow.org/datasets/catalog/c4">C4 Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>~750 GB cleaned Common Crawl, via TensorFlow Datasets</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{raffel2020exploring,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{Journal of Machine Learning Research}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{21}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span>=<span style="color:#e6db74">{140}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{1--67}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2020}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>SlimPajama-DC: Data Combinations for LLM Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/slimpajama-dc-data-combinations/</guid><description>Shen et al. study how global deduplication and domain combinations in SlimPajama affect LLM training, finding diversity after dedup is key.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-data-domain-combinations">An empirical study of data domain combinations</h2>
<p>This is a <strong>discovery paper</strong> that empirically investigates how different combinations and proportions of data domains affect language model pretraining. Using the SlimPajama dataset (a globally deduplicated, 627B token refinement of RedPajama), the study trains seven 1.3B model configurations with varying domain mixtures to identify which combinations and deduplication strategies produce the best downstream performance.</p>
<h2 id="why-data-combination-strategy-matters">Why data combination strategy matters</h2>
<p>Multi-source pretraining datasets combine data from web crawls, code repositories, books, academic papers, and other sources. Two underexplored questions drive this work: (1) Does deduplication within each source (local) versus across all sources (global) meaningfully affect model quality? (2) When sources are thoroughly deduplicated, how does the combination and proportion of domains affect downstream performance? Most open-source LLM training datasets (RedPajama, The Pile) perform only local deduplication, leaving cross-source redundancy unaddressed.</p>
<h2 id="global-deduplication-and-the-slimpajama-dataset">Global deduplication and the SlimPajama dataset</h2>
<p>SlimPajama applies global MinHashLSH deduplication (Jaccard similarity threshold 0.8, 13-gram signatures) across all seven data sources simultaneously. This reduces RedPajama&rsquo;s 1.2T tokens to 627B tokens, a roughly 48% reduction. The heaviest deduplication hits CommonCrawl and GitHub, which had the most cross-source overlap.</p>
<p>The key processing steps:</p>
<ol>
<li><strong>Low-length document filtering</strong>: Remove documents below a minimum length threshold.</li>
<li><strong>Global deduplication</strong>: MinHashLSH across all sources simultaneously, requiring 64 CPU cores and 1.4TB peak memory. This removes both within-source and between-source duplicates.</li>
</ol>
<p>The resulting dataset composition:</p>
<table>
  <thead>
      <tr>
          <th>Source</th>
          <th>SlimPajama</th>
          <th>RedPajama</th>
          <th>LLaMA 1</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>52.2% (333B)</td>
          <td>72.6% (878B)</td>
          <td>67.0%</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>26.7% (170B)</td>
          <td>14.4% (175B)</td>
          <td>15.0%</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>5.2% (33B)</td>
          <td>4.9% (59B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>4.2% (27B)</td>
          <td>2.1% (26B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>4.6% (29B)</td>
          <td>2.3% (28B)</td>
          <td>2.5%</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>3.8% (24B)</td>
          <td>2.0% (24B)</td>
          <td>4.5%</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>3.3% (21B)</td>
          <td>1.7% (20B)</td>
          <td>2.0%</td>
      </tr>
  </tbody>
</table>
<h2 id="seven-domain-combination-configurations">Seven domain combination configurations</h2>
<p>All configurations train 1.3B parameter models on 330B tokens with identical architecture and hyperparameters. The configurations systematically vary domain diversity:</p>
<ul>
<li><strong>DC-1</strong>: CommonCrawl only (single source)</li>
<li><strong>DC-2</strong>: CommonCrawl + C4 (two web sources)</li>
<li><strong>DC-3</strong>: CommonCrawl + C4 with adjusted proportions</li>
<li><strong>DC-4</strong>: Wikipedia + Books + GitHub + ArXiv + StackExchange (no web crawl)</li>
<li><strong>DC-5</strong>: CommonCrawl + C4 + Wikipedia + Books (four sources, no code/academic)</li>
<li><strong>DC-6</strong>: All seven SlimPajama sources (maximum diversity)</li>
<li><strong>DC-7</strong>: RefinedWeb CommonCrawl (external single-source baseline)</li>
</ul>
<p>The experimental design probes: incremental diversity (DC-1 to DC-2 to DC-5 to DC-6), proportion sensitivity (DC-2 vs DC-3), source importance (DC-3 vs DC-4), and specialization vs generalization (individual vs combined).</p>
<h2 id="diversity-after-global-deduplication-drives-performance">Diversity after global deduplication drives performance</h2>
<h3 id="hugging-face-leaderboard-results">Hugging Face leaderboard results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Average</th>
          <th>ARC</th>
          <th>HellaSwag</th>
          <th>MMLU</th>
          <th>TruthfulQA</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>RedPajama-1.3B</td>
          <td>38.0</td>
          <td>37.2</td>
          <td>55.8</td>
          <td>24.9</td>
          <td>34.3</td>
      </tr>
      <tr>
          <td>DC-1 (CC only)</td>
          <td>38.5</td>
          <td>36.3</td>
          <td>56.0</td>
          <td>27.0</td>
          <td>34.8</td>
      </tr>
      <tr>
          <td>DC-4 (no web)</td>
          <td>37.6</td>
          <td>33.4</td>
          <td>53.3</td>
          <td>26.0</td>
          <td>37.6</td>
      </tr>
      <tr>
          <td>DC-6 (all sources)</td>
          <td>40.0</td>
          <td>33.7</td>
          <td>61.0</td>
          <td>26.9</td>
          <td>38.4</td>
      </tr>
      <tr>
          <td>DC-7 (RefinedWeb)</td>
          <td>41.0</td>
          <td>35.1</td>
          <td>64.7</td>
          <td>26.2</td>
          <td>37.9</td>
      </tr>
  </tbody>
</table>
<p><strong>Key patterns:</strong></p>
<ol>
<li>
<p><strong>More domain diversity improves average performance.</strong> The progression DC-1 (38.5) to DC-2 (38.4) to DC-5 (38.6) to DC-6 (40.0) shows that adding domains consistently lifts average accuracy once global deduplication has removed cross-source redundancy.</p>
</li>
<li>
<p><strong>Global deduplication enables clean combination.</strong> All SlimPajama configurations except DC-4 outperform RedPajama-1.3B (38.0), which uses local deduplication only. The elimination of cross-source overlap means adding sources contributes genuinely new information.</p>
</li>
<li>
<p><strong>Removing web crawl data hurts.</strong> DC-4 (no CommonCrawl/C4) scores lowest (37.6), demonstrating that web text provides essential breadth even when specialized sources are included.</p>
</li>
<li>
<p><strong>Individual domains excel at specific tasks.</strong> DC-1 (CC only) achieves the highest ARC and MMLU scores. DC-4 leads on Winogrande. DC-5 leads on WSC273. No single combination dominates all tasks, reinforcing that diversity trades specialization for generalization.</p>
</li>
<li>
<p><strong>Findings transfer to 7B scale.</strong> The best 1.3B configuration insights were applied to a 7B model trained with large batch sizes, achieving 63.4 average accuracy across the extended benchmark suite.</p>
</li>
</ol>
<h3 id="training-loss-patterns">Training loss patterns</h3>
<p>DC-6 (all sources) achieves the lowest training loss among SlimPajama configurations, consistent with the downstream results. DC-4 (no web crawl) shows the highest training loss, confirming that the large, diverse web crawl data is the most important single component.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The central finding is that <strong>diversity matters most after deduplication</strong>. When cross-source redundancy is removed, each additional source contributes genuinely new signal. Without global deduplication, adding sources may just increase redundancy without proportional benefit.</p>
<p><strong>Limitations:</strong></p>
<ul>
<li>Only seven fixed configurations are tested. No systematic search over continuous mixture proportions (contrast with <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> or <a href="/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/">Data Mixing Laws</a>).</li>
<li>The configurations are not independent: DC-6 includes all sources from DC-1 through DC-5, making it difficult to isolate the contribution of any single addition.</li>
<li>Only 1.3B and 7B scales tested. Whether the diversity benefit continues scaling is unverified.</li>
<li>English-only. Cross-lingual diversity effects are not studied.</li>
<li>The paper is a technical report without formal peer review.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> All 1.3B models and datasets are publicly released under MIT license on HuggingFace.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>SlimPajama</td>
          <td>627B tokens</td>
          <td>Globally deduplicated from 1.2T RedPajama</td>
      </tr>
      <tr>
          <td>Training</td>
          <td>RefinedWeb</td>
          <td>600B tokens</td>
          <td>External CC-only baseline</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>HF Leaderboard (ARC, HellaSwag, MMLU, TruthfulQA)</td>
          <td>Standard</td>
          <td>4 benchmarks</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>Extended suite</td>
          <td>12 additional benchmarks</td>
          <td>Zero and few-shot</td>
      </tr>
  </tbody>
</table>
<h3 id="models">Models</h3>
<p>1.3B parameter Cerebras-GPT architecture with ALiBi positional encoding and SwiGLU activation. All configurations trained on 330B tokens. 7B model trained with large batch-size (LBS) strategy on Cerebras 16x CS-2 cluster (80 PFLOP/s in bf16).</p>
<h3 id="hardware">Hardware</h3>
<p>Cerebras 16x CS-2 cluster, 80 PFLOP/s in bf16 mixed precision.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://huggingface.co/MBZUAI-LLM/SlimPajama-DC">SlimPajama-DC Models</a></td>
          <td>Model</td>
          <td>MIT</td>
          <td>All 1.3B DC configurations (select via revision)</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/datasets/MBZUAI-LLM/SlimPajama-627B-DC">SlimPajama-627B-DC Dataset</a></td>
          <td>Dataset</td>
          <td>-</td>
          <td>Source-split version of SlimPajama-627B</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{shen2023slimpajamadc,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{SlimPajama-DC: Understanding Data Combinations for LLM Training}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Shen, Zhiqiang and Tao, Tianhua and Ma, Liqun and Neiswanger, Willie and Liu, Zhengzhong and Wang, Hongyi and Tan, Bowen and Hestness, Joel and Vassilieva, Natalia and Soboleva, Daria and Xing, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span>=<span style="color:#e6db74">{arXiv preprint arXiv:2309.10818}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Scaling Data-Constrained Language Models</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/scaling-data-constrained-language-models/</guid><description>Muennighoff et al. extend Chinchilla scaling laws to repeated data, finding up to 4 epochs cause negligible loss and 16 epochs mark diminishing returns.</description><content:encoded><![CDATA[<h2 id="an-empirical-study-of-scaling-under-data-constraints">An empirical study of scaling under data constraints</h2>
<p>This is a <strong>discovery paper</strong> that systematically investigates what happens when language models are trained for multiple epochs on repeated data. It extends the Chinchilla scaling laws to the data-constrained regime by proposing a new scaling formula that accounts for the diminishing value of repeated tokens, validated across 400+ training runs ranging from 10M to 9B parameters and up to 1500 epochs.</p>
<h2 id="running-out-of-unique-training-data">Running out of unique training data</h2>
<p>The Chinchilla scaling laws assume unlimited unique data: for a given compute budget, there exists an optimal balance of model parameters and training tokens. But extrapolating these laws to larger models implies data requirements that exceed what is available. Villalobos et al. estimated that high-quality English text would be exhausted by 2024 under Chinchilla-optimal scaling. Most prior large language models trained for a single epoch, and some work explicitly warned against data reuse. The Galactica models (trained for 4.25 epochs) showed that multi-epoch training could work, but no systematic study had quantified the tradeoff between repeated data and fresh data, or how to allocate compute optimally when data is finite.</p>
<h2 id="effective-data-with-exponential-decay-for-repetition">Effective data with exponential decay for repetition</h2>
<p>The paper generalizes the Chinchilla scaling law by replacing raw token count $D$ with an effective data term $D&rsquo;$ that accounts for the diminishing value of repeated tokens:</p>
<p>$$
L(N, D) = \frac{A}{N&rsquo;^{\alpha}} + \frac{B}{D&rsquo;^{\beta}} + E
$$</p>
<p>where the effective data is:</p>
<p>$$
D&rsquo; = U_{D} + U_{D} R_{D}^{<em>} \left(1 - e^{-R_{D}/R_{D}^{</em>}}\right)
$$</p>
<p>Here $U_{D}$ is the number of unique tokens, $R_{D}$ is the number of repetitions (epochs minus 1), and $R_{D}^{<em>}$ is a learned constant representing the &ldquo;half-life&rdquo; of data repetition. When $R_{D} = 0$ (single epoch), $D&rsquo; = U_{D} = D$ and the formula reduces to standard Chinchilla. When $R_{D} \ll R_{D}^{</em>}$, repeated data is worth almost the same as fresh data. As $R_{D}$ grows large, the value of repeated tokens decays to zero, and $D&rsquo;$ saturates at $U_{D}(1 + R_{D}^{<em>})$, meaning no amount of repetition can substitute for more than $R_{D}^{</em>}$ epochs&rsquo; worth of fresh data.</p>
<p>A symmetric formula handles excess parameters:</p>
<p>$$
N&rsquo; = U_{N} + U_{N} R_{N}^{<em>} \left(1 - e^{-R_{N}/R_{N}^{</em>}}\right)
$$</p>
<p>where $U_{N}$ is the compute-optimal parameter count for $U_{D}$ unique tokens and $R_{N}$ measures how much the model exceeds that count. The fitted values are $R_{D}^{<em>} \approx 15.0$ (data repetition half-life at ~16 epochs) and $R_{N}^{</em>} \approx 5.3$ (excess parameters decay faster than repeated data).</p>
<h2 id="experiments-across-400-models">Experiments across 400+ models</h2>
<p><strong>Scale.</strong> Models from 10M to 9B parameters, trained for up to 1500 epochs. Three experimental protocols: fixed unique data (100M, 400M, 1.5B tokens), fixed FLOPs, and parametric fitting across all runs. Training on C4 (English web text) with GPT-2 architecture decoder-only transformers.</p>
<h3 id="resource-allocation-epochs-scale-faster-than-parameters">Resource allocation: epochs scale faster than parameters</h3>
<p>With fixed unique data, results show that more than 50% loss reduction is possible by training beyond one epoch and increasing model size beyond the single-epoch optimum. The data-constrained efficient frontier recommends allocating most additional compute to more epochs rather than more parameters, because excess parameters decay faster ($R_{N}^{<em>} &lt; R_{D}^{</em>}$). This contrasts with Chinchilla, which recommends scaling both equally.</p>
<p>A concrete validation: training the data-constrained compute-optimal model for $9.3 \times 10^{21}$ FLOPs with 25B unique tokens, the recommended allocation (27% fewer parameters, more epochs) achieves better loss and downstream performance than the Chinchilla-optimal allocation.</p>
<h3 id="resource-return-the-4-epoch-safe-zone-and-16-epoch-half-life">Resource return: the 4-epoch safe zone and 16-epoch half-life</h3>
<table>
  <thead>
      <tr>
          <th>Epochs</th>
          <th>Loss impact</th>
          <th>Downstream impact</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>1 (baseline)</td>
          <td>Optimal</td>
          <td>Optimal</td>
      </tr>
      <tr>
          <td>Up to 4</td>
          <td>Negligible (+0.5% loss)</td>
          <td>No significant difference</td>
      </tr>
      <tr>
          <td>~16 ($R_{D}^{*}$)</td>
          <td>Diminishing returns begin sharply</td>
          <td>Measurable degradation</td>
      </tr>
      <tr>
          <td>Beyond 16</td>
          <td>Returns decay to near zero</td>
          <td>Significant degradation</td>
      </tr>
      <tr>
          <td>Extreme (44+)</td>
          <td>Training can diverge</td>
          <td>Failure</td>
      </tr>
  </tbody>
</table>
<p>The 8.7B parameter model trained for 4 epochs ($D_{C} = 44$B unique tokens) finishes with only 0.5% higher validation loss than the single-epoch model ($D_{C} = 178$B unique tokens). Beyond 16 epochs, each repeated token retains only $1 - 1/e \approx 63%$ of the value of a fresh token, meaning roughly 37% of value is lost per repetition cycle at the half-life point.</p>
<h3 id="complementary-strategies-code-augmentation-and-filtering">Complementary strategies: code augmentation and filtering</h3>
<p>When data is limited, two strategies can extend the effective dataset:</p>
<p><strong>Code augmentation.</strong> Mixing Python code from The Stack with natural language data. Up to 50% code (42B tokens) shows no degradation on natural language benchmarks, effectively providing a 2x increase in useful training data. Some tasks (WebNLG generation, bAbI reasoning) actually improve with code, possibly because code trains long-range state-tracking capabilities.</p>
<p><strong>Filtering relaxation.</strong> Perplexity filtering (keeping the 25% lowest-perplexity samples) is effective on noisy datasets, but deduplication filtering does not improve downstream performance (though it may reduce memorization). The recommendation: reserve aggressive filtering for noisy data sources; for clean datasets, more data through reduced filtering is better than less data through strict filtering.</p>
<p><strong>Combined strategy</strong>: doubling available data with code and then repeating for 4 epochs yields 8x more training tokens with performance expected to match 8x more unique data.</p>
<h2 id="key-findings-and-limitations">Key findings and limitations</h2>
<p><strong>Key findings:</strong></p>
<ul>
<li>Multi-epoch training is beneficial, not harmful, up to moderate repetition counts.</li>
<li>The data-constrained scaling law accurately predicts loss under repetition using an exponential decay formulation.</li>
<li>Compute should be allocated to epochs faster than parameters when data is constrained.</li>
<li>Code augmentation and selective filtering extend effective data without quality degradation.</li>
</ul>
<p><strong>Limitations:</strong></p>
<ul>
<li>All experiments use the GPT-2 transformer architecture; applicability to other architectures or modalities is untested.</li>
<li>Only the entire dataset is repeated uniformly. Selectively repeating subsets (e.g., high-value data for more epochs) is not modeled.</li>
<li>Hyperparameter sensitivity (learning rate, dropout) to epoch count is unexplored. Higher learning rates may cause earlier onset of diminishing returns.</li>
<li>Focused on English text. Cross-lingual augmentation effects are not studied.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p><strong>Status: Highly Reproducible.</strong> Code, models, datasets, and hyperparameters are all publicly released under Apache 2.0.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training</td>
          <td>C4 (English)</td>
          <td>Varies by experiment</td>
          <td>Fixed unique data: 100M, 400M, 1.5B tokens</td>
      </tr>
      <tr>
          <td>Code augmentation</td>
          <td>The Stack (Python)</td>
          <td>Up to 42B tokens</td>
          <td>Mixed with natural language</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>19 NL tasks</td>
          <td>Standard splits</td>
          <td>Zero to five-shot, 114 scores per model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data-constrained scaling law: $D&rsquo; = U_{D} + U_{D} R_{D}^{<em>}(1 - e^{-R_{D}/R_{D}^{</em>}})$ with $R_{D}^{<em>} \approx 15.0$, $R_{N}^{</em>} \approx 5.3$. Fitted using the methodology of Hoffmann et al. (2022) adapted for the repetition terms. 400+ training runs used for fitting.</p>
<h3 id="models">Models</h3>
<p>GPT-2 architecture decoder-only transformers with GPT-2 tokenizer. Sizes: 10M to 8.7B parameters. Cosine learning rate schedule (max 2e-4, decay to 2e-5), Adam optimizer ($\beta_2 = 0.999$), dropout 0.1, weight decay 0.1, gradient clipping at 1.0. bfloat16 precision. Trained using Megatron-DeepSpeed.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Data-Constrained Optimal</th>
          <th>Chinchilla Optimal</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Validation loss (9.3e21 FLOPs, 25B unique)</td>
          <td>Lower</td>
          <td>Higher</td>
          <td>27% fewer parameters</td>
      </tr>
      <tr>
          <td>Downstream (4 epochs vs 1)</td>
          <td>No significant difference</td>
          <td>Baseline</td>
          <td>8.7B params, 44B unique tokens</td>
      </tr>
      <tr>
          <td>Code augmentation (50% code)</td>
          <td>No NL degradation</td>
          <td>Baseline</td>
          <td>Some tasks improve</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Trained on the LUMI supercomputer (Finland) using AMD Instinct MI250X GPUs with data, tensor, and pipeline parallelism. Up to 256 GPUs (64 nodes) per run, with up to 2,200 nodes (~8,800 GPUs) used in parallel across all concurrent runs. Total compute: approximately 3 million GPU hours. The cluster runs on 100% renewable hydroelectric energy.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/huggingface/datablations">datablations</a></td>
          <td>Code + Models + Data</td>
          <td>Apache 2.0</td>
          <td>All 400+ models, datasets, and training code</td>
      </tr>
      <tr>
          <td><a href="https://github.com/TurkuNLP/Megatron-DeepSpeed">Megatron-DeepSpeed fork</a></td>
          <td>Code</td>
          <td>-</td>
          <td>Training framework adapted for AMD ROCm</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{muennighoff2023scaling,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Scaling Data-Constrained Language Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Muennighoff, Niklas and Rush, Alexander M. and Barak, Boaz and Le Scao, Teven and Piktus, Aleksandra and Tazi, Nouamane and Pyysalo, Sampo and Wolf, Thomas and Raffel, Colin}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>DoReMi: Optimizing Data Mixtures for LM Pretraining</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/</guid><description>DoReMi uses a small proxy model with distributionally robust optimization to learn domain weights that speed up large-scale language model pretraining by 2.6x.</description><content:encoded><![CDATA[<h2 id="a-method-for-automatic-domain-reweighting">A method for automatic domain reweighting</h2>
<p>This is a <strong>method paper</strong> that introduces Domain Reweighting with Minimax Optimization (DoReMi), an algorithm for automatically tuning the mixture proportions of pretraining data domains. Rather than relying on heuristics or expensive downstream-task-based tuning, DoReMi uses a small proxy model trained with <a href="https://en.wikipedia.org/wiki/Robust_optimization">group distributionally robust optimization (Group DRO)</a> to produce domain weights that transfer to much larger models.</p>
<h2 id="why-data-mixture-proportions-matter">Why data mixture proportions matter</h2>
<p>Language model pretraining datasets combine text from many domains: web crawls, Wikipedia, books, code, academic papers, and others. The mixture proportions (how much of each domain to include) significantly affect downstream performance, but existing approaches either set them by hand (<a href="https://en.wikipedia.org/wiki/The_Pile_(dataset)">The Pile</a> uses heuristic weights) or tune them against downstream tasks (GLaM/PaLM), which is expensive and risks overfitting to a specific evaluation set. No principled, task-agnostic method existed for determining mixture proportions.</p>
<h2 id="minimax-optimization-over-domain-excess-loss">Minimax optimization over domain excess loss</h2>
<p>DoReMi&rsquo;s core insight is to frame data mixture optimization as a minimax problem: find domain weights that minimize the worst-case excess loss across all domains. The algorithm has three steps.</p>
<p><strong>Step 1</strong>: Train a small reference model (280M parameters) on some default domain weights $\alpha_{\text{ref}}$ (e.g., proportional to raw token count).</p>
<p><strong>Step 2</strong>: Train a small proxy model $p_{\theta}$ using Group DRO, which solves the minimax objective:</p>
<p>$$
\min_{\theta} \max_{\alpha \in \Delta^{k}} \sum_{i=1}^{k} \alpha_{i} \cdot \left[ \frac{1}{\sum_{x \in D_{i}} |x|} \sum_{x \in D_{i}} \ell_{\theta}(x) - \ell_{\text{ref}}(x) \right]
$$</p>
<p>where $\ell_{\theta}(x) = -\log p_{\theta}(x)$ and $\ell_{\text{ref}}(x) = -\log p_{\text{ref}}(x)$. The excess loss $\ell_{\theta}(x) - \ell_{\text{ref}}(x)$ measures how much headroom the proxy has to improve on each example relative to the reference. The inner maximization upweights domains with high excess loss via exponentiated gradient ascent, while the outer minimization trains the proxy on those upweighted domains.</p>
<p>At each training step, the domain weights update as:</p>
<p>$$
\alpha_{t}&rsquo; \leftarrow \alpha_{t-1} \exp(\eta \lambda_{t})
$$</p>
<p>where $\lambda_{t}[i]$ is the per-domain excess loss (clipped at zero), followed by renormalization and smoothing with a uniform component: $\alpha_{t} \leftarrow (1-c)\frac{\alpha_{t}&rsquo;}{\sum_{i} \alpha_{t}&rsquo;[i]} + cu$, with $c = 10^{-3}$.</p>
<p>The final domain weights are the average over all training steps: $\bar{\alpha} = \frac{1}{T}\sum_{t=1}^{T} \alpha_{t}$.</p>
<p><strong>Step 3</strong>: Resample data according to $\bar{\alpha}$ and train the full-scale model using standard procedures.</p>
<p><strong>Iterated DoReMi</strong> extends this by running multiple rounds, using the previous round&rsquo;s optimized weights as the next round&rsquo;s reference weights. This converges within 3 rounds on the GLaM dataset.</p>
<h2 id="experiments-across-the-pile-and-glam-datasets">Experiments across The Pile and GLaM datasets</h2>
<p><strong>Datasets.</strong> The Pile (22 domains, 800GB) and the GLaM dataset (8 domains, also used for PaLM). On The Pile, baseline weights come from the dataset defaults. On GLaM, baseline weights are uniform, with downstream-tuned oracle weights available for comparison.</p>
<p><strong>Setup.</strong> Transformer decoder-only LMs trained with next-token prediction. All models use batch size 512 and sequence length 1024. Proxy and reference models are 280M parameters. Main models are 8B parameters (30x larger). Training runs: 200K steps (Pile) or 300K steps (GLaM). The domain weight optimization cost (training two 280M models) is 8% of the compute for the 8B main model.</p>
<p><strong>Evaluation.</strong> Per-domain held-out perplexity and one-shot generative accuracy on five tasks: TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, and LAMBADA.</p>
<h3 id="key-domain-weight-shifts">Key domain weight shifts</h3>
<p>On The Pile, DoReMi (280M) dramatically upweights diverse web text (Pile-CC: 0.112 to 0.606) while downweighting specialized domains like ArXiv (0.105 to 0.004), PubMed Central (0.107 to 0.005), and StackExchange (0.093 to 0.015). Smaller, underrepresented domains like YouTubeSubtitles and PhilPapers receive proportionally large increases.</p>
<h3 id="scaling-behavior">Scaling behavior</h3>
<p>DoReMi was tested with matched proxy/main model sizes (280M through 1B) and with varying proxy sizes (70M through 1B) feeding into an 8B main model.</p>
<table>
  <thead>
      <tr>
          <th>Configuration</th>
          <th>Speedup to baseline accuracy</th>
          <th>Downstream improvement</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>DoReMi (280M to 280M)</td>
          <td>4x</td>
          <td>+2% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (280M to 8B)</td>
          <td>2.6x</td>
          <td>+6.5% avg accuracy</td>
      </tr>
      <tr>
          <td>DoReMi (150M to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
      <tr>
          <td>DoReMi (1B to 8B)</td>
          <td>~2x</td>
          <td>Significant</td>
      </tr>
  </tbody>
</table>
<p>Improvements are consistent across all tested model scales (280M to 1B matched), with no sign of diminishing returns at larger sizes.</p>
<h2 id="perplexity-improves-everywhere-even-on-downweighted-domains">Perplexity improves everywhere, even on downweighted domains</h2>
<p>The most striking finding is that DoReMi improves perplexity on all 22 domains in The Pile, including domains it downweights. The proposed explanation: the lowest-entropy domains need few samples to learn (they&rsquo;re statistically simple), while the highest-entropy domains have token distributions close to the uniform initialization and also need fewer samples. Reallocating weight to medium-entropy domains generates positive transfer that lifts all domains.</p>
<p>On The Pile, DoReMi reaches the baseline&rsquo;s downstream accuracy in 75K steps versus 200K for the baseline (2.6x speedup) and achieves a 6.5% absolute improvement in average one-shot accuracy at 200K steps.</p>
<p>On the GLaM dataset, iterated DoReMi (round 2) matches the performance of domain weights that were tuned directly on downstream task performance, despite having no knowledge of downstream tasks. Domain weights converge within 3 iterations.</p>
<h3 id="ablations">Ablations</h3>
<p>Using only the proxy model&rsquo;s loss (prefer hardest domains) or only the negative reference loss (prefer easiest domains) both underperform the full excess loss formulation. Both components are necessary: the excess loss identifies domains where the proxy has room to improve relative to what is learnable.</p>
<p>The proxy model itself typically underperforms the main model trained on its weights, and this gap grows at larger proxy scales. A 1B proxy model underperforms the 1B baseline, yet its domain weights still improve 1B main model training by over 2x. This suggests the domain weight signal is robust even when the proxy model itself is not well-trained.</p>
<h3 id="limitations">Limitations</h3>
<p>The domain weight landscape may have multiple local optima: a 280M proxy puts most weight on Pile-CC, while a 1B proxy favors OpenWebText2 instead. Both configurations improve over baseline, but the optimal weights are not unique.</p>
<p>The granularity of &ldquo;domains&rdquo; matters. DoReMi works better with more domains (22 on The Pile versus 8 on GLaM). Domains are defined by data provenance, which is coarse-grained. Fine-grained domain definitions (e.g., via clustering) could improve results but also risk DRO putting all weight on a small set of worst-case examples.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>800 GB, 22 domains</td>
          <td>Default heuristic weights as baseline</td>
      </tr>
      <tr>
          <td>Pretraining</td>
          <td>GLaM dataset</td>
          <td>8 domains</td>
          <td>Uniform weights as baseline; downstream-tuned oracle available</td>
      </tr>
      <tr>
          <td>Evaluation</td>
          <td>TriviaQA, NaturalQuestions, WebQuestions, SQuADv2, LAMBADA</td>
          <td>Standard splits</td>
          <td>One-shot generative evaluation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Group DRO with exponentiated gradient ascent for domain weight updates. Step size $\eta = 1$, smoothing $c = 10^{-3}$. Per-token excess loss clipped at zero. Domain weights averaged over all training steps. Iterated DoReMi converges when $|\bar{\alpha} - \alpha_{\text{ref}}|_{\infty} &lt; 10^{-3}$.</p>
<h3 id="models">Models</h3>
<p>Vanilla Transformer decoder-only models with 256K vocabulary. Sizes: 70M (3 layers), 150M (6 layers), 280M (12 layers), 510M (12 layers), 760M (12 layers), 1B (16 layers), 8B (32 layers). All use 64-dim attention heads except 8B (128-dim).</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>DoReMi (280M to 8B)</th>
          <th>Baseline (8B)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Avg one-shot accuracy</td>
          <td>+6.5% over baseline</td>
          <td>Reference</td>
          <td>5 generative tasks</td>
      </tr>
      <tr>
          <td>Worst-case log-perplexity</td>
          <td>1.46</td>
          <td>1.71</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Avg log-perplexity</td>
          <td>1.40</td>
          <td>1.64</td>
          <td>Across 22 Pile domains</td>
      </tr>
      <tr>
          <td>Domains beating baseline</td>
          <td>22/22</td>
          <td>0/22</td>
          <td>Per-domain perplexity</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>Proxy and reference models (under 1B) trained on TPUv3. Models at 1B and 8B trained on TPUv4. Domain weight optimization (two 280M runs) costs 8% of 8B training FLOPs.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{xie2023doremi,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Xie, Sang Michael and Pham, Hieu and Dong, Xuanyi and Du, Nan and Liu, Hanxiao and Lu, Yifeng and Liang, Percy and Le, Quoc V. and Ma, Tengyu and Yu, Adams Wei}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{36}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Data Mixing Laws for LM Pretraining Optimization</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</link><pubDate>Wed, 08 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/data-mixing-laws-pretraining/</guid><description>Ye et al. discover that LM loss follows an exponential law over domain mixture proportions, enabling cheap prediction and optimization of data mixtures.</description><content:encoded><![CDATA[<h2 id="an-empirical-discovery-of-predictable-mixture-loss-relationships">An empirical discovery of predictable mixture-loss relationships</h2>
<p>This is a <strong>discovery paper</strong> that identifies a quantitative, functional relationship between pretraining data mixture proportions and language model loss. The key finding is that domain-specific validation loss follows an exponential law over the linear combination of training domain proportions, and this law composes with standard scaling laws to enable cheap prediction of large-model performance under arbitrary mixtures.</p>
<h2 id="the-missing-quantitative-link-between-data-mixtures-and-performance">The missing quantitative link between data mixtures and performance</h2>
<p>Pretraining data for large language models combines text from many domains (web, code, academic, books, etc.), and mixture proportions significantly affect model quality. Existing approaches either set proportions by hand without disclosed criteria (LLaMA, Baichuan) or use algorithmic methods like <a href="/notes/natural-language-processing/language-models/doremi-data-mixture-optimization/">DoReMi</a> that optimize qualitatively but cannot predict the quantitative effect of a specific mixture before training. Scaling laws exist for model size and data quantity, but no equivalent existed for mixture proportions. This paper fills that gap.</p>
<h2 id="the-exponential-data-mixing-law">The exponential data mixing law</h2>
<p>The core finding: for a model of fixed size trained for a fixed number of steps, the validation loss on domain $i$ as a function of the training mixture proportions $r_{1 \dots M}$ follows:</p>
<p>$$
L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right)
$$</p>
<p>where $c_{i}$, $k_{i}$, and $t_{ij}$ are fitted parameters. The constant $c_{i}$ represents the irreducible loss (not affected by mixture changes). The interaction coefficients $t_{ij}$ capture how training domain $j$ affects validation loss on domain $i$: negative $t_{ij}$ means domain $j$ helps domain $i$, positive means it hurts.</p>
<p>This was discovered progressively:</p>
<ol>
<li><strong>Two domains</strong>: Log-reducible-loss is linear in domain proportion (univariate exponential).</li>
<li><strong>Three domains</strong>: The exponential generalizes to a linear combination over all domain proportions (Eq. above), outperforming alternatives with comparable parameter count.</li>
<li><strong>General validation</strong>: For a validation set composed of $K$ domains with proportions $s_{1 \dots K}$, the overall loss is:</li>
</ol>
<p>$$
L(r_{1 \dots M}) = \sum_{i=1}^{K} s_{i} \left[ c_{i} + k_{i} \exp\left(\sum_{j=1}^{M} t_{ij} r_{j}\right) \right]
$$</p>
<p>When the validation set composition is unknown, implicit domain aggregation treats $s_{i}$ as learnable parameters. Setting the number of implicit domains larger than the true number works well and is robust to overestimation.</p>
<h3 id="domain-interaction-patterns">Domain interaction patterns</h3>
<p>Visualizing the fitted $t_{ij}$ coefficients across 5 coarse Pile domains reveals three relationship types: most domain pairs are <strong>unrelated</strong> (sparse interaction matrix where each domain&rsquo;s loss is dominated by its own training proportion), some show <strong>facilitation</strong> (e.g., dialogue data helps internet text), and some show <strong>conflict</strong> (e.g., symbolic data hurts prose). This sparsity explains why the law can be fitted with fewer samples than the quadratic parameter count would suggest.</p>
<h2 id="nested-scaling-pipeline-for-cheap-prediction">Nested scaling pipeline for cheap prediction</h2>
<p>Fitting data mixing laws directly at target scale is too expensive (requires many full training runs at different mixtures). The paper proposes nesting three scaling laws:</p>
<p><strong>Step 1</strong>: For each mixture $r_{i}$ and each small model size $N_{j}$, train for $S_{0}$ steps. Fit a <a href="https://en.wikipedia.org/wiki/Power_law">power law</a> $L(S) = E_{1} + B/S^{\beta}$ over steps to extrapolate to the target step count $S_{\text{target}}$.</p>
<p><strong>Step 2</strong>: With the step-extrapolated losses for each mixture, fit a power law $L(N) = E_{2} + A/N^{\alpha}$ over model sizes to extrapolate to the target model size $N_{\text{target}}$.</p>
<p><strong>Step 3</strong>: With the predicted losses at $(N_{\text{target}}, S_{\text{target}})$ for all sampled mixtures, fit the data mixing law and search for the optimal mixture.</p>
<p>This pipeline requires only training small models (70M to 410M) for short runs (30B tokens) to predict performance of a 1B model trained for 100B tokens.</p>
<h3 id="mixture-sampling-strategy">Mixture sampling strategy</h3>
<p>To get informative samples efficiently, the paper uses double-diminishing proportions: for each domain, enumerate proportions by halving from the maximum available. This distributes losses evenly across the exponential law&rsquo;s range. From 40 candidate mixtures trained at the smallest scale (70M), 20 are selected based on which subset minimizes data mixing law fitting error.</p>
<h2 id="experiments-on-redpajama-and-continual-pretraining">Experiments on RedPajama and continual pretraining</h2>
<p><strong>Main experiment.</strong> Models trained on RedPajama, validated on the Pile (mimicking the common scenario where validation data comes from a different distribution than training). Small models: 70M, 160M, 305M, 410M trained for 30B tokens. Target: 1B model for 100B tokens.</p>
<p>The optimized mixture dramatically redistributes weight compared to RedPajama defaults:</p>
<table>
  <thead>
      <tr>
          <th>Domain</th>
          <th>Default</th>
          <th>Optimized</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CommonCrawl</td>
          <td>0.670</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>C4</td>
          <td>0.150</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>GitHub</td>
          <td>0.045</td>
          <td>0.141</td>
      </tr>
      <tr>
          <td>ArXiv</td>
          <td>0.045</td>
          <td>0.250</td>
      </tr>
      <tr>
          <td>Books</td>
          <td>0.045</td>
          <td>0.094</td>
      </tr>
      <tr>
          <td>StackExchange</td>
          <td>0.025</td>
          <td>0.125</td>
      </tr>
      <tr>
          <td>Wikipedia</td>
          <td>0.020</td>
          <td>0.016</td>
      </tr>
  </tbody>
</table>
<p>The optimized mixture reaches the default mixture&rsquo;s final performance in 73% of the training steps and eventually achieves performance equivalent to 48% more training on the default mixture.</p>
<p><strong>Comparison to DoReMi and DoGE.</strong> Data mixing laws outperform both: the predicted-optimal mixture achieves lower validation loss than DoReMi and DoGE (both universal and OOD settings) for 1B models trained for 100B tokens on RedPajama.</p>
<p><strong>Continual pretraining.</strong> The law extends to continual pretraining (Pythia-70M on Pile + Python code). It accurately predicts the critical mixture proportion that avoids <a href="https://en.wikipedia.org/wiki/Catastrophic_interference">catastrophic forgetting</a> on the original domain while improving the target domain. This suggests data mixing laws could guide dynamic data schedules across multi-stage pretraining.</p>
<h2 id="implications-and-limitations">Implications and limitations</h2>
<p>The data mixing law provides a predictive framework rather than just an optimization algorithm. Key implications:</p>
<ul>
<li>The interaction coefficients $t_{ij}$ make domain relationships quantitatively observable before full-scale training, identifying facilitation and conflict pairs.</li>
<li>The nested pipeline&rsquo;s cost is dominated by the small-model training runs (40 mixtures at 70M scale), which is orders of magnitude cheaper than even a single target-scale run.</li>
<li>The continual pretraining application opens the door to optimizing dynamic data schedules, where mixture proportions change across training stages.</li>
</ul>
<p><strong>Limitations</strong>: The &ldquo;domain&rdquo; concept remains loosely defined (provenance-based). The nested scaling laws introduce compounding errors at each step, and predictions tend to slightly underestimate actual loss. The number of required fitting samples, while subquadratic in practice due to sparsity, still scales with the number of domains. No theoretical justification for the exponential form is provided; it is a purely empirical finding.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training (pilot)</td>
          <td>The Pile (GitHub, Pile-CC, Books3)</td>
          <td>30B tokens</td>
          <td>2-domain and 3-domain experiments</td>
      </tr>
      <tr>
          <td>Training (main)</td>
          <td>RedPajama</td>
          <td>100B tokens</td>
          <td>7 domains</td>
      </tr>
      <tr>
          <td>Validation</td>
          <td>The Pile validation set</td>
          <td>Standard split</td>
          <td>Out-of-distribution relative to RedPajama</td>
      </tr>
      <tr>
          <td>Continual pretraining</td>
          <td>Pile + Python code</td>
          <td>10B tokens</td>
          <td>Pythia-70M base model</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<p>Data mixing law: $L_{i}(r_{1 \dots M}) = c_{i} + k_{i} \exp(\sum_{j} t_{ij} r_{j})$. Fitted via AdaBoost Regressor on sampled mixtures. Step scaling law: $L(S) = E_{1} + B/S^{\beta}$. Model size scaling law: $L(N) = E_{2} + A/N^{\alpha}$. Both fitted via Huber loss minimization with LBFGS. Decomposed Chinchilla-style (separate fits for stability). 40 candidate mixtures sampled via double-diminishing proportions, 20 selected for the final pipeline.</p>
<h3 id="models">Models</h3>
<p>Transformer decoder-only LMs. Pilot: 70M, 160M. Main pipeline: 70M, 160M, 305M, 410M (for fitting), 1B (target). Batch size: 1M tokens. Cosine learning rate decay with 2K step warmup, decaying to 0.1x at 100K steps.</p>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Optimized Mixture</th>
          <th>Default Mixture</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Steps to match default final loss</td>
          <td>73K (73%)</td>
          <td>100K (100%)</td>
          <td>27% training reduction</td>
      </tr>
      <tr>
          <td>Equivalent extra training</td>
          <td>+48%</td>
          <td>Baseline</td>
          <td>Estimated via step scaling law</td>
      </tr>
      <tr>
          <td>Validation loss (1B, 100B)</td>
          <td>Lowest</td>
          <td>Higher than optimized</td>
          <td>Also beats DoReMi and DoGE</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>8 A100 GPUs. Training times per 30B-token run: 3.5 hours (70M), 8 hours (160M), 16 hours (305M), 21 hours (410M).</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>MIT</td>
          <td>Pilot and validation data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/togethercomputer/RedPajama-Data">RedPajama</a></td>
          <td>Dataset</td>
          <td>Apache 2.0</td>
          <td>Main training data</td>
      </tr>
      <tr>
          <td><a href="https://github.com/EleutherAI/pythia">Pythia Suite</a></td>
          <td>Model</td>
          <td>Apache 2.0</td>
          <td>Model architecture configs; Pythia-70M checkpoint for continual pretraining</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility status: Partially Reproducible.</strong> Datasets and base model checkpoints are public. No official code release for the data mixing law fitting pipeline, mixture sampling, or the nested scaling law prediction workflow.</p>
<hr>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{ye2025datamixinglaws,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Data Mixing Laws: Optimizing Data Mixtures by Predicting Language Modeling Performance}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Ye, Jiasheng and Liu, Peiju and Sun, Tianxiang and Zhan, Jun and Zhou, Yunhua and Qiu, Xipeng}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2025}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>RWKV: Linear-Cost RNN with Transformer Training</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/</guid><description>RWKV combines parallelizable transformer training with constant-cost RNN inference using linear attention and channel-wise decay.</description><content:encoded><![CDATA[<h2 id="a-new-architecture-bridging-rnns-and-transformers">A New Architecture Bridging RNNs and Transformers</h2>
<p>This is a <strong>Method</strong> paper that introduces RWKV (Receptance Weighted Key Value), a novel sequence model architecture that combines the parallelizable training of Transformers with the efficient $O(Td)$ inference of RNNs. RWKV can be formulated equivalently as either a Transformer (for parallel training) or an RNN (for sequential inference), achieving the lowest computational and memory complexity among comparable architectures while matching Transformer-level performance. The authors scale RWKV to 14 billion parameters, making it the largest dense RNN ever trained at the time of publication.</p>
<h2 id="the-quadratic-cost-of-self-attention">The Quadratic Cost of Self-Attention</h2>
<p>Transformers have become the dominant architecture for NLP, powering models like GPT-3, LLaMA, and Chinchilla. Their self-attention mechanism captures both local and long-range dependencies while supporting parallelized training. However, self-attention scales quadratically with sequence length in both time ($O(T^2d)$) and space ($O(T^2 + Td)$), making it computationally and memory intensive for long sequences and resource-constrained deployment.</p>
<p>RNNs, by contrast, offer linear scaling in memory and computation, but suffer from the vanishing gradient problem and cannot parallelize across the time dimension during training. This limits their scalability and makes them unable to match Transformer performance in practice.</p>
<p>Prior work on efficient Transformers (Reformer, Performer, Linformer, AFT, MEGA) has attempted to reduce this quadratic cost, often at the expense of model expressivity. RWKV aims to combine the best of both worlds: Transformer-grade training efficiency with RNN-grade inference cost, without any approximation to the attention mechanism.</p>
<h2 id="linear-attention-via-channel-wise-decay">Linear Attention via Channel-Wise Decay</h2>
<p>RWKV is built on four core vectors that interact multiplicatively at each timestep:</p>
<ul>
<li><strong>R</strong> (Receptance): receives past information, acting as a gating signal</li>
<li><strong>W</strong> (Weight): a trainable positional weight decay vector</li>
<li><strong>K</strong> (Key): analogous to keys in standard attention</li>
<li><strong>V</strong> (Value): analogous to values in standard attention</li>
</ul>
<p>The architecture consists of stacked residual blocks, each containing a <strong>time-mixing</strong> sub-block and a <strong>channel-mixing</strong> sub-block.</p>
<h3 id="token-shift">Token Shift</h3>
<p>All linear projection vectors are produced by interpolating between the current input $x_t$ and the previous input $x_{t-1}$, creating a token shift mechanism:</p>
<p>$$
r_t = W_r \cdot (\mu_r \odot x_t + (1 - \mu_r) \odot x_{t-1})
$$</p>
<p>$$
k_t = W_k \cdot (\mu_k \odot x_t + (1 - \mu_k) \odot x_{t-1})
$$</p>
<p>$$
v_t = W_v \cdot (\mu_v \odot x_t + (1 - \mu_v) \odot x_{t-1})
$$</p>
<p>where $\mu_r$, $\mu_k$, $\mu_v$ are learnable interpolation parameters. This is implemented efficiently as a simple offset in the temporal dimension.</p>
<h3 id="the-wkv-operator">The WKV Operator</h3>
<p>The core attention-like computation replaces standard dot-product attention with a channel-wise weighted sum using exponential decay:</p>
<p>$$
wkv_t = \frac{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} \odot v_i + e^{u + k_t} \odot v_t}{\sum_{i=1}^{t-1} e^{-(t-1-i)w + k_i} + e^{u + k_t}}
$$</p>
<p>Here $w$ is the channel-wise time decay vector and $u$ is a separate bonus vector that attends specifically to the current token. Unlike AFT where $W$ is a pairwise matrix, RWKV treats $W$ as a channel-wise vector modified by relative position, enabling the recurrent formulation.</p>
<h3 id="output-gating">Output Gating</h3>
<p>The receptance vector gates the WKV output through a sigmoid:</p>
<p>$$
o_t = W_o \cdot (\sigma(r_t) \odot wkv_t)
$$</p>
<p>The channel-mixing block uses a similar gating mechanism with squared ReLU activation:</p>
<p>$$
o&rsquo;_t = \sigma(r&rsquo;_t) \odot (W&rsquo;_v \cdot \max(k&rsquo;_t, 0)^2)
$$</p>
<h3 id="dual-mode-operation">Dual-Mode Operation</h3>
<p>During <strong>training</strong>, RWKV operates in time-parallel mode. The matrix multiplications ($W_\lambda$ for $\lambda \in {r, k, v, o}$) dominate at $O(BTd^2)$ and parallelize identically to standard Transformers. The element-wise WKV computation is $O(BTd)$ and parallelizes along batch and channel dimensions.</p>
<p>During <strong>inference</strong>, RWKV switches to time-sequential mode. Each timestep updates a fixed-size state vector, giving constant $O(d)$ memory and $O(Td)$ total time for generating $T$ tokens, compared to $O(T^2d)$ for standard Transformers.</p>
<h3 id="optimizations">Optimizations</h3>
<p>Three additional design choices improve training:</p>
<ol>
<li><strong>Custom CUDA kernels</strong> for the sequential WKV computation, fusing it into a single kernel on training accelerators</li>
<li><strong>Small init embedding</strong>: initializing the embedding matrix with small values plus an additional LayerNorm, accelerating convergence</li>
<li><strong>Custom initialization</strong>: most weights initialized to zero with no biases, following identity-mapping principles from residual network design</li>
</ol>
<h2 id="scaling-to-14b-parameters-and-benchmark-evaluation">Scaling to 14B Parameters and Benchmark Evaluation</h2>
<h3 id="model-scaling">Model Scaling</h3>
<p>The authors train six RWKV models from 169M to 14B parameters, all for one epoch (330B tokens) on the Pile:</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>Dimension</th>
          <th>Parameters</th>
          <th>FLOP/Token</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>12</td>
          <td>768</td>
          <td>$1.69 \times 10^8$</td>
          <td>$2.61 \times 10^8$</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>24</td>
          <td>1024</td>
          <td>$4.30 \times 10^8$</td>
          <td>$7.57 \times 10^8$</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>24</td>
          <td>2048</td>
          <td>$1.52 \times 10^9$</td>
          <td>$2.82 \times 10^9$</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>32</td>
          <td>2560</td>
          <td>$2.99 \times 10^9$</td>
          <td>$5.71 \times 10^9$</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>32</td>
          <td>4096</td>
          <td>$7.39 \times 10^9$</td>
          <td>$1.44 \times 10^{10}$</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>40</td>
          <td>5120</td>
          <td>$1.42 \times 10^{10}$</td>
          <td>$2.78 \times 10^{10}$</td>
      </tr>
  </tbody>
</table>
<p>The parameter count follows: $\text{params} = 2VD + 13D^2L + D(11L + 4)$, where $V = 50277$ is vocabulary size, $D$ is model dimension, and $L$ is layers. FLOPs match the standard transformer formula: $\text{FLOP} = 6 \cdot [\text{tokens}] \cdot [\text{params}]$.</p>
<h3 id="scaling-laws">Scaling Laws</h3>
<p>Training 45 RWKV models across varied (dataset, parameters) pairs, the authors find that RWKV follows the same log-log linear scaling law established for Transformers. The linear fit to Pareto-optimal points achieves $r^2 = 0.994$, and extrapolation an additional order of magnitude still yields $r^2 = 0.875$. This contrasts with prior claims that LSTMs do not follow transformer-like scaling.</p>
<h3 id="nlp-benchmarks">NLP Benchmarks</h3>
<p>RWKV is compared against similarly-sized models trained on comparable token budgets: Pythia, OPT, and BLOOM (all FLOP-matched). Results span twelve benchmarks: ARC (Easy/Challenge), BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, and Winogrande.</p>
<p>RWKV performs competitively with Transformers across all model sizes. On average across benchmarks, RWKV tracks closely with Pythia and outperforms OPT and BLOOM at comparable scales.</p>
<h3 id="long-context-and-extended-finetuning">Long Context and Extended Finetuning</h3>
<p>RWKV can extend its context length after pretraining through progressive finetuning: doubling from 1024 to 2048 (10B tokens), then to 4096 (100B tokens), and finally to 8192 (100B tokens). Each doubling reduces test loss on the Pile, indicating effective use of longer context.</p>
<p>On the Long Range Arena (LRA) benchmark, which tests sequences from 1,000 to 16,000 tokens, RWKV performs second only to S4 across the five datasets.</p>
<h3 id="inference-efficiency">Inference Efficiency</h3>
<p>Benchmarking text generation on CPU (x86) and GPU (NVIDIA A100 80GB) at float32 precision shows that RWKV exhibits linear scaling in generation time, while Transformers scale quadratically. This advantage grows with sequence length: for long outputs, RWKV completes generation substantially faster at equivalent model sizes.</p>
<h2 id="competitive-performance-with-key-caveats">Competitive Performance with Key Caveats</h2>
<p>RWKV demonstrates that RNN-class models can match Transformer performance at scale, while maintaining $O(Td)$ time and $O(d)$ memory during inference. The key findings are:</p>
<ol>
<li><strong>Scaling laws hold</strong>: RWKV follows the same compute-optimal scaling as Transformers ($r^2 = 0.994$), contradicting earlier claims about RNN scaling behavior</li>
<li><strong>Competitive NLP performance</strong>: Across twelve benchmarks, RWKV matches similarly-sized Transformers trained on comparable data</li>
<li><strong>Linear inference cost</strong>: Generation time scales linearly rather than quadratically, with constant memory regardless of sequence length</li>
<li><strong>Context extension</strong>: Progressive finetuning effectively extends the context window post-training</li>
</ol>
<h3 id="limitations">Limitations</h3>
<p>The authors identify two primary limitations:</p>
<p><strong>Information compression</strong>: Linear attention funnels all past information through a single fixed-size state vector. For tasks requiring recall of specific details over very long contexts, this is mechanistically more constrained than full self-attention, which maintains direct access to all previous tokens.</p>
<p><strong>Prompt sensitivity</strong>: RWKV is more sensitive to prompt engineering than standard Transformers. The linear attention mechanism limits how much prompt information carries forward, making the order of information in the prompt particularly important. Reordering prompts improved F1 from 44.2% to 74.8% on one task.</p>
<h3 id="future-directions">Future Directions</h3>
<p>The authors suggest several avenues: applying parallel scan to reduce WKV cost to $O(B \log(T) d)$, extending RWKV to encoder-decoder and multimodal architectures, leveraging hidden states for interpretability and safety, and increasing internal state size to improve long-range recall.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official PyTorch training and inference implementation</td>
      </tr>
      <tr>
          <td><a href="https://huggingface.co/BlinkDL/rwkv-4-pile-14b">Pre-trained weights (169M to 14B)</a></td>
          <td>Model</td>
          <td>Apache-2.0</td>
          <td>All six Pile-trained sizes on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>)</td>
      </tr>
      <tr>
          <td><a href="https://pile.eleuther.ai/">The Pile</a></td>
          <td>Dataset</td>
          <td>Mixed</td>
          <td>825 GiB pretraining corpus; component licenses vary by source</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility classification</strong>: Highly Reproducible. Training code (Apache-2.0), pre-trained weights for all six model sizes, the full training corpus, and complete hyperparameters (Appendix G) are all publicly available. The only missing detail is the specific GPU cluster configuration used for pretraining.</p>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Pretraining</td>
          <td>The Pile</td>
          <td>330B tokens</td>
          <td>One full epoch for all model sizes</td>
      </tr>
      <tr>
          <td>Context extension</td>
          <td>The Pile</td>
          <td>210B additional tokens</td>
          <td>Progressive doubling: 1024 to 8192</td>
      </tr>
      <tr>
          <td>NLP evaluation</td>
          <td>ARC, BoolQ, COPA, HeadQA, HellaSwag, LAMBADA, OpenBookQA, PIQA, ReCoRD, SciQ, Winogrande</td>
          <td>Various</td>
          <td>Zero-shot evaluation</td>
      </tr>
      <tr>
          <td>Long-range evaluation</td>
          <td>Long Range Arena (LRA)</td>
          <td>1K-16K tokens</td>
          <td>Five sub-tasks</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adam ($\beta = (0.9, 0.99)$), no weight decay</li>
<li>Precision: bfloat16</li>
<li>Training context length: 1024 tokens</li>
<li>Learning rate: constant warmup, then exponential decay</li>
<li>Auxiliary loss from PaLM (softmax normalizer regularization)</li>
<li>Batch size: 128 or 256 sequences (dynamically switched)</li>
<li>Training organized into mini-epochs of 40,320 samples each (8,043 mini-epochs per Pile epoch)</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Init LR</th>
          <th>Warmup Mini-Epochs</th>
          <th>End LR</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>169M</td>
          <td>6e-4</td>
          <td>361</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>430M</td>
          <td>4e-4</td>
          <td>411</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>1.5B</td>
          <td>3e-4</td>
          <td>443</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>3B</td>
          <td>1.5e-4</td>
          <td>451</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>7B</td>
          <td>1.5e-4</td>
          <td>465</td>
          <td>1e-5</td>
      </tr>
      <tr>
          <td>14B</td>
          <td>1e-4</td>
          <td>544</td>
          <td>7e-6</td>
      </tr>
  </tbody>
</table>
<p>All pretrained models (169M to 14B) are publicly released on HuggingFace (<code>BlinkDL/rwkv-4-pile-*</code>) under Apache-2.0. Training code is at <a href="https://github.com/BlinkDL/RWKV-LM">BlinkDL/RWKV-LM</a> (Apache-2.0).</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li>All NLP benchmarks evaluated in zero-shot setting</li>
<li>FLOP-matched comparison against Pythia, OPT, BLOOM</li>
<li>Inference benchmarked on CPU (x86) and GPU (NVIDIA A100 80GB) at float32</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Inference experiments: NVIDIA A100 80GB GPU</li>
<li>Training hardware details not fully specified; FLOP budgets reported per model</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Peng, B., Alcaide, E., Anthony, Q., Albalak, A., Arcadinho, S., Biderman, S., &hellip; &amp; Zhu, R.-J. (2023). RWKV: Reinventing RNNs for the Transformer Era. In <em>Findings of the Association for Computational Linguistics: EMNLP 2023</em>, pp. 14048-14077.</p>
<p><strong>Publication</strong>: Findings of EMNLP 2023</p>
<p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/BlinkDL/RWKV-LM">GitHub Repository (Apache-2.0)</a></li>
</ul>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{peng2023rwkv,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{RWKV: Reinventing RNNs for the Transformer Era}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Peng, Bo and Alcaide, Eric and Anthony, Quentin and Albalak, Alon and Arcadinho, Samuel and Biderman, Stella and Cao, Huanqi and Cheng, Xin and Chung, Michael and Derczynski, Leon and Du, Xingjian and Grella, Matteo and GV, Kranthi Kiran and He, Xuzheng and Hou, Haowen and Kazienko, Przemys{\l}aw and Koco{\&#39;n}, Jan and Kong, Jiaming and Koptyra, Bart{\l}omiej and Lau, Hayden and Lin, Jiaju and Mantri, Krishna Sri Ipsit and Mom, Ferdinand and Saito, Atsushi and Song, Guangyu and Tang, Xiangru and Wind, Johan S. and Wo{\&#39;z}niak, Stanis{\l}aw and Zhang, Zhenyuan and Zhou, Qinghua and Zhu, Jian and Zhu, Rui-Jie}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Findings of the Association for Computational Linguistics: EMNLP 2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{14048--14077}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span>=<span style="color:#e6db74">{10.18653/v1/2023.findings-emnlp.936}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item><item><title>Block-Recurrent Transformers for Long Sequences</title><link>https://hunterheidenreich.com/notes/natural-language-processing/language-models/block-recurrent-transformers/</link><pubDate>Tue, 07 Apr 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/natural-language-processing/language-models/block-recurrent-transformers/</guid><description>Block-Recurrent Transformers combine attention and recurrence for linear-complexity language modeling on long documents like books and code.</description><content:encoded><![CDATA[<h2 id="a-method-for-combining-attention-with-block-level-recurrence">A Method for Combining Attention with Block-Level Recurrence</h2>
<p>This is a <strong>Method</strong> paper that introduces the Block-Recurrent Transformer, a model architecture that integrates recurrence into the transformer framework at the block level. Rather than processing tokens one at a time (as in traditional RNNs) or attending over entire sequences (as in standard transformers), this approach applies a transformer layer recurrently across blocks of tokens. The result is a model with linear complexity in sequence length that maintains the parallelism benefits of transformers during training. A related approach, <a href="/notes/natural-language-processing/language-models/rwkv-rnn-transformer-architecture/">RWKV</a>, later explored similar ideas using linear attention with channel-wise decay.</p>
<h2 id="why-transformers-struggle-with-long-documents">Why Transformers Struggle with Long Documents</h2>
<p>Transformers have largely replaced RNNs for sequence modeling tasks, but their quadratic self-attention cost limits the length of sequences they can process. A transformer with a window size of 512 tokens cannot see information beyond that window, making it blind to long-range dependencies in books, technical papers, or source code repositories.</p>
<p>Prior approaches to this problem fall into several categories: sparse attention patterns (BigBird, Routing Transformers, Reformer), sequence compression (Linformer, Funnel Transformers), and linearized attention approximations. These methods either sacrifice the expressiveness of full softmax attention or introduce implementation complexity.</p>
<p>Traditional RNNs like LSTMs offer linear complexity but suffer from three key limitations: sequential processing prevents parallelism on modern hardware, a single state vector bottlenecks information capacity, and vanishing gradients limit effective memory to a few hundred tokens.</p>
<h2 id="block-level-recurrence-with-lstm-style-gates">Block-Level Recurrence with LSTM-Style Gates</h2>
<p>The core innovation is applying a standard transformer layer in a recurrent fashion along the sequence, operating on blocks of $W$ tokens rather than individual tokens. The recurrent cell maintains $S$ state vectors (typically $S = W = 512$) that are updated at each block boundary.</p>
<h3 id="the-recurrent-cell">The Recurrent Cell</h3>
<p>The cell has two processing directions:</p>
<ul>
<li><strong>Vertical direction</strong>: An ordinary transformer layer with self-attention over input tokens and cross-attention to recurrent states, producing output embeddings.</li>
<li><strong>Horizontal direction</strong>: Self-attention over current state vectors and cross-attention to input tokens, producing updated state vectors. Residual connections are replaced with gates.</li>
</ul>
<p>Self-attention and cross-attention are computed in parallel (not sequentially), with results concatenated and fed into a linear projection. Keys and values are shared between directions, while queries are separate, yielding four query sets: $Q_e^v$, $Q_s^v$ (vertical) and $Q_s^h$, $Q_e^h$ (horizontal).</p>
<h3 id="gating-mechanisms">Gating Mechanisms</h3>
<p>Two gate types are explored. The <strong>fixed gate</strong> uses a learned convex combination:</p>
<p>$$
g = \sigma(b_g)
$$</p>
<p>$$
c_{t+1} = c_t \odot g + z_t \odot (1 - g)
$$</p>
<p>where $g$ is constant after training, implementing an <a href="https://en.wikipedia.org/wiki/Moving_average">exponential moving average</a>.</p>
<p>The <strong>LSTM gate</strong> uses input and forget gates:</p>
<p>$$
i_t = \sigma(W_i h_t + b_i - 1)
$$</p>
<p>$$
f_t = \sigma(W_f h_t + b_f + 1)
$$</p>
<p>$$
c_{t+1} = c_t \odot f_t + z_t \odot i_t
$$</p>
<p>The bias offsets ($-1$ for input, $+1$ for forget) initialize the model to &ldquo;remember&rdquo; by default, which is critical for training stability. Without careful initialization, the model can fall into a local optimum where it ignores the recurrent state entirely. This echoes the <a href="/notes/machine-learning/model-architectures/can-recurrent-neural-networks-warp-time/">gate initialization challenges studied by Tallec and Ollivier</a>, who derived chrono initialization for LSTMs from time-warping invariance.</p>
<h3 id="gate-configurations">Gate Configurations</h3>
<p>Three configurations are tested: <strong>dual</strong> (gates on both attention and MLP outputs), <strong>single</strong> (gate only on MLP output), and <strong>skip</strong> (gate only on attention output, no MLP). The skip configuration removes the large MLP from the recurrent layer entirely.</p>
<h3 id="learned-state-ids">Learned State IDs</h3>
<p>Since the same weights are applied to all state vectors, learned &ldquo;state IDs&rdquo; (analogous to position embeddings) are added so each state vector can issue distinct queries. <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style relative position bias is used for token self-attention, with no position bias for state-token cross-attention.</p>
<h2 id="language-modeling-on-pg19-arxiv-and-github">Language Modeling on PG19, arXiv, and GitHub</h2>
<h3 id="experimental-setup">Experimental Setup</h3>
<p>The base model is a 12-layer transformer with 150M parameters (8 heads of size 128, embedding dimension 1024, MLP hidden size 4096). The recurrent layer is placed at layer 10 with segment length $N = 4096$ and window size $W = 512$. The architecture is evaluated on three long-document datasets:</p>
<ul>
<li><strong>PG19</strong>: Full-length books from <a href="https://en.wikipedia.org/wiki/Project_Gutenberg">Project Gutenberg</a> (pre-1919)</li>
<li><strong>arXiv</strong>: Mathematics papers in LaTeX</li>
<li><strong>GitHub</strong>: Concatenated source code from open-source repositories</li>
</ul>
<p>All models report bits-per-token ($\log_2$ perplexity, lower is better).</p>
<h3 id="baselines">Baselines</h3>
<p>Five baselines are compared: Transformer-XL with window sizes of 512, 1024, and 2048, plus 12-layer and 13-layer sliding window models. The 13-layer sliding window (Slide:13L) is the primary comparison, having equivalent computation cost and parameter count to the recurrent models.</p>
<h3 id="main-results">Main Results</h3>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Step Time</th>
          <th>PG19 (bytes)</th>
          <th>PG19 (tokens)</th>
          <th>arXiv</th>
          <th>GitHub</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>XL:512</td>
          <td>0.88</td>
          <td>1.01</td>
          <td>3.62</td>
          <td>1.45</td>
          <td>1.21</td>
      </tr>
      <tr>
          <td>XL:2048</td>
          <td>2.11</td>
          <td>0.990</td>
          <td>3.58</td>
          <td>1.31</td>
          <td>1.01</td>
      </tr>
      <tr>
          <td>Slide:13L</td>
          <td>1.00</td>
          <td>0.989</td>
          <td>3.58</td>
          <td>1.42</td>
          <td>1.17</td>
      </tr>
      <tr>
          <td>Rec:fixed:skip</td>
          <td>0.99</td>
          <td>0.952</td>
          <td>3.53</td>
          <td>1.24</td>
          <td>0.976</td>
      </tr>
      <tr>
          <td>Rec:fixed:dual</td>
          <td>1.01</td>
          <td>0.957</td>
          <td>3.52</td>
          <td>1.27</td>
          <td>0.991</td>
      </tr>
      <tr>
          <td>Feedback:fixed:skip</td>
          <td>1.35</td>
          <td>0.935</td>
          <td>3.49</td>
          <td>1.24</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Memorizing Trans. 64k</td>
          <td>1.94</td>
          <td>0.950</td>
          <td>3.53</td>
          <td>1.22</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>The Rec:fixed:skip configuration achieves the best overall results while being slightly faster than the 13-layer baseline. It outperforms XL:2048, which runs over 2x slower. The block feedback variant (allowing all layers to cross-attend to recurrent states) improves perplexity further at ~35-40% higher step time.</p>
<h3 id="scaling-behavior">Scaling Behavior</h3>
<p>Models from 40M to 1.3B parameters show that the benefit of recurrence is <a href="/notes/machine-learning/model-architectures/scaling-laws-vs-model-architectures/">consistent across scales</a> and increases with model size. At larger sizes, adding recurrence provides a benefit greater than doubling the number of parameters. The 1.3B parameter model achieves 26.50 word-level perplexity on PG19, setting a new state of the art at the time of publication.</p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Layers</th>
          <th>PG19 Perplexity</th>
          <th>Parameters</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Compressive Transformer</td>
          <td>36</td>
          <td>33.6</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Routing Transformer</td>
          <td>22</td>
          <td>33.2</td>
          <td>490M</td>
      </tr>
      <tr>
          <td>Perceiver AR</td>
          <td>60</td>
          <td>28.9</td>
          <td>974.6M</td>
      </tr>
      <tr>
          <td>Block-Recurrent Transformer</td>
          <td>24</td>
          <td>26.50</td>
          <td>1.3B</td>
      </tr>
  </tbody>
</table>
<h3 id="ablations">Ablations</h3>
<ul>
<li><strong>Multiple recurrent layers</strong>: Two adjacent layers (9, 10) provide no benefit. Two separated layers (4, 10) help but no more than adding another non-recurrent layer.</li>
<li><strong>Number of states</strong>: Improvement up to 1024 states, degradation at 2048.</li>
<li><strong>Window size reduction</strong>: Reducing the sliding window hurts Transformer-XL dramatically but has smaller impact on the recurrent model, which compensates via recurrence.</li>
<li><strong>Gate type</strong>: The fixed gate consistently outperforms the LSTM gate despite being theoretically less expressive.</li>
</ul>
<h3 id="qualitative-analysis">Qualitative Analysis</h3>
<p>Comparing per-token predictions against Transformer-XL on PG19 books, the recurrent model&rsquo;s advantage comes overwhelmingly from predicting proper names (17/20 top-improvement tokens). In 19/20 cases, the predicted word was outside the attention window, confirming it was stored in recurrent state. The model can remember book titles and authors across 60,000+ tokens.</p>
<h2 id="findings-limitations-and-future-directions">Findings, Limitations, and Future Directions</h2>
<p>The Block-Recurrent Transformer demonstrates that recurrence at the block level is a cost-effective way to improve language modeling on long sequences. The fixed:skip configuration (the simplest variant) performs best, suggesting the model primarily uses recurrence for long-range name lookup rather than complex reasoning. The fact that removing the MLP from the recurrent layer has minimal impact further supports this interpretation.</p>
<p>Key limitations include: the model was only evaluated on language modeling perplexity (no downstream tasks), the LSTM gate underperforms the simpler fixed gate (suggesting untapped potential for more expressive recurrence), and the authors acknowledge that training the recurrent layer to fully exploit its capacity for knowledge extraction will require further advances.</p>
<p>The authors note that evaluating on downstream tasks requiring long-range context (book summarization, long-document QA, code completion) is an important direction for future work.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Training/Eval</td>
          <td>PG19</td>
          <td>~29k books</td>
          <td>Public domain, freely available</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>arXiv</td>
          <td>Mathematics papers</td>
          <td>Obtained via private channels, not redistributable</td>
      </tr>
      <tr>
          <td>Training/Eval</td>
          <td>GitHub</td>
          <td>Open-source repos</td>
          <td>Obtained via private channels, not redistributable</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>Optimizer: Adafactor</li>
<li>Learning rate: 1.0 with inverse square root decay (initial experiments), cosine decay with max 0.01 (scaling experiments)</li>
<li>Warmup: 1000 steps</li>
<li>Dropout: 0.05</li>
<li>Vocabulary: 32k SentencePiece (T5 pretrained for initial, custom for scaling)</li>
<li>Gate initialization: bias of $+1$ for forget gate, $-1$ for input gate to ensure initial &ldquo;remember&rdquo; behavior</li>
</ul>
<h3 id="models">Models</h3>
<table>
  <thead>
      <tr>
          <th>Variant</th>
          <th>Layers</th>
          <th>Parameters</th>
          <th>Recurrent Layers</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Base</td>
          <td>12 (+1 recurrent)</td>
          <td>~151-164M</td>
          <td>Layer 10</td>
      </tr>
      <tr>
          <td>Large</td>
          <td>24 (+2 recurrent)</td>
          <td>650M</td>
          <td>Layers 10, 20</td>
      </tr>
      <tr>
          <td>XL</td>
          <td>24 (+2 recurrent)</td>
          <td>1.3B</td>
          <td>Layers 10, 20</td>
      </tr>
  </tbody>
</table>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Best Model</th>
          <th>PG19 (tokens)</th>
          <th>arXiv</th>
          <th>GitHub</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Bits-per-token</td>
          <td>Rec:fixed:skip</td>
          <td>3.53</td>
          <td>1.24</td>
          <td>0.976</td>
      </tr>
      <tr>
          <td>Word-level PPL</td>
          <td>1.3B model</td>
          <td>26.50</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p>Error bars on PG19 are between 0.002 and 0.007 (3 runs with different seeds).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training: 32 Google V4 TPU replicas</li>
<li>Training time: ~48 hours for 500k steps on PG19</li>
<li>Batch size: 32 (segment length 4096) or 256 (segment length 512), adjusted so each model sees the same tokens per step</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Available</th>
          <th>License</th>
          <th>URL</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Code (Meliad)</td>
          <td>Yes</td>
          <td>Apache 2.0</td>
          <td><a href="https://github.com/google-research/meliad">github.com/google-research/meliad</a></td>
      </tr>
      <tr>
          <td>PG19 Dataset</td>
          <td>Yes</td>
          <td>Public Domain</td>
          <td>Public</td>
      </tr>
      <tr>
          <td>arXiv Dataset</td>
          <td>No</td>
          <td>Not redistributable</td>
          <td>Private</td>
      </tr>
      <tr>
          <td>GitHub Dataset</td>
          <td>No</td>
          <td>Not redistributable</td>
          <td>Private</td>
      </tr>
      <tr>
          <td>Pretrained Models</td>
          <td>No</td>
          <td>-</td>
          <td>-</td>
      </tr>
  </tbody>
</table>
<p><strong>Reproducibility Assessment</strong>: Partially Reproducible. Source code is available under Apache 2.0 and the PG19 dataset is public. However, two of three evaluation datasets (arXiv, GitHub) were obtained via private channels and are not redistributable. No pretrained model checkpoints are released.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Hutchins, D., Schlag, I., Wu, Y., Dyer, E., &amp; Neyshabur, B. (2022). Block-Recurrent Transformers. <em>Advances in Neural Information Processing Systems 35 (NeurIPS 2022)</em>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{hutchins2022block,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Block-Recurrent Transformers}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Hutchins, DeLesley and Schlag, Imanol and Wu, Yuhuai and Dyer, Ethan and Neyshabur, Behnam}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2022}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{2203.07852}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">archiveprefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">primaryclass</span>=<span style="color:#e6db74">{cs.LG}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div>]]></content:encoded></item></channel></rss>