<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:content="http://purl.org/rss/1.0/modules/content/"><channel><title>Generative Modeling on Hunter Heidenreich | Senior AI Research Scientist</title><link>https://hunterheidenreich.com/categories/generative-modeling/</link><description>Recent content in Generative Modeling on Hunter Heidenreich | Senior AI Research Scientist</description><image><title>Hunter Heidenreich | Senior AI Research Scientist</title><url>https://hunterheidenreich.com/img/avatar.webp</url><link>https://hunterheidenreich.com/img/avatar.webp</link></image><generator>Hugo -- 0.147.7</generator><language>en-US</language><copyright>2026 Hunter Heidenreich</copyright><lastBuildDate>Thu, 09 Apr 2026 00:00:00 +0000</lastBuildDate><atom:link href="https://hunterheidenreich.com/categories/generative-modeling/index.xml" rel="self" type="application/rss+xml"/><item><title>Latent Diffusion Models for High-Res Image Synthesis</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/latent-diffusion-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/latent-diffusion-models/</guid><description>Latent Diffusion Models train diffusion in a compressed latent space, enabling high-res image synthesis with cross-attention conditioning at reduced compute.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It introduces Latent Diffusion Models (LDMs), which train denoising diffusion models in the latent space of pretrained autoencoders rather than directly in pixel space. The key insight is that separating perceptual compression from generative learning enables high-resolution image synthesis at a fraction of the computational cost of pixel-based diffusion. The paper also introduces a cross-attention conditioning mechanism for flexible multi-modal generation.</p>
<h2 id="computational-cost-of-pixel-space-diffusion">Computational Cost of Pixel-Space Diffusion</h2>
<p>Training diffusion models directly in pixel space is computationally expensive (150 to 1000 V100 GPU-days for leading models at the time) because the model must process high-dimensional RGB data at every denoising step. Much of this compute is spent modeling imperceptible high-frequency details. The authors observe that learning can be split into two stages: a perceptual compression stage that removes high-frequency detail, and a semantic compression stage where the generative model learns the conceptual composition. Prior two-stage approaches (VQGAN, DALL-E) relied on aggressive compression and autoregressive modeling in discrete latent spaces, trading off reconstruction quality for tractability.</p>
<h2 id="core-innovation-diffusion-in-latent-space">Core Innovation: Diffusion in Latent Space</h2>
<p>LDMs decompose image synthesis into two phases:</p>
<p><strong>Phase 1: Perceptual Compression.</strong> A pretrained autoencoder (encoder $\mathcal{E}$, decoder $\mathcal{D}$) maps images $x \in \mathbb{R}^{H \times W \times 3}$ to a lower-dimensional latent representation $z = \mathcal{E}(x) \in \mathbb{R}^{h \times w \times c}$ with spatial downsampling factor $f = H/h$. The autoencoder is trained with a perceptual loss (matching deep features from a pretrained VGG network) and a patch-based adversarial objective, with either KL or VQ regularization on the latent space.</p>
<p><strong>Phase 2: Latent Diffusion.</strong> A standard denoising diffusion model operates in this latent space. The training objective becomes:</p>
<p>$$L_{\text{LDM}} := \mathbb{E}_{\mathcal{E}(x), \epsilon \sim \mathcal{N}(0,1), t} \left[ \left| \epsilon - \epsilon_\theta(z_t, t) \right|_2^2 \right]$$</p>
<p>where $z_t$ is the noised latent at timestep $t$, and $\epsilon_\theta$ is a time-conditional UNet.</p>
<p><strong>Cross-Attention Conditioning.</strong> To enable conditioning on text, semantic maps, or other modalities, the authors introduce cross-attention layers into the UNet. A domain-specific encoder $\tau_\theta$ maps conditioning input $y$ to an intermediate representation $\tau_\theta(y) \in \mathbb{R}^{M \times d_\tau}$, which interacts with the UNet features via:</p>
<p>$$Q = W_Q^{(i)} \cdot \varphi_i(z_t), \quad K = W_K^{(i)} \cdot \tau_\theta(y), \quad V = W_V^{(i)} \cdot \tau_\theta(y)$$</p>
<p>The conditional objective then becomes:</p>
<p>$$L_{\text{LDM}} := \mathbb{E}_{\mathcal{E}(x), y, \epsilon \sim \mathcal{N}(0,1), t} \left[ \left| \epsilon - \epsilon_\theta(z_t, t, \tau_\theta(y)) \right|_2^2 \right]$$</p>
<p>Both $\tau_\theta$ and $\epsilon_\theta$ are optimized jointly.</p>
<h2 id="experimental-setup-and-results">Experimental Setup and Results</h2>
<p>The authors evaluate across multiple tasks and datasets:</p>
<p><strong>Perceptual compression tradeoffs.</strong> Downsampling factors $f \in {1, 2, 4, 8, 16, 32}$ are compared on ImageNet class-conditional generation. LDM-1 (pixel-based) trains slowly; LDM-32 loses too much information. LDM-4 and LDM-8 achieve the best balance, with LDM-8 outperforming pixel-based diffusion by 38 FID points after 2M training steps on a single A100.</p>
<p><strong>Unconditional image synthesis</strong> on CelebA-HQ 256, FFHQ 256, LSUN Churches/Bedrooms 256: LDM-4 achieves FID 5.11 on CelebA-HQ (state of the art at the time), outperforming LSGM, GANs, and other likelihood-based models. On LSUN-Bedrooms, LDM-4 achieves FID 2.95, close to ADM (1.90) with half the parameters and roughly 4x less training compute (see Appendix E.3.5).</p>
<p><strong>Text-to-image synthesis</strong> on MS-COCO: A 1.45B parameter LDM-KL-8 model trained on LAION-400M achieves FID 12.63 with classifier-free guidance (a technique that amplifies the conditioning signal at the cost of diversity, by interpolating between conditional and unconditional predictions) at scale s=1.5, on par with GLIDE (FID 12.24, 6B params) and Make-A-Scene (FID 11.84, 4B params) with substantially fewer parameters.</p>
<p><strong>Class-conditional ImageNet 256:</strong> LDM-4-G achieves FID 3.60, IS 247.67, outperforming ADM-G (FID 4.59) with fewer parameters and less compute.</p>
<p><strong>Super-resolution:</strong> LDM-4 (big) achieves FID 2.4 on ImageNet 64-to-256 upscaling (validation split), outperforming SR3 in FID.</p>
<p><strong>Inpainting</strong> on Places: LDM-4 (big, w/ ft) achieves FID 1.50, setting a new state of the art on image inpainting.</p>
<h2 id="key-findings-and-limitations">Key Findings and Limitations</h2>
<ul>
<li>LDM-4 and LDM-8 offer the best tradeoff between perceptual compression and generation quality.</li>
<li>The autoencoder only needs to be trained once and can be reused across different diffusion models and tasks.</li>
<li>Cross-attention conditioning generalizes to text, semantic layouts, and bounding boxes without architecture changes.</li>
<li>Convolutional sampling enables generation at resolutions higher than the training resolution (up to 1024x1024).</li>
<li>Sequential sampling remains slower than GANs. The autoencoder reconstruction can become a bottleneck for tasks requiring pixel-level precision.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Unconditional</td>
          <td>CelebA-HQ, FFHQ, LSUN</td>
          <td>256x256</td>
          <td>Standard benchmarks</td>
      </tr>
      <tr>
          <td>Class-conditional</td>
          <td>ImageNet</td>
          <td>256x256</td>
          <td>1000 classes</td>
      </tr>
      <tr>
          <td>Text-to-image</td>
          <td>LAION-400M</td>
          <td>256x256</td>
          <td>400M image-text pairs</td>
      </tr>
      <tr>
          <td>Inpainting</td>
          <td>Places</td>
          <td>256x256, 512x512</td>
          <td>Following LaMa protocol</td>
      </tr>
      <tr>
          <td>Super-resolution</td>
          <td>ImageNet</td>
          <td>64 to 256</td>
          <td>Following SR3 pipeline</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Autoencoder regularization</strong>: KL-reg (KL penalty toward standard normal, weighted by ~$10^{-6}$) or VQ-reg (vector quantization layer on the latent space with a learned codebook)</li>
<li><strong>Diffusion</strong>: Standard DDPM denoising with reweighted objective</li>
<li><strong>Sampling</strong>: DDIM sampler with configurable steps (100 to 500 depending on task)</li>
<li><strong>Guidance</strong>: Classifier-free diffusion guidance with scale $s$ (1.5 for class-conditional and text-to-image quantitative evaluation; 10.0 for qualitative text-to-image samples)</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Autoencoder</strong>: Based on VQGAN architecture with perceptual + adversarial loss</li>
<li><strong>UNet backbone</strong>: Time-conditional with cross-attention layers at multiple resolutions</li>
<li><strong>Text encoder</strong>: BERT-tokenizer with transformer $\tau_\theta$ for LAION text-to-image model</li>
<li><strong>LDM-4-G</strong>: 400M parameters, $f=4$ downsampling</li>
<li><strong>LDM-KL-8 (text)</strong>: 1.45B parameters, $f=8$ downsampling, KL-regularized</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Task</th>
          <th>Best Value</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CelebA-HQ unconditional</td>
          <td>5.11</td>
          <td>500 DDIM steps</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet class-conditional</td>
          <td>3.60</td>
          <td>LDM-4-G, cfg s=1.5</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>MS-COCO text-to-image</td>
          <td>12.63</td>
          <td>LDM-KL-8-G, 250 steps, cfg s=1.5</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>Places inpainting</td>
          <td>1.50</td>
          <td>LDM-4 big, w/ ft</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet 4x super-resolution</td>
          <td>2.4</td>
          <td>LDM-4 big, 100 steps</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Perceptual compression tradeoff experiments: single NVIDIA A100</li>
<li>Inpainting model trained on eight V100</li>
<li>Training at least 2.7x faster than pixel-based diffusion at equal parameters</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/CompVis/latent-diffusion">CompVis/latent-diffusion</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained models</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Rombach, R., Blattmann, A., Lorenz, D., Esser, P., &amp; Ommer, B. (2022). High-Resolution Image Synthesis with Latent Diffusion Models. <em>CVPR 2022</em>. <a href="https://arxiv.org/abs/2112.10752">https://arxiv.org/abs/2112.10752</a></p>
<p><strong>Publication</strong>: CVPR 2022</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{rombach2022highresolution,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{High-Resolution Image Synthesis with Latent Diffusion Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj{\&#34;o}rn}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>     = <span style="color:#e6db74">{10684--10695}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2022}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/CompVis/latent-diffusion">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>D3PM: Discrete Denoising Diffusion Probabilistic Models</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/discrete-diffusion-models/</guid><description>D3PMs extend diffusion models to discrete data with structured transition matrices, connecting diffusion to masked language models.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It extends denoising diffusion probabilistic models (DDPMs) from continuous to discrete state-spaces by introducing structured Markov transition matrices for the corruption process. The paper unifies several corruption strategies, draws a formal connection between absorbing-state diffusion and masked language models, and demonstrates competitive results on both image and text generation.</p>
<h2 id="diffusion-beyond-continuous-spaces">Diffusion Beyond Continuous Spaces</h2>
<p>Standard DDPMs operate in continuous state-spaces (e.g., pixel values treated as real numbers) and use Gaussian noise for corruption. Many important data types are inherently discrete: text (tokens from a vocabulary), quantized images (discrete pixel values), molecular structures, and segmentation maps. Prior work by Hoogeboom et al. extended binary diffusion to multinomial diffusion with uniform transition probabilities, but this limits the structure of the corruption process. D3PMs generalize this by allowing arbitrary transition matrices that encode domain-specific inductive biases.</p>
<h2 id="core-innovation-structured-transition-matrices">Core Innovation: Structured Transition Matrices</h2>
<p>D3PMs define a forward corruption process over discrete variables $\mathbf{x} \in {1, \ldots, K}^D$ using transition matrices $\mathbf{Q}_t \in \mathbb{R}^{K \times K}$:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_{t-1} \mathbf{Q}_t)$$</p>
<p>where $\mathbf{x}_{t-1}$ is a one-hot row vector. The cumulative transition after $t$ steps is $\overline{\mathbf{Q}}_t = \mathbf{Q}_1 \mathbf{Q}_2 \cdots \mathbf{Q}_t$, giving:</p>
<p>$$q(\mathbf{x}_t | \mathbf{x}_0) = \text{Cat}(\mathbf{x}_t; \mathbf{p} = \mathbf{x}_0 \overline{\mathbf{Q}}_t)$$</p>
<p>The paper explores several transition matrix designs:</p>
<p><strong>Uniform diffusion:</strong> $[\mathbf{Q}_t]_{ij} = (1 - \beta_t) \mathbf{1}_{i=j} + \beta_t / K$. Transitions with equal probability to any state. Stationary distribution is uniform.</p>
<p><strong>Absorbing state:</strong> In absorbing-state diffusion, each non-mask token transitions to the mask state with probability $\beta_t$ per step, while tokens already at the mask state remain there:</p>
<p>$[\mathbf{Q}_t]_{ij} = (1-\beta_t)\mathbf{1}_{i=j\neq m} + \beta_t \mathbf{1}_{j=m} + \mathbf{1}_{i=j=m}$. Each token transitions to a designated absorbing state $m$ (e.g., [MASK] for text, gray pixel for images) with probability $\beta_t$. This establishes a direct connection to masked language models like BERT.</p>
<p><strong>Discretized Gaussian:</strong> Transition probabilities decay as a function of the distance $|i-j|$ between states, mimicking Gaussian diffusion on ordinal data like pixel values.</p>
<p><strong>Embedding-based nearest neighbor:</strong> For text, transitions are weighted by proximity in a pretrained word embedding space, so corruption preferentially swaps words with semantically similar ones.</p>
<p><strong>Training objective.</strong> The reverse process $p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t)$ is parameterized by predicting $\tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$ and computing the posterior:</p>
<p>$$p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) \propto \sum_{\tilde{\mathbf{x}}_0} q(\mathbf{x}_{t-1} | \mathbf{x}_t, \tilde{\mathbf{x}}_0) , \tilde{p}_\theta(\tilde{\mathbf{x}}_0 | \mathbf{x}_t)$$</p>
<p>The loss combines the variational lower bound (VLB) with an auxiliary cross-entropy loss $L_\lambda$:</p>
<p>$$L = L_{\text{VLB}} + \lambda , L_{\text{CE}}$$</p>
<p>where $L_{\text{CE}}$ is a reweighted cross-entropy loss on the $\mathbf{x}_0$ prediction that stabilizes training and improves sample quality. The VLB decomposes into per-timestep KL divergences between the true and predicted reverse transitions.</p>
<h2 id="experiments-and-results">Experiments and Results</h2>
<p><strong>Image generation (CIFAR-10):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>Loss</th>
          <th>IS</th>
          <th>FID</th>
          <th>NLL (bpd)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM uniform</td>
          <td>$L_{\text{VLB}}$</td>
          <td>5.99</td>
          <td>51.27</td>
          <td>5.08</td>
      </tr>
      <tr>
          <td>D3PM absorbing</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>6.78</td>
          <td>30.97</td>
          <td>4.40</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_{\text{VLB}}$</td>
          <td>7.75</td>
          <td>15.30</td>
          <td>3.97</td>
      </tr>
      <tr>
          <td>D3PM Gauss</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.54</td>
          <td>8.34</td>
          <td>3.98</td>
      </tr>
      <tr>
          <td>D3PM Gauss + logistic</td>
          <td>$L_\lambda$ ($\lambda{=}0.001$)</td>
          <td>8.56</td>
          <td>7.34</td>
          <td>3.44</td>
      </tr>
      <tr>
          <td>DDPM $L_{\text{simple}}$ (continuous)</td>
          <td>&ndash;</td>
          <td>9.46</td>
          <td>3.17</td>
          <td>3.75</td>
      </tr>
  </tbody>
</table>
<p>The best discrete D3PM variant is D3PM Gauss + logistic, which achieves FID 7.34 and NLL 3.44 bpd using the combined $L_\lambda$ loss with a truncated logistic parameterization. The truncated logistic parameterization replaces the standard softmax output with a discretized logistic distribution over pixel values, assigning probability mass to each discrete bin based on a continuous logistic CDF. This provides a smoother output distribution that better captures the ordinal structure of pixel intensities. This variant exceeds the continuous DDPM in log-likelihood (3.44 vs. 3.75 bpd) while approaching its sample quality (FID 7.34 vs. 3.17).</p>
<p><strong>Text generation (text8, character-level, 1000 steps):</strong></p>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>bpc</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>D3PM absorbing ($L_\lambda$)</td>
          <td>1.45</td>
      </tr>
      <tr>
          <td>D3PM NN ($L_{\text{VLB}}$)</td>
          <td>1.59</td>
      </tr>
      <tr>
          <td>D3PM uniform</td>
          <td>1.61</td>
      </tr>
      <tr>
          <td>Discrete Flow (Tran et al.)</td>
          <td>1.23</td>
      </tr>
  </tbody>
</table>
<p>Among the D3PM variants and baselines evaluated, D3PM absorbing achieves the best bpc on text8 apart from Discrete Flow (Tran et al., 2019). On LM1B (sentencepiece vocabulary of 8192 tokens), D3PM absorbing achieves a perplexity of 76.9 at 1000 steps, compared to 137.9 for D3PM uniform and 43.6 for a comparable autoregressive transformer, demonstrating that discrete diffusion scales to large vocabularies.</p>
<p><strong>Ablation findings:</strong></p>
<ul>
<li>The auxiliary cross-entropy loss $L_\lambda$ is critical: for D3PM Gauss, it improves FID from 15.30 ($L_{\text{VLB}}$) to 8.34 ($L_\lambda$, $\lambda{=}0.001$). Adding the truncated logistic parameterization further improves FID to 7.34.</li>
<li>Discretized Gaussian transitions outperform both uniform and absorbing-state transitions on CIFAR-10 across all metrics.</li>
<li>For text, the absorbing-state (mask) model outperforms uniform and nearest-neighbor models. Nearest-neighbor diffusion provides only marginal improvement over uniform, a surprising negative result.</li>
<li>The $\mathbf{x}_0$-parameterization ensures the learned reverse distribution has the correct sparsity pattern dictated by the transition matrix $\mathbf{Q}_t$.</li>
</ul>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>The choice of transition matrix is an important design decision that encodes domain-specific inductive biases. Discretized Gaussian transitions work best for ordinal image data; absorbing-state transitions work best for text.</li>
<li>D3PMs formally unify diffusion models and masked language models: absorbing-state diffusion with a [MASK] token is equivalent to a reweighted BERT-style training objective.</li>
<li>The combined VLB + auxiliary loss ($L_\lambda$) achieves better density estimation (3.44 bpd) than continuous DDPMs (3.75 bpd) while producing competitive samples.</li>
<li>Sample quality (best FID 7.34 for D3PM Gauss + logistic) still lags behind continuous-space DDPMs (FID 3.17) on CIFAR-10, though the gap narrows with structured transitions and the auxiliary loss.</li>
<li>Scaling to very large numbers of categories $K$ requires special techniques (low-rank corruption or matrix exponentials) to manage the $O(K^2 T)$ memory cost of storing transition matrices.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Image generation</td>
          <td>CIFAR-10</td>
          <td>32x32, 256 categories</td>
          <td>Quantized to 256 ordinal values per channel</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>text8</td>
          <td>Character-level</td>
          <td>27 character vocabulary, sequences of length 256</td>
      </tr>
      <tr>
          <td>Text generation</td>
          <td>LM1B</td>
          <td>Word-level</td>
          <td>Sentencepiece vocabulary of 8192 tokens, sequence length 128</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Noise schedules</strong>: Linear schedule for D3PM Gauss, cosine schedule for D3PM uniform, and a novel mutual information schedule for absorbing and nearest-neighbor models</li>
<li><strong>Reverse parameterization</strong>: $\mathbf{x}_0$-parameterization with posterior computation via Bayes&rsquo; rule</li>
<li><strong>Loss</strong>: $L_{\text{VLB}} + \lambda L_{\text{CE}}$ with $\lambda = 0.001$ for images and $\lambda = 0.01$ for text absorbing models</li>
<li><strong>Scaling</strong>: Low-rank corruption (absorbing, uniform) scales as $O(r^2 T)$; matrix exponentials for nearest-neighbor transitions</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Image models</strong>: Modified U-Net architecture from Ho et al. (2020) adapted for categorical output via softmax over $K$ classes</li>
<li><strong>Text models</strong>: 12-layer <a href="/notes/natural-language-processing/language-models/t5-text-to-text-transfer-transformer/">T5</a>-style transformer encoder with 70M parameters (12 heads, MLP dim 3072, QKV dim 768)</li>
<li><strong>Timesteps</strong>: $T = 1000$ for both images and text, though text models can be evaluated with fewer steps (e.g., 256 or 20)</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>Best D3PM</th>
          <th>Continuous DDPM</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>7.34 (Gauss + logistic)</td>
          <td>3.17</td>
      </tr>
      <tr>
          <td>NLL (bpd)</td>
          <td>CIFAR-10</td>
          <td>3.44 (Gauss + logistic)</td>
          <td>3.75</td>
      </tr>
      <tr>
          <td>BPC</td>
          <td>text8 (char)</td>
          <td>1.45 (absorbing, $L_\lambda$)</td>
          <td>N/A</td>
      </tr>
      <tr>
          <td>Perplexity</td>
          <td>LM1B</td>
          <td>76.9 (absorbing)</td>
          <td>N/A</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>All models trained for 1M steps with batch size 512 on TPUv2 or TPUv3</li>
<li>Text models: 12-layer transformer encoder (T5 architecture), 70M parameters</li>
<li>Image models: Modified U-Net architecture from Ho et al. (2020)</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/google-research/google-research/tree/master/d3pm">google-research/d3pm</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX/Flax implementation for image and text experiments</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Austin, J., Johnson, D. D., Ho, J., Tarlow, D., &amp; van den Berg, R. (2021). Structured Denoising Diffusion Models in Discrete State-Spaces. <em>NeurIPS 2021</em>. <a href="https://arxiv.org/abs/2107.03006">https://arxiv.org/abs/2107.03006</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{austin2021structured,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Structured Denoising Diffusion Models in Discrete State-Spaces}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Austin, Jacob and Johnson, Daniel D. and Ho, Jonathan and Tarlow, Daniel and van den Berg, Rianne}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>Consistency Models: Fast One-Step Diffusion Generation</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</link><pubDate>Sun, 15 Mar 2026 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/consistency-models/</guid><description>Consistency models enable one-step generation by learning to map any point on a diffusion ODE trajectory to its origin, achieving FID 3.55 on CIFAR-10.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper. It proposes consistency models, a new class of generative models designed for fast one-step (or few-step) generation. The models can be trained either by distilling pretrained diffusion models (consistency distillation) or as standalone generative models from scratch (consistency training). The paper provides theoretical analysis of both training modes and achieves FID 3.55 on CIFAR-10 for single-step non-adversarial generation (state of the art at the time of publication).</p>
<h2 id="the-slow-sampling-problem-in-diffusion">The Slow Sampling Problem in Diffusion</h2>
<p>Diffusion models produce high-quality samples but require iterating through many denoising steps (often tens to hundreds), making generation slow compared to GANs or VAEs. Previous approaches to speed up sampling include faster ODE/SDE solvers (DDIM, DPM-Solver) and progressive distillation. These either still require multiple steps or depend on a complex multi-stage distillation pipeline. The goal is a model that can generate high-quality samples in a single forward pass while optionally allowing more steps for better quality.</p>
<h2 id="core-innovation-the-self-consistency-property">Core Innovation: The Self-Consistency Property</h2>
<p>The key idea builds on the Probability Flow (PF) ODE from the score-based SDE framework. The PF ODE describes a deterministic trajectory that converts noise into data, governed by the learned score function. For the VE-SDE parameterization used by EDM (Karras et al., 2022), this takes the form:</p>
<p>$$\frac{d\mathbf{x}_t}{dt} = -t , s_\phi(\mathbf{x}_t, t)$$</p>
<p>where $s_\phi$ is a pretrained score model, a <strong>consistency function</strong> $f(\mathbf{x}_t, t)$ maps any point on an ODE trajectory to the trajectory&rsquo;s origin $\mathbf{x}_\epsilon$. The defining property is self-consistency:</p>
<p>$$f(\mathbf{x}_t, t) = f(\mathbf{x}_{t&rsquo;}, t&rsquo;) \quad \text{for all } t, t&rsquo; \in [\epsilon, T]$$</p>
<p>for any points $\mathbf{x}_t$ and $\mathbf{x}_{t&rsquo;}$ on the same PF ODE trajectory.</p>
<p><strong>Parameterization.</strong> The model enforces the boundary condition $f(\mathbf{x}_\epsilon, \epsilon) = \mathbf{x}_\epsilon$ using skip connections:</p>
<p>$$f_\theta(\mathbf{x}, t) = c_{\text{skip}}(t) , \mathbf{x} + c_{\text{out}}(t) , F_\theta(\mathbf{x}, t)$$</p>
<p>where $c_{\text{skip}}(\epsilon) = 1$ and $c_{\text{out}}(\epsilon) = 0$, ensuring the boundary condition is satisfied by construction.</p>
<p><strong>Consistency Distillation (CD).</strong> Given a pretrained diffusion model, CD trains a consistency model by enforcing self-consistency between adjacent timesteps:</p>
<p>$$\mathcal{L}_{\text{CD}}^N(\theta, \theta^-; \phi) = \mathbb{E}\left[\lambda(t_n) , d!\left(f_\theta(\mathbf{x}_{t_{n+1}}, t_{n+1}), , f_{\theta^-}(\hat{\mathbf{x}}_{t_n}^\phi, t_n)\right)\right]$$</p>
<p>where $\hat{\mathbf{x}}_{t_n}^\phi$ is obtained by running one step of the ODE solver using the pretrained score model, $\theta^-$ is an exponential moving average (EMA) of $\theta$, and $d(\cdot, \cdot)$ is a distance metric. The use of a target network $\theta^-$ (updated via EMA) parallels techniques from deep Q-learning and momentum contrastive learning.</p>
<p><strong>Consistency Training (CT).</strong> CT eliminates the need for a pretrained diffusion model. It replaces the ODE solver step with a score estimate derived from the denoising score matching identity:</p>
<p>$$\nabla_{\mathbf{x}_t} \log p_t(\mathbf{x}_t) = \mathbb{E}\left[\frac{\mathbf{x} - \mathbf{x}_t}{t^2} ,\middle|, \mathbf{x}_t\right]$$</p>
<p>Because this identity lets us estimate the score from noisy data alone (without a pretrained model), we can compute the ODE update directly from training samples. This allows training directly on data pairs $(\mathbf{x}, \mathbf{x} + t\mathbf{z})$ where $\mathbf{z} \sim \mathcal{N}(0, I)$.</p>
<p><strong>Theoretical guarantee.</strong> If CD achieves zero loss, the consistency model error is bounded by $O((\Delta t)^p)$ where $\Delta t$ is the maximum timestep gap and $p$ is the order of the ODE solver.</p>
<h2 id="experiments-and-benchmarks">Experiments and Benchmarks</h2>
<p><strong>Datasets:</strong> CIFAR-10 (32x32), ImageNet 64x64, LSUN Bedroom 256x256, LSUN Cat 256x256.</p>
<p><strong>Architecture:</strong> All models use the NCSN++/EDM architecture. CD distills from pretrained EDM models.</p>
<p><strong>Key results for consistency distillation (CD):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>3.55</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>2.93</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>6.20</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>4.70</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>7.80</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>2</td>
          <td>5.22</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>11.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>2</td>
          <td>8.84</td>
      </tr>
  </tbody>
</table>
<p>CD outperforms progressive distillation (PD) across all datasets and sampling steps, with the exception of single-step generation on Bedroom 256x256 where CD with $\ell_2$ slightly underperforms PD with $\ell_2$.</p>
<p><strong>Key results for consistency training (CT):</strong></p>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Steps</th>
          <th>FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>1</td>
          <td>8.70</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>2</td>
          <td>5.83</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>1</td>
          <td>13.0</td>
      </tr>
      <tr>
          <td>ImageNet 64x64</td>
          <td>2</td>
          <td>11.1</td>
      </tr>
      <tr>
          <td>LSUN Bedroom 256</td>
          <td>1</td>
          <td>16.0</td>
      </tr>
      <tr>
          <td>LSUN Cat 256</td>
          <td>1</td>
          <td>20.7</td>
      </tr>
  </tbody>
</table>
<p>CT outperforms existing single-step non-adversarial models (VAEs, normalizing flows), e.g., improving over DC-VAE&rsquo;s FID of 17.90 on CIFAR-10. Samples from CT share structural similarity with EDM samples from the same initial noise, suggesting CT does not suffer from mode collapse.</p>
<p><strong>Zero-shot editing:</strong> Consistency models support colorization, super-resolution, inpainting, stroke-guided generation, interpolation, and denoising at test time without task-specific training, by modifying the multi-step sampling algorithm.</p>
<h2 id="findings-and-limitations">Findings and Limitations</h2>
<ul>
<li>Consistency distillation achieves state-of-the-art FID for one-step generation (3.55 on CIFAR-10, 6.20 on ImageNet 64x64).</li>
<li>Multi-step sampling provides a smooth quality-compute tradeoff: more steps yield better FID.</li>
<li>CT produces competitive results without any pretrained diffusion model, making consistency models a standalone generative model family.</li>
<li>The LPIPS distance metric $d(\cdot, \cdot)$ generally outperforms $\ell_1$ and $\ell_2$ for training consistency models.</li>
<li>At higher resolutions (LSUN 256x256), the gap between CD/CT and full EDM sampling widens.</li>
<li>CT currently underperforms CD, suggesting room for improvement in the standalone training paradigm.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Primary benchmark</td>
          <td>CIFAR-10</td>
          <td>32x32, 50K train</td>
          <td>FID on 50K samples</td>
      </tr>
      <tr>
          <td>Scaling benchmark</td>
          <td>ImageNet 64x64</td>
          <td>64x64, 1.28M</td>
          <td>Unconditional generation</td>
      </tr>
      <tr>
          <td>High-res benchmark</td>
          <td>LSUN Bedroom, Cat</td>
          <td>256x256</td>
          <td>Unconditional generation</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>ODE solver for CD</strong>: Euler and Heun (2nd order) solvers on the empirical PF ODE</li>
<li><strong>EMA for target network</strong>: Decay rate $\mu$ scheduled as a function of training step</li>
<li><strong>Schedule functions</strong>: $N$ (number of discretization steps) and $\mu$ (EMA rate) increase over training following specific schedules (see Appendix C of the paper)</li>
<li><strong>Distance metric</strong>: LPIPS performs best; $\ell_2$ and $\ell_1$ also evaluated</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: NCSN++/EDM architecture from Karras et al. (2022)</li>
<li><strong>CD teacher</strong>: Pretrained EDM models</li>
<li><strong>Parameterization</strong>: Skip-connection formulation with $c_{\text{skip}}(t)$ and $c_{\text{out}}(t)$ from EDM</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Dataset</th>
          <th>CD 1-step</th>
          <th>CT 1-step</th>
          <th>EDM (full)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FID</td>
          <td>CIFAR-10</td>
          <td>3.55</td>
          <td>8.70</td>
          <td>2.04</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>ImageNet 64</td>
          <td>6.20</td>
          <td>13.0</td>
          <td>2.44</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Bedroom</td>
          <td>7.80</td>
          <td>16.0</td>
          <td>3.57</td>
      </tr>
      <tr>
          <td>FID</td>
          <td>LSUN Cat</td>
          <td>11.0</td>
          <td>20.7</td>
          <td>6.69</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<ul>
<li>Training details follow EDM conventions</li>
<li>CD and CT use the same batch sizes and learning rate schedules as EDM training</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/openai/consistency_models">openai/consistency_models</a></td>
          <td>Code</td>
          <td>MIT</td>
          <td>Official implementation with pretrained checkpoints</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Dhariwal, P., Chen, M., &amp; Sutskever, I. (2023). Consistency Models. <em>ICML 2023</em>. <a href="https://arxiv.org/abs/2303.01469">https://arxiv.org/abs/2303.01469</a></p>
<p><strong>Publication</strong>: ICML 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2023consistency,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Consistency Models}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Machine Learning}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>    = <span style="color:#e6db74">{202}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2023}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://arxiv.org/abs/2303.01469}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/openai/consistency_models">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-based-generative-modeling-sde/">Score-Based Generative Modeling with SDEs</a></li>
</ul>
]]></content:encoded></item><item><title>Score-Based Generative Modeling with SDEs (Song 2021)</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-based-generative-modeling-sde/</guid><description>Unified SDE framework for score-based generative models, introducing Predictor-Corrector samplers and setting CIFAR-10 records with FID 2.20 and 2.99 bits/dim.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper. It proposes a unified framework that generalizes previous discrete score-based models (SMLD and DDPM) into continuous-time Stochastic Differential Equations (SDEs). The paper introduces algorithms for sampling (Predictor-Corrector) and likelihood computation (Probability Flow ODE), validated by setting new records on CIFAR-10 (FID 2.20, IS 9.89 at the time of publication). It also contains elements of <strong>Systematization</strong> by showing how existing methods are special cases of this broader framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>Prior successful generative models, specifically Score Matching with Langevin Dynamics (SMLD) and Denoising Diffusion Probabilistic Models (DDPM), operate by sequentially corrupting data with slowly increasing noise and learning to reverse the process. Both methods treat the noise scales as a finite set of discrete steps. The authors aim to generalize this to a continuum of noise scales by modeling the diffusion process as a Stochastic Differential Equation (SDE). This continuous formulation enables:</p>
<ul>
<li><strong>Flexible sampling:</strong> Use of general-purpose SDE solvers.</li>
<li><strong>Exact likelihood computation:</strong> Via connection to Neural ODEs.</li>
<li><strong>Controllable generation:</strong> Solving inverse problems (inpainting, colorization) without retraining.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>SDE framework</strong> for score-based generative modeling:</p>
<ul>
<li><strong>Continuous Generalization:</strong> Proving that SMLD and DDPM noise perturbations correspond to discretizations of Variance Exploding (VE) SDEs and Variance Preserving (VP) SDEs, respectively.</li>
<li><strong>Reverse-Time SDE:</strong> Leveraging Anderson&rsquo;s result (Anderson, 1982: a result on time-reversal of diffusion processes showing that the reverse is also a diffusion, with the forward drift reversed and a correction term involving the score of the marginal density) that the reverse of a diffusion process is also a diffusion process, governed by the score (gradient of log density).</li>
<li><strong>Predictor-Corrector (PC) Samplers:</strong> A hybrid sampling strategy where a numerical SDE solver (Predictor) estimates the next step, and a score-based MCMC approach (Corrector) corrects the marginal distribution.</li>
<li><strong>Probability Flow ODE:</strong> Deriving a deterministic ODE that shares the same marginal densities as the SDE, enabling near-exact likelihood computation (accuracy is limited by both numerical ODE solver discretization and variance of the unbiased Hutchinson trace estimator) and latent space manipulation.</li>
<li><strong>Sub-VP SDE:</strong> A new SDE class proposed to improve likelihoods by bounding variance tighter than the VP SDE.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the framework on standard image benchmarks:</p>
<ul>
<li><strong>Datasets:</strong> CIFAR-10 (32x32), CelebA (64x64), LSUN (Bedroom, Church), and CelebA-HQ (256x256 and 1024x1024).</li>
<li><strong>Ablation Studies:</strong> Comparing samplers (Ancestral vs. Reverse Diffusion vs. Probability Flow vs. PC) and SDE types (VE, VP, sub-VP).</li>
<li><strong>Architecture Search:</strong> Exploring improvements like FIR up/downsampling, rescaling skip connections, and increasing depth (leading to NCSN++ and DDPM++ architectures).</li>
<li><strong>Likelihood Evaluation:</strong> Computing Negative Log-Likelihood (NLL) in bits/dim using the Probability Flow ODE.</li>
<li><strong>Inverse Problems:</strong> Testing class-conditional generation, inpainting, and colorization using the conditional reverse-time SDE.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Record Performance:</strong> The <strong>NCSN++ cont. (deep, VE)</strong> model achieved an Inception Score of 9.89 and FID of 2.20 on CIFAR-10 (as of ICLR 2021).</li>
<li><strong>High-Fidelity Generation:</strong> First score-based model to generate 1024x1024 images (CelebA-HQ).</li>
<li><strong>Competitive Likelihoods:</strong> The <strong>DDPM++ cont. (deep, sub-VP)</strong> model achieved 2.99 bits/dim on uniformly dequantized CIFAR-10, a record at the time.</li>
<li><strong>Sampling Efficiency:</strong> PC samplers consistently outperformed predictor-only methods (like standard ancestral sampling) for the same computational cost.</li>
<li><strong>Controllable Generation:</strong> Successful application to inpainting and colorization using a single unconditional model.</li>
<li><strong>Limitations:</strong> Sampling remains slower than GANs on the same datasets. The breadth of available samplers introduces many hyperparameters (SDE type, predictor, corrector, signal-to-noise ratio, number of steps) that require tuning.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>CIFAR-10</strong>: Used for main benchmarking (FID, Inception Score, NLL).</li>
<li><strong>CelebA-HQ</strong>: Used for high-resolution experiments at 256x256 and 1024x1024.</li>
<li><strong>LSUN</strong>: Bedroom and Church Outdoor categories (256x256) used for sampler comparison and controllable generation (inpainting, colorization).</li>
<li><strong>Preprocessing</strong>: CIFAR-10 images are 32x32; CelebA pre-processed to 64x64 following Song &amp; Ermon (2020). Data is typically scaled to $[0, 1]$ or standardized depending on the specific SDE config.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>Forward SDEs</strong>:</p>
<p>Here $dw$ denotes a Wiener process increment (a small, independent Gaussian noise burst at each timestep).</p>
<ul>
<li><strong>VE SDE (Variance Exploding)</strong>: $dx = \sqrt{\frac{d[\sigma^2(t)]}{dt}} dw$. Corresponds to SMLD. Used with $\sigma_{\min}=0.01$ and $\sigma_{\max}$ chosen via heuristics.</li>
<li><strong>VP SDE (Variance Preserving)</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)} dw$. Corresponds to DDPM.</li>
<li><strong>Sub-VP SDE</strong>: $dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)(1 - e^{-2\int_0^t \beta(s)ds})} dw$. Bounded variance, good for likelihoods.</li>
</ul>
<p><strong>Reverse-Time SDE Solver (Predictor)</strong>:</p>
<ul>
<li>Discretized via <strong>Reverse Diffusion Sampling</strong>, which matches the forward discretization.</li>
<li><strong>Euler-Maruyama</strong> solver used for continuously-trained models.</li>
</ul>
<p><strong>Corrector Algorithm</strong>:</p>
<ul>
<li><strong>Langevin MCMC</strong>: Applies annealed Langevin dynamics: adds noise and takes a score-guided gradient step to correct the marginal distribution at each timestep.</li>
<li><strong>PC Sampling</strong>: Alternates between one step of the Predictor and one step of the Corrector.</li>
<li><strong>Signal-to-Noise Ratio ($r$)</strong>: A hyperparameter for the corrector step size. Tuned values: $r \approx 0.16$ for VE SDEs on CIFAR-10.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>NCSN++</strong>: Optimized architecture for VE SDEs. Key features:
<ul>
<li>4 residual blocks per resolution.</li>
<li>BigGAN-type residual blocks.</li>
<li>Rescaling skip connections by $1/\sqrt{2}$.</li>
<li>FIR (Finite Impulse Response) up/downsampling.</li>
<li>&ldquo;Residual&rdquo; progressive architecture for input, no progressive growing for output.</li>
</ul>
</li>
<li><strong>DDPM++</strong>: Optimized architecture for VP/sub-VP SDEs. Similar to NCSN++ but without FIR upsampling and no progressive growing.</li>
<li><strong>Deep Variants</strong>: &ldquo;cont. (deep)&rdquo; models double the depth (from 4 to 8 blocks per resolution) for the best reported results.</li>
<li><strong>Conditioning</strong>: Time $t$ is conditioned via random Fourier feature embeddings (scale 16) for continuous models.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p><strong>Metrics</strong>:</p>
<ul>
<li><strong>FID (Fréchet Inception Distance)</strong>: Computed on 50k samples.</li>
<li><strong>Inception Score</strong>: Reported for CIFAR-10.</li>
<li><strong>NLL (Negative Log-Likelihood)</strong>: Reported in bits/dim on uniformly dequantized data using the Probability Flow ODE.</li>
</ul>
<p><strong>Denoising</strong>: A single denoising step using Tweedie&rsquo;s formula is applied at the end of sampling to remove residual noise, which significantly improves FID.</p>
<h3 id="hardware">Hardware</h3>
<p><strong>Training</strong>:</p>
<ul>
<li>Batch size: 128 for CIFAR-10, 64 for LSUN, 8 for high-res CelebA-HQ.</li>
<li>Iterations: Discrete-objective models trained for 1.3M iterations during architecture exploration. Continuous-objective models (cont.) trained for 0.95M iterations. High-res CelebA-HQ (1024x1024) trained for approximately 2.4M iterations.</li>
<li><strong>EMA</strong>: Exponential Moving Average rate of 0.999 used for VE models, 0.9999 for VP models.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yang-song/score_sde">yang-song/score_sde</a></td>
          <td>Code</td>
          <td>Apache-2.0</td>
          <td>Official JAX and PyTorch implementation with pretrained checkpoints</td>
      </tr>
  </tbody>
</table>
<p>All datasets used (CIFAR-10, CelebA-HQ, LSUN) are publicly available. Pretrained model checkpoints for CIFAR-10, CelebA-HQ, and FFHQ are provided in the repository. Specific hardware requirements (GPU type, training time) are not detailed in the paper.</p>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., &amp; Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. <em>ICLR 2021</em>. <a href="https://arxiv.org/abs/2011.13456">https://arxiv.org/abs/2011.13456</a></p>
<p><strong>Publication</strong>: ICLR 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{song2021scorebased,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>     = <span style="color:#e6db74">{Score-Based Generative Modeling through Stochastic Differential Equations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>    = <span style="color:#e6db74">{Song, Yang and Sohl-Dickstein, Jascha and Kingma, Diederik P and Kumar, Abhishek and Ermon, Stefano and Poole, Ben}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>      = <span style="color:#e6db74">{2021}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>       = <span style="color:#e6db74">{https://openreview.net/forum?id=PxTIG12RRHS}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/yang-song/score_sde">GitHub Repository</a></li>
<li><a href="/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/">Score Matching and Denoising Autoencoders</a></li>
</ul>
]]></content:encoded></item><item><title>Score Matching and Denoising Autoencoders: A Connection</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/score-matching-denoising-autoencoders/</guid><description>Theoretical paper proving the equivalence between training Denoising Autoencoders and performing Score Matching on a Parzen density estimator.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Theory Paper</strong>.</p>
<p>Its primary contribution is a formal mathematical derivation connecting two previously distinct techniques: Score Matching (SM) and Denoising Autoencoders (DAE). It provides the &ldquo;why&rdquo; behind the empirical success of DAEs by grounding them in the probabilistic framework of energy-based models. It relies on proofs and equivalence relations (e.g., $J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$).</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper bridges a gap between two successful but disconnected approaches in unsupervised learning:</p>
<ol>
<li><strong>Denoising Autoencoders (DAE):</strong> Empirically successful for pre-training deep networks. They previously lacked a clear probabilistic interpretation.</li>
<li><strong>Score Matching (SM):</strong> A theoretically sound method for estimating unnormalized density models that avoids the partition function problem but requires computing expensive second derivatives.</li>
</ol>
<p>By connecting them, the authors aim to define a proper probabilistic model for DAEs (allowing sampling/ranking) and find a simpler way to apply score matching that avoids second derivatives.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Denoising Score Matching (DSM)</strong> framework and the proof of its equivalence to DAEs. Key contributions include:</p>
<ul>
<li><strong>Equivalence Proof:</strong> Showing that training a DAE with Gaussian noise is equivalent to matching the score of a model against a non-parametric Parzen density estimator of the data.</li>
<li><strong>Denoising Score Matching ($J_{DSM}$):</strong> A new objective that learns a score function by trying to denoise corrupted samples. This avoids the explicit second derivatives required by standard Implicit Score Matching ($J_{ISM}$).</li>
<li><strong>Explicit Energy Function:</strong> Deriving the specific energy function $E(x;\theta)$ that corresponds to the standard sigmoid DAE architecture.</li>
<li><strong>Justification for Tied Weights:</strong> Providing a theoretical justification for tying encoder and decoder weights, which arises naturally from differentiating the energy function.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The validation in this theoretical paper is purely mathematical and focuses on formal proofs:</p>
<ul>
<li><strong>Derivation of Equivalence:</strong> The paper formally proves the chain of equivalences:
$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$
where $q_{\sigma}$ is the Parzen density estimate.</li>
<li><strong>Appendix Proof:</strong> A detailed proof is provided to show that Explicit Score Matching ($J_{ESM}$) on the Parzen density is equivalent to the proposed Denoising Score Matching ($J_{DSM}$) objective.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Theoretical Unification:</strong> DAE training is formally equivalent to Score Matching on a smoothed data distribution ($q_{\sigma}$).</li>
<li><strong>New Training Objective:</strong> The $J_{DSM}$ objective offers a computationally efficient way to perform score matching (no Hessian required) by using a denoising objective.</li>
<li><strong>Probabilistic Interpretation:</strong> DAEs can now be understood as Energy-Based Models (EBMs), allowing for operations like sampling (via Hybrid Monte Carlo) and likelihood ranking, which were previously ill-defined for standard autoencoders.</li>
<li><strong>Regularization Insight:</strong> The smoothing kernel width $\sigma$ in the Parzen estimator corresponds to the noise level in the DAE. This suggests that DAEs are learning a regularized version of the score, which may explain their robustness.</li>
<li><strong>Connection to Regularized Score Matching:</strong> The paper notes that Kingma and LeCun (2010) independently proposed a regularized score matching criterion $J_{ISMreg}$ derived by approximating $J_{ISMq_{\sigma}}$. The four $q_{\sigma}$-based objectives in this work (including the DAE objective) can be seen as approximation-free forms of regularized score matching, with the additional advantage that $J_{DSMq_{\sigma}}$ does not require second derivatives.</li>
</ul>
<hr>
<h2 id="key-concepts-explained">Key Concepts Explained</h2>
<h3 id="1-score-and-score-matching">1. &ldquo;Score&rdquo; and &ldquo;Score Matching&rdquo;</h3>
<p><strong>What does &ldquo;score&rdquo; actually mean?</strong></p>
<p>In this paper (and probabilistic modeling generally), the <strong>score</strong> is the gradient of the log-density with respect to the <em>data vector</em> $x$.</p>
<ul>
<li><strong>Definition:</strong> $\psi(x) = \nabla_x \log p(x)$.</li>
<li><strong>Intuition:</strong> It is a vector field pointing in the direction of highest probability increase. Crucially, calculating the score avoids the intractable partition function $Z$, because $\nabla_x \log p(x) = \nabla_x \log \tilde{p}(x) - \nabla_x \log Z = \nabla_x \log \tilde{p}(x)$. The constant $Z$ vanishes upon differentiation.</li>
</ul>
<p><strong>What is Score Matching?</strong></p>
<p>Score Matching is a training objective for unnormalized models. It minimizes the squared Euclidean distance between the model&rsquo;s score $\psi(x;\theta)$ and the data&rsquo;s true score $\nabla_x \log q(x)$.</p>
<h3 id="2-the-parzen-density-estimator">2. The Parzen Density Estimator</h3>
<p><strong>What is it?</strong></p>
<p>It is a non-parametric method for estimating a probability density function from finite data. It places a smooth kernel (here, a Gaussian) centered at every data point in the training set $D_n$.</p>
<ul>
<li><strong>Formula:</strong> $q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n \mathcal{N}(\tilde{x}; x^{(t)}, \sigma^2 I)$.</li>
</ul>
<p><strong>Why smooth the data?</strong></p>
<ol>
<li>
<p><strong>To define the score:</strong> The empirical data distribution is a set of Dirac deltas (spikes). The gradient (score) of a Dirac delta is undefined. Smoothing creates a differentiable surface, allowing a valid target score $\nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x})$ to be computed.</p>
</li>
<li>
<p><strong>To model corruption:</strong> The Parzen estimator with Gaussian kernels mathematically models the process of taking a clean data point $x$ and adding Gaussian noise - the exact procedure used in Denoising Autoencoders.</p>
</li>
</ol>
<h3 id="3-why-avoiding-second-derivatives-matters">3. Why avoiding second derivatives matters</h3>
<p>Standard <strong>Implicit Score Matching (ISM)</strong> eliminates the need for the unknown data score, but introduces a new cost: it requires computing the trace of the Hessian (the sum of second partial derivatives) of the log-density.</p>
<ul>
<li><strong>The Cost:</strong> For high-dimensional data (like images) and deep networks, computing second derivatives of the log-density is computationally expensive.</li>
<li>This paper shows that <strong>Denoising Score Matching (DSM)</strong> allows you to bypass Hessian computation entirely. By using the Parzen target, the objective simplifies to matching a first-order vector, making it scalable to deep neural networks.</li>
</ul>
<h3 id="4-the-equivalence-chain---why-each-step">4. The equivalence chain - why each step?</h3>
<p>The chain $J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ connects the concepts.</p>
<ul>
<li>
<p><strong>$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}}$ (Implicit $\to$ Explicit):</strong>
<strong>Why:</strong> Integration by parts. This is Hyvärinen&rsquo;s original proof (2005): integration by parts moves the derivative from $\psi$ onto the data density $q$, producing a term involving $q$&rsquo;s gradient (the score). The boundary term vanishes because $q_{\sigma}$ decays to zero at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching). The result allows replacing the unknown data score with a computable term involving only the model&rsquo;s score and its Jacobian.</p>
</li>
<li>
<p><strong>$J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}}$ (Explicit $\to$ Denoising):</strong>
<strong>Why:</strong> The explicit score of the Parzen density is known. When $x$ is perturbed to $\tilde{x}$ by Gaussian noise $\epsilon \sim \mathcal{N}(0, \sigma^2 I)$, the gradient of the log-density pointing back to the mean is exactly $\frac{1}{\sigma^2}(x - \tilde{x})$. Minimizing the error against the true score becomes minimizing the error against this restoration vector.</p>
</li>
<li>
<p><strong>$J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$ (Denoising $\to$ Autoencoder):</strong>
<strong>Why:</strong> Algebraic substitution. If you define the model&rsquo;s score $\psi(\tilde{x};\theta)$ to be proportional to the reconstruction error ($\propto x^r - \tilde{x}$), the score matching loss $J_{DSM}$ becomes proportional to the standard autoencoder squared loss $|x^r - x|^2$.</p>
</li>
</ul>
<h3 id="5-energy-based-models-ebms-connection">5. Energy-Based Models (EBMs) connection</h3>
<p><strong>What is an EBM?</strong></p>
<p>An EBM defines a probability distribution via an energy function $E(x;\theta)$, where $p(x;\theta) \propto e^{-E(x;\theta)}$.</p>
<p><strong>Why standard autoencoders lack probabilistic interpretation:</strong></p>
<p>A standard autoencoder acts as a deterministic map $x \to x^r$, providing a reconstruction error. It lacks a normalization constant or a defined density function to support sampling or probability queries.</p>
<p><strong>What does this enable?</strong></p>
<p>By proving the equivalence, the DAE is formally defined as an EBM. This enables:</p>
<ol>
<li><strong>Sampling:</strong> Using MCMC methods (like Hybrid Monte Carlo) to generate new data from the DAE.</li>
<li><strong>Ranking:</strong> Calculating the energy of inputs to determine which are more &ldquo;likely&rdquo; or &ldquo;normal&rdquo; (useful for anomaly detection).</li>
</ol>
<h3 id="6-the-specific-energy-function-form">6. The specific energy function form</h3>
<p>The function is:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p><strong>Why does it have that specific form?</strong></p>
<p>It was derived via integration to ensure its derivative matches the DAE architecture. The authors worked backward from the DAE&rsquo;s reconstruction function (sigmoid + linear) to find the scalar field that generates it.</p>
<p><strong>Where does the quadratic term come from?</strong></p>
<p>The score (negative energy gradient) needs to look like $\psi(x) \propto c - x + W^T\text{sigmoid}(Wx + b)$.</p>
<ul>
<li>The term $-x$ in the score arises because $\nabla_x(-\frac{1}{2}|x|^2) = -x$. Including $-\frac{1}{2}|x|^2$ inside the energy&rsquo;s numerator produces this linear term after differentiation.</li>
</ul>
<p><strong>How does differentiating it recover the DAE reconstruction?</strong></p>
<ul>
<li>$\nabla_x \sum_j \text{softplus}(\langle W_j, x \rangle + b_j) = W^T \sigma(Wx + b)$ (The encoder part).</li>
<li>$\nabla_x \langle c, x \rangle = c$ (The bias).</li>
<li>$\nabla_x (-\frac{1}{2}|x|^2) = -x$ (The input subtraction).</li>
<li>Result: $-\nabla_x E \propto c + W^T h - x = x^r - x$.</li>
</ul>
<h3 id="7-tied-weights-justification">7. &ldquo;Tied weights&rdquo; justification</h3>
<p><strong>What does it mean for weights to be &ldquo;tied&rdquo;?</strong></p>
<p>The decoder matrix is the transpose of the encoder matrix ($W^T$).</p>
<p><strong>Why is this theoretically justified?</strong></p>
<p>Because the reconstruction function is interpreted as the <strong>gradient</strong> of an energy function. A vector field can only be the gradient of a scalar field if its Jacobian is symmetric.</p>
<ul>
<li>In the DAE energy derivative, the encoder contributes $W^T \sigma(Wx + b)$. If the decoder used a separate matrix $U$, the resulting vector field would not be a valid gradient of any scalar energy function (unless $U = W^T$).</li>
<li>Therefore, for a DAE to correspond to a valid probabilistic Energy-Based Model, the weights <em>must</em> be tied.</li>
</ul>
<p><strong>The necessity of tied weights:</strong></p>
<p>Within this parametrization, tied weights are a mathematical necessity: a separate decoder matrix $U \neq W^T$ would make the reconstruction function an invalid gradient of any scalar energy, breaking the EBM correspondence.</p>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<p>Since this is a theoretical paper, the &ldquo;reproducibility&rdquo; lies in the mathematical formulations derived.</p>
<h3 id="data">Data</h3>
<ul>
<li><strong>Input Data ($D_n$):</strong> The theory assumes a set of training examples $D_n = {x^{(1)}, &hellip;, x^{(n)}}$ drawn from an unknown true pdf $q(x)$.</li>
<li><strong>Parzen Density Estimate ($q_{\sigma}$):</strong> The theoretical targets are derived from a kernel-smoothed empirical distribution:
$$q_{\sigma}(\tilde{x}) = \frac{1}{n} \sum_{t=1}^n q_{\sigma}(\tilde{x}|x^{(t)})$$
where the kernel is an isotropic Gaussian of variance $\sigma^2$.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Denoising Score Matching (DSM) Objective</strong></p>
<p>The paper proposes this objective as a tractable alternative to standard score matching. It minimizes the distance between the model score and the gradient of the log-noise density:</p>
<p>$$J_{DSMq_{\sigma}}(\theta) = \mathbb{E}_{q_{\sigma}(x,\tilde{x})} \left[ \frac{1}{2} \left| \psi(\tilde{x};\theta) - \frac{\partial \log q_{\sigma}(\tilde{x}|x)}{\partial \tilde{x}} \right|^2 \right]$$</p>
<p>For Gaussian noise, the target score is simply $\frac{1}{\sigma^2}(x - \tilde{x})$.</p>
<p><strong>2. Equivalence Chain</strong></p>
<p>The central result connects four objectives:</p>
<p>$$J_{ISMq_{\sigma}} \sim J_{ESMq_{\sigma}} \sim J_{DSMq_{\sigma}} \sim J_{DAE\sigma}$$</p>
<p>This implies optimizing the DAE reconstruction error is minimizing a score matching objective.</p>
<h3 id="models">Models</h3>
<p><strong>1. The Denoising Autoencoder (DAE)</strong></p>
<ul>
<li><strong>Corruption:</strong> Additive isotropic Gaussian noise $\tilde{x} = x + \epsilon, \epsilon \sim \mathcal{N}(0, \sigma^2 I)$.</li>
<li><strong>Encoder:</strong> $h = \text{sigmoid}(W\tilde{x} + b)$.</li>
<li><strong>Decoder:</strong> $x^r = W^T h + c$ (Tied weights $W$).</li>
<li><strong>Loss:</strong> Squared reconstruction error $|x^r - x|^2$. (The equivalence with DSM introduces a $\frac{1}{2\sigma^4}$ scaling factor.)</li>
</ul>
<p><strong>2. The Corresponding Energy Function</strong></p>
<p>To make the DAE equivalent to Score Matching, the underlying Energy-Based Model $p(x;\theta) \propto e^{-E(x;\theta)}$ must have the following energy function:</p>
<p>$$E(x; W, b, c) = - \frac{1}{\sigma^2} \left( \langle c, x \rangle - \frac{1}{2}|x|^2 + \sum_{j=1}^{d_h} \text{softplus}(\langle W_j, x \rangle + b_j) \right)$$</p>
<p>Note the scaling by $1/\sigma^2$ and the quadratic term $|x|^2$.</p>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metric:</strong> Theoretical Equivalence ($\sim$).</li>
<li><strong>Condition:</strong> The equivalence holds provided $\sigma &gt; 0$ and the density $q_{\sigma}$ is differentiable and vanishes at infinity (Hyvärinen&rsquo;s 2005 regularity condition for Implicit Score Matching).</li>
</ul>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Vincent, P. (2011). A Connection Between Score Matching and Denoising Autoencoders. <em>Neural Computation</em>, 23(7), 1661-1674. <a href="https://doi.org/10.1162/NECO_a_00142">https://doi.org/10.1162/NECO_a_00142</a></p>
<p><strong>Publication</strong>: Neural Computation 2011</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@article</span>{vincentConnectionScoreMatching2011,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{A {{Connection Between Score Matching}} and {{Denoising Autoencoders}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Vincent, Pascal}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2011</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">month</span> = jul,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">journal</span> = <span style="color:#e6db74">{Neural Computation}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span> = <span style="color:#e6db74">{23}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">number</span> = <span style="color:#e6db74">{7}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span> = <span style="color:#e6db74">{1661--1674}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">doi</span> = <span style="color:#e6db74">{10.1162/NECO_a_00142}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://www.iro.umontreal.ca/~vincentp/Publications/smdae_techreport.pdf">Official PDF</a></li>
</ul>
]]></content:encoded></item><item><title>Rectified Flow: Learning to Generate and Transfer Data</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/rectified-flow/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/rectified-flow/</guid><description>A unified ODE-based framework for generative modeling and domain transfer that learns straight paths for fast 1-step generation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with a significant <strong>Theory</strong> component.</p>
<ul>
<li><strong>Method</strong>: It proposes &ldquo;Rectified Flow,&rdquo; a novel generative framework that learns ordinary differential equations (ODEs) to transport distributions via straight paths. It introduces the &ldquo;Reflow&rdquo; algorithm to iteratively straighten these paths.</li>
<li><strong>Theory</strong>: It provides rigorous proofs connecting the method to Optimal Transport, showing that the rectification process yields a coupling with non-increasing convex transport costs and that recursive reflow reduces the curvature of trajectories.</li>
</ul>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work addresses two main challenges in unsupervised learning: generative modeling (generating data from noise) and domain transfer (mapping between two observed distributions).</p>
<ul>
<li><strong>Inefficiency of ODE/SDE Models</strong>: Continuous-time models (like Score-based Generative Models and DDPMs) require simulating diffusions over many steps, resulting in high computational costs during inference.</li>
<li><strong>Complexity of GANs</strong>: GANs provide fast (one-step) generation alongside challenges with training instability and mode collapse.</li>
<li><strong>Disconnection</strong>: Generative modeling and domain transfer are often treated as separate tasks requiring different techniques.</li>
</ul>
<p>The authors aim to unify these tasks into a single &ldquo;transport mapping&rdquo; problem while bridging the gap between high-quality continuous models and fast one-step models.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Rectified Flow</strong> framework and the <strong>Reflow</strong> procedure.</p>
<ul>
<li><strong>Straight-Line ODEs</strong>: Rectified Flow learns an ODE drift $v$ to follow the straight line connecting data pairs $(X_0, X_1)$, providing an alternative to diffusion models that rely on stochastic paths or specific forward processes. This is achieved via a simple least-squares optimization problem.</li>
<li><strong>Reflow (Iterative Straightening)</strong>: The authors introduce a recursive training procedure where a new flow is trained on the data pairs $(Z_0, Z_1)$ generated by the previous flow. Theoretical analysis shows this reduces the &ldquo;transport cost&rdquo; and straightens the trajectories, allowing for accurate 1-step simulation (effectively converting the ODE into a one-step model).</li>
<li><strong>Unified Framework</strong>: The method uses the exact same algorithm for generation ($\pi_0$ is Gaussian) and domain transfer ($\pi_0$ is a source dataset), removing the need for adversarial losses or cycle-consistency constraints.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors validated the method across image generation, translation, and domain adaptation tasks.</p>
<ul>
<li><strong>Unconditioned Image Generation</strong>:
<ul>
<li><strong>Dataset</strong>: CIFAR-10 ($32\times32$).</li>
<li><strong>Baselines</strong>: Compared against GANs (StyleGAN2, TDPM), Diffusion/SDE Models (VP SDE, sub-VP SDE, VE SDE), ODE methods (VP ODE, sub-VP ODE, VE ODE), and distilled methods (DDIM Distillation).</li>
<li><strong>High-Res</strong>: Validated on LSUN Bedroom/Church, CelebA-HQ, and AFHQ ($256\times256$).</li>
</ul>
</li>
<li><strong>Image-to-Image Translation</strong>:
<ul>
<li><strong>Datasets</strong>: AFHQ (Cat $\leftrightarrow$ Dog/Wild), MetFace $\leftrightarrow$ CelebA-HQ.</li>
<li><strong>Setup</strong>: Transferring styles while preserving semantic identity (using a classifier-based feature mapping metric).</li>
</ul>
</li>
<li><strong>Domain Adaptation</strong>:
<ul>
<li><strong>Datasets</strong>: DomainNet, Office-Home.</li>
<li><strong>Metric</strong>: Classification accuracy on the transferred testing data.</li>
</ul>
</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Superior 1-Step Generation</strong>: On CIFAR-10 with a single Euler step (as of ICLR 2023), the distilled 2-Rectified Flow achieved an FID of <strong>4.85</strong>, beating the best one-step U-Net model TDPM (FID 8.91, a truncated diffusion model using a GAN). The distilled 3-Rectified Flow reached a Recall of <strong>0.51</strong>, beating the GAN baseline StyleGAN2+ADA (Recall 0.49).</li>
<li><strong>Straightening Effect</strong>: The &ldquo;Reflow&rdquo; procedure was empirically shown to reduce the &ldquo;straightness&rdquo; error and transport costs, validating the theoretical claims. &ldquo;Straightness&rdquo; is measured as $S(Z) = \mathbb{E}[\int_0^1 |\dot{Z}_t - (Z_1 - Z_0)|^2, dt]$ (zero means perfectly straight); &ldquo;transport cost&rdquo; is $\mathbb{E}[c(Z_1 - Z_0)]$ for a convex cost $c$, and Reflow reduces this for all convex costs.</li>
<li><strong>High-Quality Transfer</strong>: The model successfully performed image translation (e.g., Cat to Wild Animal) without paired data or cycle-consistency losses.</li>
<li><strong>Strong Full-Simulation Results</strong>: With RK45 adaptive ODE solving, 1-Rectified Flow achieves FID 2.58 and Recall 0.57 on CIFAR-10 (Table 1a), the best among ODE methods and comparable to fully simulated SDEs (VP SDE: FID 2.55).</li>
<li><strong>Fast Simulation</strong>: The method allows for extremely coarse time discretization (e.g., $N=1$) without significant quality loss after reflow, effectively solving the slow inference speed of standard ODE models.</li>
<li><strong>Domain Adaptation</strong>: On Office-Home, Rectified Flow achieves 69.2% accuracy, outperforming Deep CORAL (68.7%) and other baselines. On DomainNet, it achieves 41.4%, comparable to Deep CORAL (41.5%) and MLDG (41.2%).</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p>The paper utilizes several standard computer vision benchmarks.</p>
<table>
  <thead>
      <tr>
          <th>Purpose</th>
          <th>Dataset</th>
          <th>Size/Resolution</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Generation</td>
          <td><strong>CIFAR-10</strong></td>
          <td>32x32</td>
          <td>Standard split</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td><strong>LSUN</strong> (Bedroom, Church)</td>
          <td>256x256</td>
          <td>High-res evaluation</td>
      </tr>
      <tr>
          <td>Generation</td>
          <td><strong>CelebA-HQ</strong></td>
          <td>256x256</td>
          <td>High-res evaluation</td>
      </tr>
      <tr>
          <td>Gen/Transfer</td>
          <td><strong>AFHQ</strong> (Cat, Dog, Wild)</td>
          <td>512x512</td>
          <td>256x256 for generation, 512x512 for transfer</td>
      </tr>
      <tr>
          <td>Transfer</td>
          <td><strong>MetFace</strong></td>
          <td>1024x1024</td>
          <td>Resized to 512x512 for experiments</td>
      </tr>
      <tr>
          <td>Adaptation</td>
          <td><strong>DomainNet</strong></td>
          <td>Mixed</td>
          <td>345 categories, 6 domains</td>
      </tr>
      <tr>
          <td>Adaptation</td>
          <td><strong>Office-Home</strong></td>
          <td>Mixed</td>
          <td>65 categories, 4 domains</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li>
<p><strong>Objective Function</strong>:
The drift $v(Z_t, t)$ is trained by minimizing a least-squares regression objective:
$$\min_{v} \int_{0}^{1} \mathbb{E}[|(X_1 - X_0) - v(X_t, t)|^2] dt$$
where $X_t = tX_1 + (1-t)X_0$ is the linear interpolation.</p>
</li>
<li>
<p><strong>Reflow Procedure</strong>:
Iteratively updates the flow. Let $Z^k$ be the $k$-th rectified flow.</p>
<ol>
<li>Generate 4 million data pairs $(Z_0^k, Z_1^k)$ by simulating the current flow.</li>
<li>Fine-tune the $i$-rectified flow model for 300,000 steps on these pairs to obtain the $(i+1)$-rectified flow.</li>
</ol>
</li>
<li>
<p><strong>Distillation</strong>:
For 1-step distillation ($k=1$), the L2 loss is replaced with LPIPS perceptual similarity, which empirically yields better image quality. For multi-step distillation, training samples $t$ from ${0, 1/k, \ldots, (k-1)/k}$ rather than the full $[0, 1]$ interval.</p>
</li>
<li>
<p><strong>ODE Solver</strong>:</p>
<ul>
<li>Training: Analytical linear interpolation.</li>
<li>Inference: Euler method (constant step size $1/N$) or RK45 (adaptive).</li>
</ul>
</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li>
<p><strong>Architecture</strong>:</p>
<ul>
<li>Uses the <strong>DDPM++ U-Net</strong> architecture (from Song et al., 2020) across experiments. Implementation is modified from the open-source code of Song et al.</li>
</ul>
</li>
<li>
<p><strong>Optimization</strong>:</p>
<ul>
<li><strong>Optimizer</strong>: Adam (CIFAR-10) or AdamW (Transfer/Adaptation).</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>LR: $2 \times 10^{-4}$ (CIFAR), Grid search for transfer.</li>
<li>EMA: 0.999999 (CIFAR), 0.9999 (Transfer).</li>
<li>Batch Size: 4 (Transfer), 16 (Domain Adaptation).</li>
<li>Dropout: 0.15 (CIFAR), 0.1 (Transfer).</li>
</ul>
</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<table>
  <thead>
      <tr>
          <th>Metric</th>
          <th>Value (CIFAR-10, N=1)</th>
          <th>Baseline (Best 1-step)</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>FID</strong></td>
          <td><strong>4.85</strong> (2-Rectified + Distill)</td>
          <td>8.91 (TDPM)</td>
          <td>Lower is better</td>
      </tr>
      <tr>
          <td><strong>Recall</strong></td>
          <td><strong>0.51</strong> (3-Rectified + Distill)</td>
          <td>0.49 (StyleGAN2+ADA)</td>
          <td>Higher is better</td>
      </tr>
  </tbody>
</table>
<h3 id="hardware">Hardware</h3>
<p>The paper does not specify GPU models or training times. The DDPM++ U-Net architecture used in the experiments typically requires multi-GPU setups for training on high-resolution datasets.</p>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/gnobitab/RectifiedFlow">RectifiedFlow (GitHub)</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official PyTorch implementation with CIFAR-10 and high-res training code, plus pre-trained checkpoints</td>
      </tr>
  </tbody>
</table>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Liu, X., Gong, C., &amp; Liu, Q. (2023). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow. <em>International Conference on Learning Representations (ICLR)</em>. <a href="https://openreview.net/forum?id=XVjTT1nw5z">https://openreview.net/forum?id=XVjTT1nw5z</a></p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{liuFlowStraightFast2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow {{Straight}} and {{Fast}}: {{Learning}} to {{Generate}} and {{Transfer Data}} with {{Rectified Flow}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Liu, Xingchao and Gong, Chengyue and Liu, Qiang}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=XVjTT1nw5z}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://github.com/gnobitab/RectifiedFlow">Official Code Repository</a></li>
<li><a href="https://openreview.net/forum?id=XVjTT1nw5z">OpenReview Page</a></li>
</ul>
]]></content:encoded></item><item><title>Flow Matching for Generative Modeling: Scalable CNFs</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/flow-matching-for-generative-modeling/</guid><description>A simulation-free framework for training Continuous Normalizing Flows using Conditional Flow Matching and Optimal Transport paths.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, as it introduces &ldquo;Flow Matching&rdquo; (FM), a novel simulation-free paradigm for training Continuous Normalizing Flows (CNFs) at scale. It is supported by a strong <strong>Theory</strong> basis, providing formal theorems that allow the intractable marginal vector field regression to be solved via a tractable conditional objective. It also touches on <strong>Systematization</strong> by showing that existing diffusion paths are specific instances of the proposed Gaussian probability path framework.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The paper aims to overcome the scaling limitations of Continuous Normalizing Flows (CNFs).</p>
<ul>
<li><strong>Problem</strong>: Standard Maximum Likelihood training for CNFs requires expensive numerical ODE simulations during training, which scales poorly. Existing simulation-free methods often involve intractable integrals or result in biased gradients.</li>
<li><strong>Gap</strong>: Diffusion models scale well, yet they are restricted to specific, curved probability paths (e.g., VP, VE) that can result in slow sampling and long training times.</li>
<li><strong>Goal</strong>: To develop an efficient, simulation-free training method for CNFs that supports arbitrary probability paths, specifically allowing for straighter, more efficient trajectories like those from Optimal Transport.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is <strong>Flow Matching (FM)</strong> and specifically the <strong>Conditional Flow Matching (CFM)</strong> objective.</p>
<ul>
<li><strong>Direct Vector Field Regression</strong>: The model regresses a target vector field $u_t$ that generates a desired probability path $p_t$.</li>
<li><strong>Conditional Flow Matching (CFM)</strong>: The authors prove that regressing the vector field of <em>conditional</em> paths (e.g., $p_t(x|x_1)$ given a single data point) yields the same gradients as regressing the intractable marginal vector field. This bypasses the need to know the marginal score or vector field.</li>
<li><strong>Optimal Transport Paths</strong>: The framework enables the use of <strong>Optimal Transport (OT)</strong> displacement interpolation for probability paths. OT paths are straight lines with constant speed, leading to faster training and easier sampling.</li>
</ul>
<p><strong>Concurrent work note</strong>: Rectified Flow (Liu et al., 2023) and Stochastic Interpolants (Albergo &amp; Vanden-Eijnden, 2023) were published concurrently at ICLR 2023 with structurally similar contributions under different names. All three independently propose simulation-free training of continuous flows via direct vector field regression; the differences lie in the specific interpolation schemes, theoretical framing, and experimental focus.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<ul>
<li><strong>Domains</strong>: 2D Checkerboard data, CIFAR-10, and ImageNet at resolutions $32 \times 32$, $64 \times 64$, and $128 \times 128$.</li>
<li><strong>Task</strong>: Unconditional generative modeling (density estimation and sample quality) and conditional super-resolution ($64 \times 64 \to 256 \times 256$).</li>
<li><strong>Baselines</strong>: Compared against Diffusion-based methods on the same architecture (U-Net): DDPM, Score Matching (SM), and ScoreFlow.</li>
<li><strong>Ablations</strong>: Specifically compared <strong>FM with Diffusion paths</strong> vs. <strong>FM with Optimal Transport (OT) paths</strong> to isolate the benefit of the training objective vs. the path choice.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Outperforms diffusion baselines</strong>: FM-OT consistently outperforms all diffusion-based methods (DDPM, Score Matching, ScoreFlow) in both Likelihood (NLL) and Sample Quality (FID) across CIFAR-10 and ImageNet, using the same U-Net architecture and training budget. Selected rows from Table 1 (NLL in bits per dimension, BPD; lower is better for all three metrics; &ldquo;FM w/ OT&rdquo; and &ldquo;FM w/ Diffusion&rdquo; refer to FM trained with OT paths and Diffusion paths respectively):</li>
</ul>
<table>
  <thead>
      <tr>
          <th>Dataset</th>
          <th>Method</th>
          <th>NLL (BPD) ↓</th>
          <th>FID ↓</th>
          <th>NFE ↓</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>CIFAR-10</td>
          <td>DDPM</td>
          <td>3.12</td>
          <td>7.48</td>
          <td>274</td>
      </tr>
      <tr>
          <td>CIFAR-10</td>
          <td>FM w/ OT</td>
          <td><strong>2.99</strong></td>
          <td><strong>6.35</strong></td>
          <td><strong>142</strong></td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>ScoreFlow</td>
          <td>3.36</td>
          <td>24.95</td>
          <td>601</td>
      </tr>
      <tr>
          <td>ImageNet 64×64</td>
          <td>FM w/ OT</td>
          <td><strong>3.31</strong></td>
          <td><strong>14.45</strong></td>
          <td><strong>138</strong></td>
      </tr>
  </tbody>
</table>
<ul>
<li><strong>Training stability</strong>: FM with diffusion paths (FM w/ Diffusion) is itself a more stable alternative to diffusion training than DDPM and Score Matching, as shown by training curves in the paper (Figure 5), even before switching to OT paths. The OT path then provides further gains.</li>
<li><strong>Sampling speed</strong>: The straight trajectories of OT paths allow accurate sampling with significantly fewer function evaluations (NFE) compared to diffusion paths.</li>
<li><strong>Generality</strong>: Diffusion is a specific instance of Gaussian probability paths within FM. OT paths are a better-optimized alternative available within the same framework.</li>
<li><strong>Downstream adoption</strong>: Flow matching has been adopted beyond image generation. <a href="/notes/biology/computational-biology/dynamicflow/">DynamicFlow</a> uses it as the generative backbone for simultaneously generating ligand molecules and transforming protein pockets, extending flow matching to structure-based drug design.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Datasets</strong>: CIFAR-10, ImageNet ($32 \times 32$, $64 \times 64$, $128 \times 128$).</li>
<li><strong>Preprocessing</strong>:
<ul>
<li>Images are center-cropped and resized.</li>
<li>For $32 \times 32$ and $64 \times 64$, the preprocessing follows Chrabaszcz et al. (2017).</li>
<li>Data is transformed via $\varphi(y) = 2^7(y+1)$ mapping $[-1, 1]$ pixel values to $[0, 256]$ for BPD computation.</li>
</ul>
</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<p><strong>1. Conditional Flow Matching (CFM) Objective</strong></p>
<p>The practical training objective used is the CFM loss, which bypasses intractable marginalization:</p>
<p>$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p(x_0)} | v_t(\psi_t(x_0)) - u_t(\psi_t(x_0) | x_1) |^2$$</p>
<p>Where $t \sim \mathcal{U}[0,1]$, $x_1 \sim q(x_1)$ (data), and $x_0 \sim p(x_0)$ (noise).</p>
<p><strong>2. Optimal Transport (OT) Probability Path</strong></p>
<p>The authors recommend the OT path for efficiency.</p>
<ul>
<li><strong>Mean/Std Schedule</strong>: $\mu_t(x) = t x_1$ and $\sigma_t(x) = 1 - (1 - \sigma_{min})t$.</li>
<li><strong>Conditional Flow Map</strong>: $\psi_t(x) = (1 - (1 - \sigma_{min})t)x + t x_1$.</li>
<li><strong>Target Vector Field</strong>: The closed-form regression target for OT is:
$$u_t(x|x_1) = \frac{x_1 - (1 - \sigma_{min})x}{1 - (1 - \sigma_{min})t}$$</li>
</ul>
<p><strong>3. Sampling</strong></p>
<p>Sampling is performed by solving the ODE $\frac{d}{dt}\phi_t(x) = v_t(\phi_t(x))$ from $t=0$ to $t=1$ using the learned vector field $v_t$.</p>
<ul>
<li><strong>Solver</strong>: <code>dopri5</code> (adaptive) is used for robust evaluation. Fixed-step solvers (Euler, Midpoint) are used for low-NFE efficiency tests.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Architecture</strong>: U-Net architecture from Dhariwal &amp; Nichol (2021) is used for all image experiments.</li>
<li><strong>Toy Data</strong>: 5-layer MLP with 512 neurons.</li>
<li><strong>Hyperparameters</strong>:
<ul>
<li>Optimizer: Adam ($\beta_1=0.9, \beta_2=0.999$, weight decay=0.0).</li>
<li>Learning Rate: Polynomial decay or constant (see Table 3 in paper).</li>
<li>$\sigma_{min}$: Set to a small value (e.g., $1e-5$).</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>:
<ul>
<li><strong>NLL (BPD)</strong>: Computed using the continuous change of variables formula, estimated via the Hutchinson trace estimator to bypass $O(d^3)$ divergence computation.</li>
<li><strong>FID</strong>: Frechet Inception Distance for sample quality.</li>
<li><strong>NFE</strong>: Number of Function Evaluations required by the solver.</li>
</ul>
</li>
<li><strong>Likelihood Computation</strong>: Requires solving an augmented ODE to track the log-density change:
$$\frac{d}{dt} \begin{bmatrix} \phi_t(x) \ f(t) \end{bmatrix} = \begin{bmatrix} v_t(\phi_t(x)) \ -\text{div}(v_t(\phi_t(x))) \end{bmatrix}$$</li>
</ul>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>CIFAR-10</strong>: 2 GPUs.</li>
<li><strong>ImageNet-32</strong>: 4 GPUs.</li>
<li><strong>ImageNet-64</strong>: 16 GPUs.</li>
<li><strong>ImageNet-128</strong>: 32 GPUs.</li>
<li><strong>Precision</strong>: Full 32-bit for CIFAR/IM-32; 16-bit mixed precision for IM-64/128.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/facebookresearch/flow_matching">flow_matching (PyTorch library)</a></td>
          <td>Code</td>
          <td>CC BY-NC 4.0</td>
          <td>Later official library from Meta; not the original experiment code</td>
      </tr>
  </tbody>
</table>
<p>The paper does not release the original training code or model weights used in the experiments. The <code>facebookresearch/flow_matching</code> library was released later as a general-purpose PyTorch implementation of flow matching algorithms. Standard benchmark datasets (CIFAR-10, ImageNet) are publicly available.</p>
<hr>
<h2 id="theoretical-notes-why-cfm-works">Theoretical Notes: Why CFM Works</h2>
<p>The paper relies on three key theorems to make training tractable.</p>
<p><strong>Theorem 1 (Marginal Generation)</strong>:</p>
<p>Marginalizing conditional vector fields $u_t(x|x_1)$ yields the correct marginal vector field $u_t(x)$ that generates the marginal probability path $p_t(x)$.</p>
<p>$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)} dx_1$$</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>To understand why this theorem holds, we have to look at the <strong>Continuity Equation</strong>, which is the fundamental partial differential equation (PDE) that links a probability density path $p_t$ to a vector field $u_t$.</p>
<p>A vector field $u_t$ is said to &ldquo;generate&rdquo; a probability path $p_t$ if and only if they satisfy the continuity equation:</p>
<p>$$\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$$</p>
<p>The proof of Theorem 1 relies on substituting the definitions of the marginal path and vector field into this equation to see if they balance out.</p>
<p><strong>Step-by-Step Proof:</strong></p>
<ol>
<li>
<p><strong>Start with the time derivative of the marginal path</strong>: We begin by differentiating the marginal probability path $p_t(x)$ with respect to time. By definition, the marginal path is the integral of the conditional paths over the data distribution:
$$\frac{\partial p_t(x)}{\partial t} = \frac{\partial}{\partial t} \int p_t(x|x_1) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Swap derivative and integral</strong>: Assuming standard regularity conditions (Leibniz Rule), we can move the time derivative inside the integral:
$$\frac{\partial p_t(x)}{\partial t} = \int \frac{\partial p_t(x|x_1)}{\partial t} q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Apply the Conditional Continuity Equation</strong>: This is the critical step. We know that the conditional vector field $u_t(x|x_1)$ generates the conditional path $p_t(x|x_1)$. Therefore, for every single sample $x_1$, the pair satisfies the continuity equation:
$$\frac{\partial p_t(x|x_1)}{\partial t} = -\nabla \cdot (p_t(x|x_1) u_t(x|x_1))$$</p>
<p>Substituting this into our integral gives:
$$\frac{\partial p_t(x)}{\partial t} = -\int \nabla \cdot (p_t(x|x_1) u_t(x|x_1)) q(x_1) dx_1$$</p>
</li>
<li>
<p><strong>Pull the Divergence out</strong>: Since the divergence operator ($\nabla \cdot$) acts on $x$ and the integral is over $x_1$, we can pull the divergence operator outside the integral (by linearity):
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot \left( \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1 \right)$$</p>
</li>
<li>
<p><strong>Match with the Marginal Vector Field Definition</strong>: Now, look at the term inside the parentheses. The paper defines the marginal vector field $u_t(x)$ specifically to make this term simpler. Rearranging the definition of $u_t(x)$ provided in the theorem:
$$p_t(x) u_t(x) = \int p_t(x|x_1) u_t(x|x_1) q(x_1) dx_1$$</p>
<p>Substitute $p_t(x) u_t(x)$ back into our equation from Step 4:
$$\frac{\partial p_t(x)}{\partial t} = -\nabla \cdot (p_t(x) u_t(x))$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: We have just shown that $\frac{\partial p_t(x)}{\partial t} + \nabla \cdot (p_t(x) u_t(x)) = 0$. This is exactly the continuity equation. Because the marginal path and the aggregated marginal vector field satisfy this equation, the vector field is proven to generate the path.</p></blockquote>
<p><strong>Theorem 2 (Gradient Equivalence)</strong>:</p>
<p>The intractable Flow Matching objective $\mathcal{L}_{FM}$ (which requires $u_t(x)$) has the <strong>same gradients</strong> as the tractable Conditional Flow Matching objective $\mathcal{L}_{CFM}$.</p>
<p>$$\nabla_\theta \mathcal{L}_{FM}(\theta) = \nabla_\theta \mathcal{L}_{CFM}(\theta)$$</p>
<p>This allows the model to learn the marginal vector field by only seeing conditional sample paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The reason Theorem 2 holds is that the &ldquo;Conditional Flow Matching&rdquo; (CFM) objective is essentially an unbiased estimator of the &ldquo;Flow Matching&rdquo; (FM) objective (up to a constant). When we average over all the conditional data points $x_1$, the &ldquo;cross-term&rdquo; in the loss function aligns perfectly with the marginal vector field.</p>
<p><strong>1. Expand the Loss Functions</strong></p>
<p>First, let&rsquo;s look at the squared error in both objectives. Recall that $v_t$ is our neural network (parameterized by $\theta$), $u_t$ is the intractable marginal target, and $u_t(x|x_1)$ is the tractable conditional target.</p>
<p>Expanding the squared norms:</p>
<ul>
<li>
<p><strong>FM Objective</strong>:
$$\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t, p_t(x)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x) + |u_t(x)|^2 \right]$$</p>
</li>
<li>
<p><strong>CFM Objective</strong>:
$$\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} \left[ |v_t(x)|^2 - 2v_t(x) \cdot u_t(x|x_1) + |u_t(x|x_1)|^2 \right]$$</p>
</li>
</ul>
<p><strong>Key Insight</strong>: When we take the gradient $\nabla_\theta$, the last term in both equations disappears because the targets ($u_t$) are independent of the network weights $\theta$. We only need to show that the expectations of the first two terms match.</p>
<p><strong>2. Matching the First Term ($|v_t(x)|^2$)</strong></p>
<p>This part is straightforward. The expectation of $|v_t(x)|^2$ is the same in both cases because of how the marginal density $p_t(x)$ is defined.</p>
<ul>
<li><strong>FM</strong>: averages over $p_t(x)$.</li>
<li><strong>CFM</strong>: averages over $p_t(x|x_1)q(x_1)$.</li>
</ul>
<p>Since $p_t(x) = \int p_t(x|x_1) q(x_1) dx_1$ (by definition), averaging over the joint distribution is mathematically identical to averaging over the marginal $p_t(x)$.</p>
<p><strong>3. Matching the Cross Term (The &ldquo;Trick&rdquo;)</strong></p>
<p>This is the critical part of the proof. We need to show that the interaction between the network and the marginal field equals the interaction between the network and the conditional field.</p>
<p><strong>The Goal</strong>: Show $\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$.</p>
<p><strong>The Proof</strong>:</p>
<ol>
<li>
<p>Start with the <strong>FM cross-term</strong> (marginal):
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)]$$</p>
</li>
<li>
<p>Substitute the definition of the marginal vector field $u_t(x)$ derived in <strong>Theorem 1</strong>:
$$u_t(x) = \int u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1$$</p>
</li>
<li>
<p>Plug this into the integral. The $p_t(x)$ terms cancel:
$$\mathbb{E}_{t, p_t(x)} [v_t(x) \cdot u_t(x)] = \int_t \int_x p_t(x) v_t(x) \cdot \left[ \int_{x_1} u_t(x|x_1) \frac{p_t(x|x_1) q(x_1)}{p_t(x)} dx_1 \right] dx$$</p>
</li>
<li>
<p>This simplifies to:
$$= \int_t \int_x \int_{x_1} v_t(x) \cdot u_t(x|x_1) p_t(x|x_1) q(x_1) dx_1 dx dt$$</p>
</li>
<li>
<p>This is exactly the definition of the expectation in the <strong>CFM objective</strong>:
$$= \mathbb{E}_{t, q(x_1), p_t(x|x_1)} [v_t(x) \cdot u_t(x|x_1)]$$</p>
</li>
</ol>
<p><strong>Conclusion</strong>: Because the expectations of all terms involving $\theta$ are identical, the gradients must be identical.</p>
<p>Intuitively, this works like <strong>Denoising Score Matching</strong> or <strong>Stochastic Gradient Descent</strong>: even though each individual conditional vector field $u_t(x|x_1)$ points to a specific data point $x_1$ (which may differ from the true marginal direction), the <em>average</em> of all these pulls equals the true marginal vector field $u_t(x)$.</p></blockquote>
<p><strong>Theorem 3 (Gaussian Conditional VFs)</strong>:</p>
<p>For any Gaussian probability path $p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$, the unique vector field generating it is available in closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p>This theorem allows explicitly defining targets for both Diffusion (curved) and Optimal Transport (straight) paths.</p>
<blockquote>
<p><strong>Understanding the Proof:</strong></p>
<p>The derivation of Theorem 3 comes from the direct relationship between a flow map $\psi_t$ and its generating vector field. Because we chose a specific, simple path (Gaussian), we can invert the flow map to find the vector field in closed form.</p>
<p><strong>1. Define the Flow Map $\psi_t$</strong></p>
<p>We start by defining the conditional probability path as a Gaussian:</p>
<p>$$p_t(x|x_1) = \mathcal{N}(x | \mu_t(x_1), \sigma_t(x_1)^2 I)$$</p>
<p>The simplest way to &ldquo;push&rdquo; a standard normal distribution (noise) $p_0 = \mathcal{N}(0, I)$ to this Gaussian is using an affine transformation (scaling and shifting). We define the flow map $\psi_t$ as:</p>
<p>$$\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$$</p>
<p>This map takes a noise sample $x_0$ and transforms it into a sample $x$ at time $t$.</p>
<p><strong>2. The Definition of a Generating Vector Field</strong></p>
<p>By definition, a vector field $u_t$ generates a flow $\psi_t$ if the vector field describes the instantaneous velocity of the flow at any point. Mathematically:</p>
<p>$$u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$$</p>
<p>Let $x = \psi_t(x_0)$ be the position of the particle at time $t$. We want to find $u_t(x)$.</p>
<p><strong>3. Invert the Flow Map</strong></p>
<p>To find $u_t(x)$, we must express the equation in terms of $x$ rather than $x_0$. Since our flow map is a simple affine transformation (multiply and add), it is easily invertible (assuming $\sigma_t(x_1) \neq 0$):</p>
<p>$$x_0 = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)}$$</p>
<p>We will call this inverse map $\psi_t^{-1}(x)$.</p>
<p><strong>4. Differentiate the Flow Map</strong></p>
<p>Now we calculate the left side of our definition equation (velocity): $\frac{d}{dt}\psi_t(x_0)$.</p>
<p>Taking the time derivative of $\psi_t(x_0) = \sigma_t(x_1) x_0 + \mu_t(x_1)$:</p>
<p>$$\frac{d}{dt}\psi_t(x_0) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>(Note: $\sigma&rsquo;_t$ and $\mu&rsquo;_t$ denote time derivatives).</p>
<p><strong>5. Substitute and Solve</strong></p>
<p>Now we combine everything. We know $u_t(\psi_t(x_0)) = \frac{d}{dt}\psi_t(x_0)$.</p>
<p>Substitute the result from Step 4 into this equation:</p>
<p>$$u_t(\psi_t(x_0)) = \sigma&rsquo;_t(x_1) x_0 + \mu&rsquo;_t(x_1)$$</p>
<p>This expresses the vector field in terms of the initial point $x_0$. We must express it in terms of the current point $x$. So, we plug in the inverse formula for $x_0$ derived in Step 3:</p>
<p>$$u_t(x|x_1) = \sigma&rsquo;_t(x_1) \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} + \mu&rsquo;_t(x_1)$$</p>
<p>Rearranging terms gives the final closed form:</p>
<p>$$u_t(x|x_1) = \frac{\sigma&rsquo;_t(x_1)}{\sigma_t(x_1)}(x - \mu_t(x_1)) + \mu&rsquo;_t(x_1)$$</p>
<p><strong>Why is this useful?</strong></p>
<p>This formula means that as long as you can define a mean schedule $\mu_t(x_1)$ and a standard deviation schedule $\sigma_t(x_1)$ (which is easy to do for both Diffusion and Optimal Transport), you immediately get the exact vector field target $u_t(x|x_1)$ needed to train your neural network, bypassing complex ODE solving or score matching approximations.</p></blockquote>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., &amp; Le, M. (2023). Flow Matching for Generative Modeling. <em>International Conference on Learning Representations (ICLR)</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{lipmanFlowMatchingGenerative2023,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Flow Matching for Generative Modeling}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#e6db74">{2023}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/2210.02747">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Building Normalizing Flows with Stochastic Interpolants</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</link><pubDate>Sun, 21 Dec 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/stochastic-interpolants/</guid><description>A continuous-time normalizing flow using stochastic interpolants and quadratic loss to bypass costly ODE backpropagation.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is primarily a <strong>Method</strong> paper, with significant <strong>Theory</strong> contributions.</p>
<p>The authors propose a specific algorithm (&ldquo;InterFlow&rdquo;) for constructing generative models based on continuous-time normalizing flows. The work is characterized by the derivation of a new training objective (a simple quadratic loss) that bypasses the computational bottlenecks of previous methods. It includes prominent baseline comparisons against continuous flow methods (FFJORD, OT-Flow) and diffusion models. The theoretical component establishes the validity of the interpolant density satisfying the continuity equation (a conservation law governing how probability mass flows) and bounds the Wasserstein-2 distance (a measure of transport cost between distributions, penalizing squared displacement) of the transport.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The primary motivation is to overcome the computational inefficiency of training Continuous Normalizing Flows (CNFs) using Maximum Likelihood Estimation (MLE). Standard CNF training requires backpropagating through numerical ODE solvers, which is costly and limits scalability.</p>
<p>Additionally, while score-based diffusion models (SDEs) have achieved high sample quality, they theoretically require infinite time integration and rely on specific noise schedules. The authors aim to establish a method that works strictly with Probability Flow ODEs on finite time intervals, retaining the flexibility to connect arbitrary densities without the complexity of SDEs or the cost of standard ODE adjoint methods.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>Stochastic Interpolant</strong> framework:</p>
<ul>
<li><strong>Explicit Interpolant Construction</strong>: The method defines a time-dependent interpolant $x_t = I_t(x_0, x_1)$ (e.g., trigonometric interpolation) that connects samples from the base density $\rho_0$ and target $\rho_1$.</li>
<li><strong>Simulation-Free Training</strong>: The velocity field $v_t(x)$ of the probability flow is learned by minimizing a simple quadratic objective: $G(\hat{v}) = \mathbb{E}[|\hat{v}_t(x_t)|^2 - 2\partial_t x_t \cdot \hat{v}_t(x_t)]$. Because $\partial_t I_t$ is known analytically from the interpolant definition, the expectation can be estimated by sampling $(x_0, x_1, t)$ directly. This avoids ODE integration during training (ODE integration is still required at inference).</li>
<li><strong>Decoupling Path and Optimization</strong>: The choice of path (interpolant) is separated from the optimization of the velocity field. MLE methods couple the path and objective.</li>
<li><strong>Connection to Score-Based Models</strong>: The authors show that for Gaussian base densities and trigonometric interpolants, the learned velocity field is explicitly related to the score function $\nabla \log \rho_t$, providing a theoretical bridge between CNFs and diffusion models.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors performed validation across synthetic, tabular, and image domains:</p>
<ul>
<li><strong>2D Density Estimation</strong>: Benchmarked on &ldquo;Checkerboard&rdquo;, &ldquo;8 Gaussians&rdquo;, and anisotropic curved densities to visualize mode coverage and transport smoothness.</li>
<li><strong>High-Dimensional Tabular Data</strong>: Evaluated on standard benchmarks (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) comparing Negative Log Likelihood (NLL) against FFJORD, OT-Flow, and others.</li>
<li><strong>Image Generation</strong>: Trained models on CIFAR-10 ($32 \times 32$), ImageNet ($32 \times 32$), and Oxford Flowers ($128 \times 128$) to test scalability.</li>
<li><strong>Ablations</strong>: Investigated optimizing the interpolant path itself (e.g., learning Fourier coefficients for the path) to approach optimal transport and minimize path length.</li>
</ul>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li><strong>Performance</strong>: The method matches or supersedes conventional ODE flows (like FFJORD) in terms of NLL while being significantly cheaper to train.</li>
<li><strong>Efficiency</strong>: The training cost per epoch is constant (simulation-free), whereas MLE-based ODE methods see growing costs as the dynamics become more complex.</li>
<li><strong>Scalability</strong>: The method successfully scales to $128 \times 128$ resolution on a single GPU, a resolution that prior ab-initio ODE flows had not demonstrated.</li>
<li><strong>Flexibility</strong>: The framework can connect <em>any</em> two arbitrary densities (e.g., connecting two different complex 2D distributions) without needing one to be Gaussian.</li>
<li><strong>Optimal Transport</strong>: For a fixed interpolant, minimizing $G(\hat{v})$ over the velocity field recovers the velocity for that specific path. Additionally optimizing over the interpolant family yields a solution to the Benamou-Brenier optimal transport problem.</li>
<li><strong>Limitations</strong>: The authors acknowledge that image FID scores trail dedicated diffusion models, noting that InterFlow was not optimized with standard training tricks such as exponential moving averages, truncation, or learning rate warm-ups. The framework&rsquo;s sample quality could likely improve with these additions.</li>
</ul>
<hr>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>Tabular Datasets</strong>: POWER (6D), GAS (8D), HEPMASS (21D), MINIBOONE (43D), BSDS300 (63D).
<ul>
<li>Training points range from ~30k (MINIBOONE) to ~1.6M (POWER).</li>
</ul>
</li>
<li><strong>Image Datasets</strong>:
<ul>
<li>CIFAR-10 ($32 \times 32$, 50k training points).</li>
<li>ImageNet ($32 \times 32$, ~1.28M training points).</li>
<li>Oxford Flowers ($128 \times 128$, ~315k training points).</li>
</ul>
</li>
<li><strong>Time Sampling</strong>: Time $t$ is sampled from a Beta distribution during training (reweighting) to focus learning near the target.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Interpolant</strong>: The primary interpolant used is trigonometric: $I_t(x_0, x_1) = \cos(\frac{\pi t}{2})x_0 + \sin(\frac{\pi t}{2})x_1$.
<ul>
<li>Alternative linear interpolant: $I_t = a_t x_0 + b_t x_1$.</li>
</ul>
</li>
<li><strong>Loss Function</strong>:
$$G(\hat{v}) = \mathbb{E}_{t, x_0, x_1}[|\hat{v}_t(x_t)|^2 - 2\partial_t I_t(x_0, x_1) \cdot \hat{v}_t(x_t)]$$
<ul>
<li>The expectation is amenable to empirical estimation using batches of $x_0, x_1, t$.</li>
</ul>
</li>
<li><strong>Sampling</strong>: Numerical integration using Dormand-Prince (Runge-Kutta 4/5).</li>
<li><strong>Optimization</strong>: SGD/Adam variants used for optimization.</li>
</ul>
<h3 id="models">Models</h3>
<ul>
<li><strong>Tabular Architectures</strong>:
<ul>
<li>Feed-forward networks with 4-5 hidden layers.</li>
<li>Hidden widths: 512 (POWER, GAS, HEPMASS, MINIBOONE) or 1024 (BSDS300).</li>
<li>Activation: ReLU (general) or ELU (BSDS300).</li>
</ul>
</li>
<li><strong>Image Architectures</strong>:
<ul>
<li>U-Net based on the DDPM implementation.</li>
<li>Dimensions: 256 hidden dimension.</li>
<li>Sinusoidal time embeddings used.</li>
</ul>
</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<ul>
<li><strong>Metrics</strong>: Negative Log Likelihood (NLL) in nats (tabular) or bits per dim (images), Frechet Inception Distance (FID) for images.</li>
<li><strong>Baselines</strong>: FFJORD, Glow, Real NVP, OT-Flow, ScoreFlow, DDPM.</li>
</ul>
<p><strong>Tabular NLL</strong> (nats, lower is better; Table 2 Left):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>POWER</th>
          <th>GAS</th>
          <th>HEPMASS</th>
          <th>MINIBOONE</th>
          <th>BSDS300</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>MADE</td>
          <td>3.08</td>
          <td>-3.56</td>
          <td>20.98</td>
          <td>15.59</td>
          <td>-148.85</td>
      </tr>
      <tr>
          <td>Real NVP</td>
          <td>-0.17</td>
          <td>-8.33</td>
          <td>18.71</td>
          <td>13.55</td>
          <td>-153.28</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>-0.17</td>
          <td>-8.15</td>
          <td>18.92</td>
          <td>11.35</td>
          <td>-155.07</td>
      </tr>
      <tr>
          <td>CPF</td>
          <td>-0.52</td>
          <td>-10.36</td>
          <td>16.93</td>
          <td>10.58</td>
          <td>-154.99</td>
      </tr>
      <tr>
          <td>NSP</td>
          <td>-0.64</td>
          <td>-13.09</td>
          <td>14.75</td>
          <td>9.67</td>
          <td>-157.54</td>
      </tr>
      <tr>
          <td>FFJORD</td>
          <td>-0.46</td>
          <td>-8.59</td>
          <td>14.92</td>
          <td>10.43</td>
          <td>-157.40</td>
      </tr>
      <tr>
          <td>OT-Flow</td>
          <td>-0.30</td>
          <td>-9.20</td>
          <td>17.32</td>
          <td>10.55</td>
          <td>-154.20</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>-0.57</strong></td>
          <td><strong>-12.35</strong></td>
          <td><strong>14.85</strong></td>
          <td><strong>10.42</strong></td>
          <td><strong>-156.22</strong></td>
      </tr>
  </tbody>
</table>
<p><strong>Image Generation NLL and FID</strong> (Table 2 Right; NLL in bits per dim, lower is better):</p>
<table>
  <thead>
      <tr>
          <th>Method</th>
          <th>CIFAR-10 NLL</th>
          <th>CIFAR-10 FID</th>
          <th>ImageNet-32 NLL</th>
          <th>ImageNet-32 FID</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>FFJORD</td>
          <td>3.40</td>
          <td>-</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Glow</td>
          <td>3.35</td>
          <td>-</td>
          <td>4.09</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM</td>
          <td>≤3.75</td>
          <td>3.17</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>DDPM++ (Song et al., 2021)</td>
          <td>≤3.37</td>
          <td>2.90</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>ScoreSDE (Song et al., 2021)</td>
          <td>2.99</td>
          <td>2.92</td>
          <td>-</td>
          <td>-</td>
      </tr>
      <tr>
          <td>VDM</td>
          <td>≤2.65</td>
          <td>7.41</td>
          <td>≤3.72</td>
          <td>-</td>
      </tr>
      <tr>
          <td>Soft Truncation</td>
          <td>2.88</td>
          <td>3.45</td>
          <td>3.85</td>
          <td>8.42</td>
      </tr>
      <tr>
          <td>ScoreFlow</td>
          <td>2.81</td>
          <td>5.40</td>
          <td>3.76</td>
          <td>10.18</td>
      </tr>
      <tr>
          <td><strong>Ours</strong></td>
          <td><strong>2.99</strong></td>
          <td><strong>10.27</strong></td>
          <td><strong>3.48</strong></td>
          <td><strong>8.49</strong></td>
      </tr>
  </tbody>
</table>
<p>Note: DDPM++ is from Song et al. (2021), the same work as ScoreSDE (it is the architecture optimized for VP/sub-VP SDEs). InterFlow matches ScoreSDE on CIFAR-10 NLL (2.99 bits per dim) while being simulation-free. FID is weaker than dedicated image models (10.27 vs 2.92 for ScoreSDE), reflecting the paper&rsquo;s primary focus on tractable likelihood rather than sample quality.</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Compute</strong>: All models were trained on a single NVIDIA A100 GPU.</li>
<li><strong>Training Time</strong>:
<ul>
<li>Tabular: $10^5$ steps.</li>
<li>Images: $1.5 \times 10^5$ to $6 \times 10^5$ steps.</li>
<li>Speedup: Demonstrated ~400x speedup compared to FFJORD on MiniBooNE dataset.</li>
</ul>
</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>lucidrains/denoising-diffusion-pytorch (link defunct)</td>
          <td>Code</td>
          <td>MIT</td>
          <td>Base U-Net architecture used for image experiments; original GitHub account no longer available</td>
      </tr>
  </tbody>
</table>
<p>No official code release accompanies this paper. All tabular datasets (POWER, GAS, HEPMASS, MINIBOONE, BSDS300) are publicly available from prior work. CIFAR-10 and ImageNet are standard public benchmarks. Oxford Flowers 102 is also publicly available. Hyperparameters and architectures are fully specified in Tables 3 and 4 of the paper.</p>
<hr>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Albergo, M. S., &amp; Vanden-Eijnden, E. (2023). Building Normalizing Flows with Stochastic Interpolants. <em>The Eleventh International Conference on Learning Representations</em>.</p>
<p><strong>Publication</strong>: ICLR 2023</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{albergoBuildingNormalizingFlows2022,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span> = <span style="color:#e6db74">{Building {{Normalizing Flows}} with {{Stochastic Interpolants}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span> = <span style="color:#e6db74">{The {{Eleventh International Conference}} on {{Learning Representations}}}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span> = <span style="color:#e6db74">{Albergo, Michael Samuel and {Vanden-Eijnden}, Eric}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span> = <span style="color:#ae81ff">2023</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span> = <span style="color:#e6db74">{https://openreview.net/forum?id=li7qeBbCR1t}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=li7qeBbCR1t">OpenReview</a></li>
<li><a href="https://arxiv.org/abs/2209.15571">arXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders: Beyond the Standard VAE</title><link>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/importance-weighted-autoencoders/</guid><description>The key difference between multi-sample VAEs and IWAEs: how log-of-averages creates a tighter bound on log-likelihood.</description><content:encoded><![CDATA[<p>If you&rsquo;ve worked with Variational Autoencoders (VAEs), you&rsquo;ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It&rsquo;s trained by taking <em>one</em> sample ($k=1$) from the recognition network to calculate the loss.</p>
<p>A natural question follows: &ldquo;What if I use more samples? Won&rsquo;t that make it better?&rdquo;</p>
<p>Using more samples improves performance when paired with the correct objective function. Averaging the loss over $k$ samples yields minimal gains. Changing the objective itself is where the real gain comes from. This post explores the difference between a &ldquo;multi-sample VAE&rdquo; and the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a model that uses the <em>same architecture</em> as a VAE but is trained with a different objective that optimizes a tighter bound on the log-likelihood.</p>
<p>All ideas here are based on the fantastic paper: <a href="https://arxiv.org/abs/1509.00519">&ldquo;Importance Weighted Autoencoders&rdquo;</a> by Burda, Grosse, and Salakhutdinov.</p>
<h2 id="the-two-ways-to-use-k-samples">The Two Ways to Use $k$ Samples</h2>
<p>Let&rsquo;s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.</p>
<h3 id="option-1-the-multi-sample-vae-the-naive-way">Option 1: The &ldquo;Multi-Sample VAE&rdquo; (The Naive Way)</h3>
<p>This is the most straightforward idea. For each input $x$ in our batch:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate the standard VAE $\mathcal{L}_1$ loss for <em>each</em> sample.</li>
<li>Average these 5 losses together.</li>
</ol>
<p>This is an <strong>average of logs</strong>. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is &ldquo;only slightly&rdquo; better. You&rsquo;re paying a 5x computational cost for a marginal gain because you&rsquo;re still optimizing the same &ldquo;loose&rdquo; $\mathcal{L}_1$ bound.</p>
<h3 id="option-2-the-importance-weighted-autoencoder-iwae-the-right-way">Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)</h3>
<p>The IWAE takes a different approach. For each input $x$:</p>
<ol>
<li>Draw 5 samples ($h_1, &hellip;, h_5$) from $q(h|x)$.</li>
<li>Calculate an &ldquo;importance weight&rdquo; $w_i$ for each sample.</li>
<li>Average these 5 <em>weights</em> together.</li>
<li>Take the <em>logarithm</em> of that average.</li>
</ol>
<p>This is a <strong>log of an average</strong>, and the difference matters.</p>
<h2 id="the-math-average-of-logs-vs-log-of-averages">The Math: Average-of-Logs vs. Log-of-Averages</h2>
<p>Let&rsquo;s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:</p>
<p>$$
\mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right]
$$</p>
<p>A <strong>multi-sample VAE</strong> simply gets a better estimate of this same value:</p>
<p>$$
\mathcal{L} _{\text{VAE}, k}(x) \approx  \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)}
$$</p>
<p>The <strong>IWAE</strong> objective, $\mathcal{L}_k$, is fundamentally different:</p>
<p>$$
\mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right]
$$</p>
<p>In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):</p>
<p>$$
\mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right)
$$</p>
<p>Because the logarithm is a concave function, Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is <em>always</em> greater than or equal to the &ldquo;average of logs.&rdquo;</p>
<p>$$
\mathcal{L}_k(x) \ge \mathcal{L}_1(x)
$$</p>
<p>This means the IWAE is optimizing a <strong>strictly tighter lower bound</strong> on the true log-likelihood of the data.</p>
<h2 id="why-does-this-log-of-average-matter">Why Does This &ldquo;Log-of-Average&rdquo; Matter?</h2>
<p>This mathematical property provides two practical benefits.</p>
<h3 id="1-better-density-estimation">1. Better Density Estimation</h3>
<p>Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve &ldquo;significantly higher log-likelihoods&rdquo; than VAEs.</p>
<h3 id="2-richer-latent-representations">2. Richer Latent Representations</h3>
<p>This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective &ldquo;harshly penalizes&rdquo; the model if its <em>one</em> sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be &ldquo;overly simplified&rdquo; to avoid bad samples, which can lead to many latent dimensions becoming inactive (the paper reports the number of &ldquo;active units&rdquo; per model).</p>
<p>The IWAE objective is more flexible. It only needs <em>one</em> of the $k$ samples to be good. This &ldquo;increased flexibility&rdquo; allows the model to learn far more complex posterior distributions and &ldquo;richer latent space representations.&rdquo; The paper&rsquo;s experiments confirm this, showing IWAEs learn to use many more &ldquo;active units&rdquo; in their latent space.</p>
<h2 id="what-this-looks-like-in-code-pytorch">What This Looks Like in Code (PyTorch)</h2>
<p>The implementation difference makes this crystal clear.</p>
<p>First, the &ldquo;k-sample&rdquo; trick: for a batch <code>x</code> of shape <code>[B, D]</code> and <code>k=5</code> samples, we repeat <code>x</code> to get <code>x_repeated</code> of shape <code>[B*k, D]</code>. We do all our forward passes on this large tensor.</p>
<h3 id="vae-multi-sample-k--1-loss">VAE (Multi-Sample, k &gt; 1) Loss</h3>
<p>Here, we can still use the analytical KL divergence, which is a big simplification.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># x_repeated has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># mu, logvar have shape [B*k, latent_dim]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_x has shape [B*k, 784]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># recon_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>recon_loss_all <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># kl_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># We use the simple, analytical KL term!</span>
</span></span><span style="display:flex;"><span>kl_loss_all <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> logvar <span style="color:#f92672">-</span> mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> logvar<span style="color:#f92672">.</span>exp(), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># total_loss_all shape: [B*k]</span>
</span></span><span style="display:flex;"><span>total_loss_all <span style="color:#f92672">=</span> recon_loss_all <span style="color:#f92672">+</span> kl_loss_all
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Just average all B*k losses. This is the &#34;average of logs&#34;.</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> total_loss_all<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="iwae-k--1-loss">IWAE (k &gt; 1) Loss</h3>
<p>Here, we must compute the exact log-probabilities of the <em>specific samples</em> we drew.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Helper function to compute log-prob of a sample from a Gaussian</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">log_prob_gaussian</span>(sample, mu, logvar):
</span></span><span style="display:flex;"><span>    const <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sample<span style="color:#f92672">.</span>shape[<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>log(<span style="color:#ae81ff">2</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>tensor(math<span style="color:#f92672">.</span>pi))
</span></span><span style="display:flex;"><span>    log_det <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(logvar, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    log_exp <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum((sample <span style="color:#f92672">-</span> mu)<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">/</span> torch<span style="color:#f92672">.</span>exp(logvar), dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> const <span style="color:#f92672">+</span> log_det <span style="color:#f92672">+</span> log_exp
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Get the 3 log-prob components ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># all have a first dimension of [B*k]</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 1. log p(x|h_i): Log-Reconstruction Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_x_given_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_x_given_h <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>F<span style="color:#f92672">.</span>binary_cross_entropy(recon_x, x_repeated, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 2. log p(h_i): Log-Prior Probability (under N(0, I))</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_p_h shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_p_h <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, <span style="color:#ae81ff">0.0</span>, <span style="color:#ae81ff">0.0</span>) <span style="color:#75715e"># mu=0, logvar=0</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># 3. log q(h_i|x): Log-Encoder Probability</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_q_h_given_x shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_q_h_given_x <span style="color:#f92672">=</span> log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- The Key Step ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Combine to get the log-importance-weight</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log_w shape: [B*k]</span>
</span></span><span style="display:flex;"><span>log_w <span style="color:#f92672">=</span> log_p_x_given_h <span style="color:#f92672">+</span> log_p_h <span style="color:#f92672">-</span> log_q_h_given_x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># Reshape to [B, k] to group samples by their original input</span>
</span></span><span style="display:flex;"><span>log_w_matrix <span style="color:#f92672">=</span> log_w<span style="color:#f92672">.</span>view(B, k) <span style="color:#75715e"># B is original batch size</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># This is the &#34;log of the average&#34;</span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)</span>
</span></span><span style="display:flex;"><span>log_iwae_bound_per_x <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>logsumexp(log_w_matrix, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>) <span style="color:#f92672">-</span> math<span style="color:#f92672">.</span>log(k)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># The objective is to MAXIMIZE this bound, so the loss is its negative</span>
</span></span><span style="display:flex;"><span>loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span>log_iwae_bound_per_x<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="the-critical-implementation-detail">The Critical Implementation Detail</h3>
<p>Notice the key difference in the final step:</p>
<ul>
<li><strong>VAE</strong>: <code>loss = total_loss_all.mean()</code> average of individual losses</li>
<li><strong>IWAE</strong>: <code>loss = -torch.logsumexp(log_w_matrix, dim=1).mean()</code> log of averaged weights</li>
</ul>
<p>This seemingly small change implements the fundamental mathematical difference between optimizing an &ldquo;average of logs&rdquo; versus a &ldquo;log of averages.&rdquo;</p>
<h2 id="when-to-use-each-approach">When to Use Each Approach</h2>
<table>
  <thead>
      <tr>
          <th>Model</th>
          <th>When to Use</th>
          <th>Key Benefit</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>VAE ($k=1$)</strong></td>
          <td>Your <strong>default baseline</strong>. It&rsquo;s fast, simple, and often &ldquo;good enough&rdquo; for many tasks.</td>
          <td>Speed and simplicity.</td>
      </tr>
      <tr>
          <td><strong>Multi-Sample VAE ($k&gt;1$)</strong></td>
          <td>When you want slightly more stable gradients but aren&rsquo;t ready for the full IWAE complexity.</td>
          <td>Marginal improvement with minimal code changes.</td>
      </tr>
      <tr>
          <td><strong>IWAE ($k&gt;1$)</strong></td>
          <td>When your baseline VAE is <strong>insufficient</strong>. Specifically, if you need:<br>1. The best possible log-likelihood.<br>2. To activate more latent dimensions or learn richer representations.</td>
          <td>Better performance and richer latents, at the cost of compute (scales linearly with $k$).</td>
      </tr>
  </tbody>
</table>
<h2 id="the-computational-trade-off">The Computational Trade-off</h2>
<p>Both approaches scale linearly with $k$. If you use $k=5$ samples, you&rsquo;re doing roughly 5x the computation. The question is whether you get 5x the benefit.</p>
<p>For multi-sample VAEs, the answer is usually &ldquo;no&rdquo;. You get more stable gradients but only marginal performance improvements.</p>
<p>For IWAEs, the answer is often &ldquo;yes&rdquo;. You get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.</p>
<h2 id="conclusion">Conclusion</h2>
<p>The next time you use more samples with your VAE, switch to the IWAE objective to get the full benefit of the computational cost of $k &gt; 1$.</p>
<p>The mathematical insight is simple but powerful: Jensen&rsquo;s Inequality tells us that the &ldquo;log of an average&rdquo; is always greater than or equal to the &ldquo;average of logs.&rdquo; By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.</p>
<p>The implementation requires computing exact log-probabilities to evaluate the specific samples. The result is a fundamentally more powerful model using the exact same architecture.</p>
<p><strong>Want to dive deeper?</strong> Check out the <a href="https://arxiv.org/abs/1509.00519">original IWAE paper</a> for experimental results and theoretical analysis, or explore my <a href="/posts/modern-variational-autoencoder-in-pytorch/">VAE tutorial</a> for hands-on implementation details.</p>
]]></content:encoded></item><item><title>Importance Weighted Autoencoders (IWAE) for Tighter Bounds</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/importance-weighted-autoencoders/</guid><description>Summary of Burda, Grosse &amp; Salakhutdinov's ICLR 2016 paper introducing Importance Weighted Autoencoders for tighter variational bounds</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces the <strong>Importance Weighted Autoencoder (IWAE)</strong>, a generative model that shares the same architecture as the Variational Autoencoder (VAE) but uses a different, tighter objective function. The key innovation is using importance weighting to derive a strictly tighter log-likelihood lower bound than the standard VAE objective.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The standard VAE has several limitations that motivated this work:</p>
<ul>
<li><strong>Strong assumptions</strong>: VAEs typically assume the posterior distribution is simple (e.g., approximately factorial) and that its parameters can be easily approximated from observations.</li>
<li><strong>Simplified representations</strong>: The VAE objective can force models to learn overly simplified representations that underutilize the network&rsquo;s full modeling capacity.</li>
<li><strong>Harsh penalization</strong>: The VAE objective harshly penalizes approximate posterior samples that are poor explanations for the data, which can be overly restrictive.</li>
<li><strong>Inactive units</strong>: VAEs tend to learn latent spaces with effective dimensions far below their capacity, where many latent units are ignored (a phenomenon later termed <strong>posterior collapse</strong>, where the approximate posterior collapses to the prior and conveys no information). The authors wanted to investigate whether a new objective could address this issue.</li>
</ul>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The core novelty is the <strong>IWAE objective function</strong>, denoted as $\mathcal{L}_{k}$.</p>
<ul>
<li>
<p><strong>VAE ($\mathcal{L}_{1}$ Bound)</strong>: The standard VAE maximizes $\mathcal{L}(x)=\mathbb{E}_{q(h|x)}[\log\frac{p(x,h)}{q(h|x)}]$. This is equivalent to the new bound when $k=1$.</p>
</li>
<li>
<p><strong>IWAE ($\mathcal{L}_{k}$ Bound)</strong>: The IWAE maximizes a tighter bound that uses $k$ samples drawn from the recognition model $q(h|x)$:</p>
</li>
</ul>
<p>$$\mathcal{L}_{k}(x)=\mathbb{E}_{h_{1},&hellip;,h_{k}\sim q(h|x)}\left[\log\frac{1}{k}\sum_{i=1}^{k}\frac{p(x,h_{i})}{q(h_{i}|x)}\right]$$</p>
<ul>
<li>
<p><strong>Tighter Bound</strong>: The authors prove that this bound is always tighter than or equal to the VAE bound ($\mathcal{L}_{k+1} \geq \mathcal{L}_{k}$) and that as $k$ approaches infinity, $\mathcal{L}_{k}$ approaches the true log-likelihood $\log p(x)$.</p>
</li>
<li>
<p><strong>Increased Flexibility</strong>: Using multiple samples gives the IWAE additional flexibility to learn generative models whose posterior distributions are complex and violate the VAE&rsquo;s simplifying assumptions.</p>
</li>
</ul>
<h3 id="key-concept-averaging-inside-vs-outside-the-log">Key Concept: Averaging Inside vs. Outside the Log</h3>
<p>A crucial distinction exists between how VAE and IWAE utilize $k$ samples. Understanding this difference explains why increasing $k$ in IWAE improves the bound. In VAE, it reduces variance.</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-vae-vs-importance-weighted-autoencoder-iwae.webp"
         alt="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         title="Flowchart comparing VAE and IWAE computation: VAE takes the log of each weight then averages (average of logs). IWAE averages the weights first then takes the log (log of average)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">VAE vs IWAE computation flow. The key difference is where the log operation occurs: VAE computes log(w_i) for each sample then averages. IWAE averages the weights first then applies log to the result.</figcaption>
    
</figure>

<p><strong>VAE (Average of Logs):</strong></p>
<p>For a VAE, using $k$ samples approximates:</p>
<p>$$\mathbb{E}\left[ \frac{1}{k} \sum_{i=1}^k \log w_i \right] \approx \text{ELBO}$$</p>
<p>where $w_i = p(x, h_i) / q(h_i | x)$. Increasing $k$ here only reduces the variance of the gradient estimator; the model still targets the same ELBO bound, so performance gains saturate quickly.</p>
<p><strong>IWAE (Log of Average):</strong></p>
<p>IWAE performs the averaging <em>inside</em> the logarithm:</p>
<p>$$\mathbb{E}\left[ \log \left( \frac{1}{k} \sum_{i=1}^k w_i \right) \right] = \mathcal{L}_k$$</p>
<p>By Jensen&rsquo;s Inequality ($\log(\mathbb{E}[X]) \geq \mathbb{E}[\log(X)]$ for concave functions), this bound is mathematically guaranteed to be at least as tight as the VAE bound. Each increase in $k$ defines a new, strictly tighter lower bound on the log-likelihood.</p>
<p><strong>Why This Matters for Gradients:</strong></p>
<p>In IWAE, the gradient weights are normalized importance weights $\tilde{w}_i = w_i / \sum_j w_j$. This means &ldquo;bad&rdquo; samples (those with low $w_i$) contribute very little to the gradient update since they vanish from the weighted sum. VAE uses unweighted samples, so a single sample with extremely low probability produces a massive negative log value that can dominate the loss and harshly penalize the model. IWAE&rsquo;s formulation allows the model to focus learning on the samples that explain the data well.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The authors compared VAE and IWAE on density estimation tasks using the MNIST and Omniglot datasets. They evaluated two main network architectures: one with a single stochastic layer and another with two stochastic layers. The models were trained with varying numbers of importance samples ($k \in {1, 5, 50}$) to observe the effect on performance and latent space utilization. The primary metrics for evaluation were the test log-likelihood (estimated using 5000 samples) and the number of &ldquo;active&rdquo; latent units, which quantifies the richness of the learned representations.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-vs-vae-active-latent-units-comparison.webp"
         alt="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         title="Bar chart comparing active latent units between VAE and IWAE across different k values on MNIST and Omniglot datasets"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Active latent units for VAE vs IWAE (1 stochastic layer). VAE active units remain flat. IWAE increases with k. Data from Table 1 of Burda et al. (2016).</figcaption>
    
</figure>

<ul>
<li>
<p><strong>Better Performance</strong>: IWAE achieved higher log-likelihoods than VAEs across all configurations. On MNIST with two stochastic layers and $k=50$, IWAE reached $-82.90$ nats compared to $-84.78$ for VAE. On Omniglot, the best IWAE achieved $-103.38$ nats versus $-106.30$ for VAE. IWAE performance improved consistently with increasing $k$, while VAE performance benefited only slightly from using more samples ($k&gt;1$).</p>
</li>
<li>
<p><strong>Richer Representations</strong>: In all experiments with $k&gt;1$, IWAE learned more active latent dimensions than VAE, suggesting richer latent representations.</p>
</li>
<li>
<p><strong>Objective Drives Representation</strong>: The authors found that latent dimension inactivation is driven by the objective function. They demonstrated this through an &ldquo;objective swap&rdquo; experiment:</p>
</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/notes/iwae-objective-swap-experiment.webp"
         alt="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         title="Bar charts showing the objective swap experiment results with active units and NLL changes when switching between VAE and IWAE objectives"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Objective swap experiment on MNIST (1 stochastic layer). Switching a trained VAE to the IWAE objective improves both metrics. Switching IWAE to VAE degrades them. Data from Table 2 of Burda et al. (2016).</figcaption>
    
</figure>

<p>This experiment provides evidence that the objective function itself influences latent utilization:</p>
<ul>
<li><strong>VAE → IWAE</strong>: A converged VAE model, when fine-tuned with the IWAE objective ($k=50$), gained 3 active units (19 → 22) and improved test NLL from 86.76 to 84.88.</li>
<li><strong>IWAE → VAE</strong>: A converged IWAE model fine-tuned with the VAE objective lost 2 active units (25 → 23) and worsened test NLL from 84.78 to 86.02.</li>
</ul>
<p>These results strongly suggest that inactivation of latent dimensions is driven by the objective function rather than by optimization dynamics, initialization, or architecture. The authors note that optimization also plays a role, as the swap results do not exactly match training from scratch.</p>
<ul>
<li>
<p><strong>Comparison to Other Models</strong>: On MNIST, the best IWAE ($-82.90$ nats) outperformed deep belief networks ($-84.55$ nats) and deep autoregressive networks ($-84.13$ nats), though DRAW ($-80.97$ nats), which exploits spatial structure, achieved better results. On Omniglot, the best IWAE ($-103.38$ nats) fell slightly behind RBMs trained with persistent contrastive divergence ($-100.46$ nats).</p>
</li>
<li>
<p><strong>Conclusion</strong>: IWAEs learn richer latent representations and achieve better generative performance than VAEs with equivalent architectures and training time.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<ul>
<li><strong>MNIST</strong>: $28 \times 28$ binarized handwritten digits (60,000 training / 10,000 test).</li>
<li><strong>Omniglot</strong>: $28 \times 28$ binarized handwritten characters from various alphabets (24,345 training / 8,070 test).</li>
<li><strong>Binarization</strong>: Dynamic sampling where binary values are sampled with expectations equal to the real pixel intensities (following Salakhutdinov &amp; Murray, 2008).</li>
<li><strong>Fixed Binarization</strong>: Results on a fixed binarization of MNIST (Larochelle, 2011) confirm that IWAE outperforms VAE across preprocessing methods. It exhibits notably more overfitting compared to dynamic sampling.</li>
</ul>
<h3 id="models">Models</h3>
<p>Two main network architectures were tested:</p>
<ol>
<li>One stochastic layer (50 units) with two deterministic layers (200 units each).</li>
<li>Two stochastic layers (100 and 50 units). Between x and h1 were two deterministic layers with 200 units each. Between h1 and h2 were two deterministic layers with 100 units each.</li>
</ol>
<ul>
<li><strong>Activations</strong>: <code>tanh</code> for deterministic layers; <code>exp</code> applied to variance predictions to ensure positivity.</li>
<li><strong>Distributions</strong>: Gaussian latent layers with diagonal covariance; Bernoulli observation layer.</li>
<li><strong>Initialization</strong>: Glorot &amp; Bengio (2010) heuristic.</li>
</ul>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Optimizer</strong>: Adam ($\beta_1=0.9$, $\beta_2=0.999$, $\epsilon=10^{-4}$).</li>
<li><strong>Batch Size</strong>: 20.</li>
<li><strong>Learning Rate Schedule</strong>: Annealed rate of $0.001 \cdot 10^{-i/7}$ for $3^i$ epochs (where $i=0&hellip;7$), totaling 3,280 passes over the data.</li>
<li><strong>Variance Control</strong>: A common concern with importance sampling is high variance. The authors prove that the Mean Absolute Deviation of their estimator is bounded by $2 + 2\delta$, where $\delta$ is the gap between the bound and true log-likelihood. As the bound tightens, variance remains controlled.</li>
<li><strong>Computational trick</strong>: In the basic IWAE implementation, both forward and backward passes must be done independently for each of the $k$ samples, so the cost scales linearly with $k$. However, the authors describe an optional optimization: stochastically approximate the gradient sum by sampling a single $\epsilon_i$ proportional to its normalized weight $\tilde{w}_i$, then computing only that one backward pass. This reduces the cost to $k$ forward passes and one backward pass. Since the backward pass costs roughly twice the forward pass, this yields approximately a 3x speedup for large $k$ at the cost of increased gradient variance.</li>
</ul>
<p><strong>Relationship to Reweighted Wake-Sleep (RWS):</strong> Both IWAE and Reweighted Wake-Sleep (Bornschein &amp; Bengio, 2015) use importance-weighted samples and have closely related generative model updates. The key difference is that IWAE derives a single unified lower bound $\mathcal{L}_k$ and uses the reparameterization trick to train the recognition network jointly. RWS instead uses separate wake and sleep phases for the recognition network, which are not derived from $\mathcal{L}_k$.</p>
<h3 id="evaluation">Evaluation</h3>
<ol>
<li><strong>Test Log-Likelihood</strong>: Primary measure of generative performance, estimated as the mean of $\mathcal{L}_{5000}$ (5000 samples) on the test set.</li>
<li><strong>Active Units</strong>: To quantify latent space richness, the authors measured &ldquo;active&rdquo; latent dimensions. A unit $u$ was defined as active if its activity statistic $A_{u}=\text{Cov}_{x}(\mathbb{E}_{u\sim q(u|x)}[u])$ exceeded $10^{-2}$. The $10^{-2}$ threshold is justified by a bimodal distribution of the log activity statistic, showing clear separation between active and inactive units.</li>
</ol>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: GPU-based implementation using mini-batch replication to parallelize the $k$ samples. Specific GPU type and training times are not reported.</li>
</ul>
<h3 id="artifacts">Artifacts</h3>
<table>
  <thead>
      <tr>
          <th>Artifact</th>
          <th>Type</th>
          <th>License</th>
          <th>Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><a href="https://github.com/yburda/iwae">yburda/iwae</a></td>
          <td>Code</td>
          <td>Unknown</td>
          <td>Official Theano implementation for MNIST and Omniglot</td>
      </tr>
  </tbody>
</table>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Burda, Y., Grosse, R., &amp; Salakhutdinov, R. (2016). Importance Weighted Autoencoders. <em>International Conference on Learning Representations (ICLR) 2016</em>. <a href="https://arxiv.org/abs/1509.00519">https://arxiv.org/abs/1509.00519</a></p>
<p><strong>Publication</strong>: ICLR 2016</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{burda2016importance,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Importance Weighted Autoencoders}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Yuri Burda and Roger Grosse and Ruslan Salakhutdinov}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{International Conference on Learning Representations}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2016}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1509.00519}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://arxiv.org/abs/1509.00519">ArXiv</a></li>
</ul>
]]></content:encoded></item><item><title>Auto-Encoding Variational Bayes: VAE Paper Summary</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</link><pubDate>Wed, 05 Nov 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/autoencoding-variational-bayes/</guid><description>Summary of Kingma &amp; Welling's 2013 VAE paper introducing the reparameterization trick and variational autoencoders.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>Method</strong> paper that introduces a generative mechanism (the VAE) and an optimization technique (the reparameterization trick), with formal theoretical derivation. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the <strong>variational auto-encoder (VAE)</strong> when neural networks are used as the recognition model.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The authors address two central intractabilities in directed probabilistic models with continuous latent variables:</p>















<figure class="post-figure center ">
    <img src="/img/notes/autoencoding-variational-bayes-figure-1-model-diagram.webp"
         alt="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         title="VAE graphical model showing latent variable z, observed variable x, and parameters phi and theta"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Figure 1 from the paper: The directed graphical model. Solid lines denote the generative model $p_\theta(z)p_\theta(x|z)$, dashed lines denote the variational approximation $q_\phi(z|x)$. The variational parameters $\phi$ are learned jointly with the generative parameters $\theta$.</figcaption>
    
</figure>

<ol>
<li>
<p><strong>Intractable Posteriors</strong>: In models with continuous latent variables (like those with non-linear hidden layers), the true posterior $p_{\theta}(z|x)$ cannot be calculated analytically, preventing the use of standard EM algorithms.</p>
</li>
<li>
<p><strong>Large Datasets</strong>: Sampling-based solutions like Monte Carlo EM (MCEM) require expensive sampling loops per datapoint. This makes them too slow for large datasets where batch optimization is too costly and efficient minibatch updates are required.</p>
</li>
</ol>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<h3 id="the-reparameterization-trick-sgvb-estimator">The Reparameterization Trick (SGVB Estimator)</h3>
<p>The core innovation is the <strong>Stochastic Gradient Variational Bayes (SGVB)</strong> estimator. The authors solve the high variance of standard gradient estimation by &ldquo;reparameterizing&rdquo; the random variable $\tilde{z}$.</p>
<p>They express $z$ as a deterministic function of the input $x$ and an auxiliary noise variable $\epsilon$:</p>
<p>$$\tilde{z} = g_{\phi}(\epsilon, x) \quad \text{with} \quad \epsilon \sim p(\epsilon)$$</p>















<figure class="post-figure center ">
    <img src="/img/notes/variational-autoencoder-reparameterization-trick.webp"
         alt="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         title="Comparison of standard stochastic node vs reparameterization trick showing gradient flow"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reparameterization trick. (A) Standard stochastic nodes block gradient flow during backpropagation. (B) By expressing $z = \mu + \sigma \odot \epsilon$ with external noise $\epsilon \sim \mathcal{N}(0,1)$, gradients can flow through the deterministic path to the parameters $\phi$.</figcaption>
    
</figure>

<ul>
<li><strong>Mechanism</strong>: For a Gaussian posterior, $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim \mathcal{N}(0, I)$.</li>
<li><strong>Impact</strong>: This makes the Monte Carlo estimate differentiable with respect to the variational parameters $\phi$, allowing the variational lower bound to be optimized via standard stochastic gradient ascent (like SGD or Adagrad).</li>
</ul>
<h3 id="the-aevb-algorithm-the-vae">The AEVB Algorithm (The VAE)</h3>
<p>The <strong>Auto-Encoding VB (AEVB)</strong> algorithm amortizes inference by learning a global recognition model (encoder) $q_{\phi}(z|x)$ jointly with the generative model (decoder) $p_{\theta}(x|z)$.</p>
<p><strong>Objective Function</strong>: Maximize the variational lower bound $\mathcal{L}(\theta, \phi; x^{(i)})$:</p>
<p>$$\mathcal{L} \simeq -D_{KL}(q_\phi(z|x^{(i)}) | p_\theta(z)) + \frac{1}{L} \sum_{l=1}^L \log p_\theta(x^{(i)}|z^{(i,l)})$$</p>
<ul>
<li><strong>First Term (Regularizer)</strong>: Forces the approximate posterior to match the prior (integrated analytically for Gaussians).</li>
<li><strong>Second Term (Reconstruction Error)</strong>: The expected negative reconstruction error (estimated via sampling).</li>
</ul>
<p>This mirrors the standard auto-encoder objective, adding a variational regularizer.</p>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was benchmarked against the <strong>Wake-Sleep</strong> algorithm and <strong>Monte Carlo EM (MCEM)</strong> using the <strong>MNIST</strong> (digits) and <strong>Frey Face</strong> (continuous faces) datasets.</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<ul>
<li>
<p><strong>Efficiency</strong>: AEVB converged faster and reached a better lower bound than Wake-Sleep (Figure 2). It scaled efficiently to the full MNIST dataset. MCEM&rsquo;s per-datapoint sampling cost made it impractical at full dataset scale, so comparisons were limited to small subsets (Figure 3).</p>
</li>
<li>
<p><strong>Regularization</strong>: The KL-divergence term provided a regularizing effect, preventing overfitting while increasing latent dimensions ($N_z$).</p>
</li>
<li>
<p><strong>Manifold Learning</strong>: The model successfully learned smooth 2D latent manifolds (visualized in Appendix A), grouping similar digits/faces together.</p>
</li>
</ul>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<h3 id="data">Data</h3>
<p><strong>Evaluation Data</strong>: For the marginal likelihood comparison (Figure 3), the paper used MNIST with $N_{\text{train}} = 100$ and $N_{\text{train}} = 5000$ to compare data efficiency (marginal log-likelihood vs. training samples seen) across algorithms. A smaller network (100 hidden units, 3 latent variables) was used for this comparison because the marginal likelihood estimator only works reliably in low-dimensional latent spaces.</p>
<h3 id="algorithms">Algorithms</h3>
<ul>
<li><strong>Algorithm</strong>: Stochastic gradient ascent with <strong>Adagrad</strong> (global stepsizes chosen from ${0.01, 0.02, 0.1}$).</li>
<li><strong>Regularization</strong>: The objective included a weight decay term corresponding to a prior $p(\theta)=\mathcal{N}(0,I)$.</li>
<li><strong>Minibatches</strong>: Size $M=100$ with $L=1$ sample per datapoint.</li>
<li><strong>Initialization</strong>: Parameters sampled from $\mathcal{N}(0, 0.01)$.</li>
</ul>
<h3 id="models">Models</h3>
<p>The original VAE used simple Multi-Layered Perceptrons (MLPs):</p>
<ul>
<li><strong>Symmetry</strong>: The encoder and decoder were symmetric, having an equal number of hidden units.</li>
<li><strong>Hidden Units</strong>: 500 units for MNIST, 200 for Frey Face (to prevent overfitting on the smaller dataset).</li>
<li><strong>Activations</strong>: <strong>Tanh</strong> activation functions for the hidden layers.</li>
<li><strong>Latent Space</strong>: Experimented with $N_z$ ranging from 2 to 200.</li>
<li><strong>Outputs</strong>:
<ul>
<li><em>MNIST</em>: <strong>Bernoulli</strong> MLP (sigmoid output).</li>
<li><em>Frey Face</em>: <strong>Gaussian</strong> MLP, with means constrained to $(0,1)$ via sigmoid.</li>
</ul>
</li>
<li><strong>Encoder Architecture</strong>: For the Gaussian encoder, the mean $\mu$ and log-variance $\log(\sigma^2)$ are linear outputs from the shared hidden layer (they share the hidden layer weights and have separate output weights).</li>
<li><strong>Log-Variance</strong>: The encoder predicted $\log(\sigma^2)$ for numerical stability.</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<p>The paper distinguishes between two metrics:</p>
<ul>
<li><strong>Variational Lower Bound</strong>: Used as the training objective (what the model optimizes).</li>
<li><strong>Marginal Likelihood</strong>: Used for final evaluation (Figure 3). The true marginal likelihood $p_\theta(x)$ was estimated using an Importance Sampling estimator constructed from samples drawn via Hybrid Monte Carlo (HMC), as detailed in Appendix D. This estimator uses: $p_{\theta}(x^{(i)}) \simeq (\frac{1}{L}\sum \frac{q(z)}{p(z)p(x|z)})^{-1}$.</li>
</ul>
<p>This distinction is critical: the training metric (lower bound) differs from the evaluation metric (estimated marginal likelihood).</p>
<h3 id="hardware">Hardware</h3>
<ul>
<li><strong>Hardware</strong>: Trained on a standard Intel Xeon CPU (approx. 40 GFLOPS); no GPUs were used.</li>
<li><strong>Training Time</strong>: Approximately 20-40 minutes per million training samples.</li>
</ul>
<h3 id="key-implementation-details-from-appendices">Key Implementation Details from Appendices</h3>
<ul>
<li><strong>Appendix A</strong>: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets.</li>
<li><strong>Appendix B</strong>: Closed-form solution for the KL divergence of two Gaussians, essential for implementing the efficient version of the estimator (Equation 10).</li>
<li><strong>Appendix C</strong>: Exact MLP equations, including the use of tanh hidden layers and specific output layers for Bernoulli vs. Gaussian data. Includes specifications for <strong>Bernoulli MLPs</strong> (binary data) and <strong>Gaussian MLPs</strong> (real-valued data).</li>
<li><strong>Appendix D</strong>: Marginal likelihood estimation protocol using Hybrid Monte Carlo (HMC) and importance sampling for evaluation (Figure 3).</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Diederik P. Kingma and Max Welling. &ldquo;Auto-Encoding Variational Bayes.&rdquo; arXiv:1312.6114 [stat.ML], 2013. <a href="https://doi.org/10.48550/arXiv.1312.6114">https://doi.org/10.48550/arXiv.1312.6114</a></p>
<p><strong>Publication</strong>: ICLR 2014 (arXiv preprint December 2013)</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@misc</span>{kingma2022autoencodingvariationalbayes,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{Auto-Encoding Variational Bayes}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Diederik P Kingma and Max Welling}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2013}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">eprint</span>=<span style="color:#e6db74">{1312.6114}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">archivePrefix</span>=<span style="color:#e6db74">{arXiv}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">primaryClass</span>=<span style="color:#e6db74">{stat.ML}</span>,
</span></span><span style="display:flex;"><span>      <span style="color:#a6e22e">url</span>=<span style="color:#e6db74">{https://arxiv.org/abs/1312.6114}</span>,
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://en.wikipedia.org/wiki/Variational_autoencoder">Wikipedia: Variational Autoencoder</a> - General overview</li>
<li><a href="https://openreview.net/forum?id=33X9fd2-9FyZd">OpenReview</a> - Original peer review with author responses</li>
<li><a href="/posts/modern-variational-autoencoder-in-pytorch/">Modern VAE in PyTorch</a> - Implementation tutorial on this site</li>
</ul>
]]></content:encoded></item><item><title>Contrastive Learning for Variational Autoencoder Priors</title><link>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</link><pubDate>Sun, 17 Aug 2025 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/notes/machine-learning/generative-models/contrastive-learning-for-vae-priors/</guid><description>Aneja et al.'s NeurIPS 2021 paper introducing Noise Contrastive Priors (NCPs) to address VAE's 'prior hole' problem with energy-based priors.</description><content:encoded><![CDATA[<h2 id="what-kind-of-paper-is-this">What kind of paper is this?</h2>
<p>This is a <strong>method paper</strong> that introduces a training approach for Variational Autoencoders (VAEs) to address fundamental limitations in their generative quality through improved prior learning.</p>
<h2 id="what-is-the-motivation">What is the motivation?</h2>
<p>The work is motivated by a critical limitation in Variational Autoencoders known as the <strong>&ldquo;prior hole&rdquo; problem</strong>, where the prior distribution p(z) fails to match the aggregate approximate posterior q(z). This mismatch leads to areas in the latent space with high density under the prior that don&rsquo;t map to realistic data samples, resulting in poor generative quality compared to GANs and other generative models.</p>















<figure class="post-figure center ">
    <img src="/img/notes/vae-prior-hole-problem-illustrated.webp"
         alt="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         title="Visualization of the VAE prior hole problem showing a ring-shaped aggregate posterior q(z) with an empty center, while the standard Gaussian prior p(z) has highest density at the center where no data exists"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The &lsquo;prior hole&rsquo; problem: the standard Gaussian prior (red dashed contours) assigns highest probability to the center, but the aggregate posterior (blue dots) forms a ring with no data in that region.</figcaption>
    
</figure>

<p>The figure above illustrates this mismatch. The blue dots represent where a trained encoder actually places data in the latent space (the aggregate posterior $q(z)$), which often forms complex, non-Gaussian shapes. The red dashed contours show the standard Gaussian prior $p(z) = \mathcal{N}(0, I)$, which assumes data is centered at the origin. When generating new samples, we draw from this prior, making it likely to sample from the empty &ldquo;hole&rdquo; where the decoder has never seen training data, producing unrealistic outputs.</p>
<p>A natural question arises: the prior $p(z)$ is used for <em>sampling</em> at inference time, so why does learning a better prior also improve <em>likelihood</em> (NLL)? The answer lies in the VAE objective. VAEs maximize the Evidence Lower Bound (ELBO):</p>
<p>$$ \log p(x) \geq \mathcal{L}_{\text{ELBO}}(x) = \underbrace{\mathbb{E}_{q(z|x)}[\log p(x|z)]}_{\text{Reconstruction}} - \underbrace{\text{KL}(q(z|x) \parallel p(z))}_{\text{Regularization}} $$</p>
<p>The KL divergence term penalizes the mismatch between each data point&rsquo;s approximate posterior $q(z|x)$ and the prior $p(z)$. When the prior is a simple Gaussian but the aggregate posterior forms a complex shape (as in the figure above), this KL term remains unnecessarily high for every data point.</p>
<p>By replacing the simple prior with a learned $p_{\text{NCP}}(z)$ that matches the aggregate posterior, the KL penalty decreases, tightening the ELBO and improving NLL. The learned prior thus provides a <strong>unified solution</strong>: better likelihood during training (tighter bound) and better sampling at inference (no &ldquo;holes&rdquo;).</p>
<p>The OpenReview discussion contains a significant theoretical debate regarding the paper&rsquo;s core premise. Reviewers argued that the &ldquo;prior hole&rdquo; problem is actually a failure of the posterior to match the prior, or a failure of the encoder. The authors defended their approach by noting that even with a perfect posterior, a simple Normal prior might fail because the decoder lacks capacity to map a simple distribution to complex data without dropping modes. This justifies fixing the prior by making it learned and complex.</p>
<h2 id="what-is-the-novelty-here">What is the novelty here?</h2>
<p>The authors propose an <strong>energy-based model (EBM) prior</strong> that is trained using <strong>Noise Contrastive Estimation (NCE)</strong>, which they term a <strong>Noise Contrastive Prior (NCP)</strong>. The key innovations are:</p>
<ul>
<li><strong>Two-Stage Training Process</strong>: First, a standard VAE is trained with a simple base prior. Then, the VAE weights are frozen and a binary classifier learns to distinguish between samples from the aggregate posterior q(z) and the base prior p(z).</li>
<li><strong>Reweighting Strategy</strong>: The core idea is to reweight a base prior distribution p(z) with a learned reweighting factor r(z) to make the resulting prior $p_{\text{NCP}}(z)$ better match the aggregate posterior q(z).</li>
<li><strong>NCE for EBM Training</strong>: The method frames EBM training as a binary classification task to avoid computationally expensive MCMC sampling.</li>
<li><strong>Scalability to Hierarchical Models</strong>: For hierarchical VAEs with multiple latent groups, the NCP approach can be applied independently and in parallel to each group&rsquo;s conditional prior.</li>
</ul>
<h2 id="what-experiments-were-performed">What experiments were performed?</h2>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling</li>
<li><strong>CelebA 64x64</strong>: Applied to both standard VAE architectures and more advanced VAEs with GMM priors (RAE model)</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>The hierarchical NVAE models used 30 latent groups for CIFAR-10 and CelebA-64, 20 groups for CelebA-HQ-256, and 10 groups of $4 \times 4$ latent variables for MNIST (deliberately small to enable reliable partition function estimation). The experiments compared FID scores, likelihood metrics, and qualitative sample quality between baseline VAEs and NCP-enhanced versions, with particular focus on hierarchical VAEs (NVAE).</p>
<h2 id="what-outcomesconclusions">What outcomes/conclusions?</h2>
<p>The proposed NCP method demonstrated improvements in generative quality across evaluated datasets, with modest gains on standard VAEs and particularly large gains on hierarchical models like NVAE:</p>
<ul>
<li><strong>CelebA-64</strong>: NCP improved FID scores from 48.12 to 41.28 for standard VAEs, and from 40.95 to 39.00 for RAE models with GMM priors.</li>
<li><strong>Hierarchical Models (NVAE)</strong>: The impact was particularly pronounced on hierarchical VAEs:
<ul>
<li><strong>CIFAR-10</strong>: FID improved from 51.71 to 24.08</li>
<li><strong>CelebA-64</strong>: FID improved from 13.48 to 5.25, making it competitive with GANs</li>
<li><strong>CelebA HQ 256x256</strong>: FID reduced from 40.26 to 24.79</li>
</ul>
</li>
<li><strong>Likelihood Performance</strong>: On MNIST, NCP-VAE achieved 78.10 nats NLL vs. baseline NVAE&rsquo;s 78.67 nats</li>
</ul>
<p>On CIFAR-10 and CelebA-HQ-256, the concurrent VAEBM method (which forms an EBM on the data space rather than the latent space) outperforms NCP-VAE. However, the authors argue the two approaches are complementary: NCP-VAE targets the latent space while VAEBM operates in data space, and combining them could yield further gains. NCP-VAE also has the advantage of applicability to discrete data (e.g., binarized MNIST) and simpler setup since it only requires training binary classifiers rather than MCMC-based training and sampling.</p>
<p>The key conclusions are that <strong>two-stage training with noise contrastive estimation</strong> provides an effective framework for learning expressive energy-based priors that addresses the prior hole problem while scaling efficiently to hierarchical models.</p>
<h2 id="reproducibility-details">Reproducibility Details</h2>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Artifact</th>
          <th style="text-align: left">Type</th>
          <th style="text-align: left">License</th>
          <th style="text-align: left">Notes</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left"><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code (Google Drive)</a></td>
          <td style="text-align: left">Code</td>
          <td style="text-align: left">Unknown</td>
          <td style="text-align: left">Official implementation; hosted on Google Drive (may become inaccessible)</td>
      </tr>
      <tr>
          <td style="text-align: left"><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview</a></td>
          <td style="text-align: left">Other</td>
          <td style="text-align: left">N/A</td>
          <td style="text-align: left">Reviews, author responses, and supplementary material</td>
      </tr>
  </tbody>
</table>
<h3 id="algorithms">Algorithms</h3>
<h4 id="the-reweighting-mechanism">The Reweighting Mechanism</h4>
<p>The core innovation is defining the NCP prior as $p_{\text{NCP}}(z) \propto p(z)r(z)$. The reweighting factor $r(z)$ is derived from the binary classifier $D(z)$ using the <strong>likelihood ratio trick</strong>:</p>
<p>$$ r(z) \approx \frac{D(z)}{1 - D(z)} $$</p>
<p>Here, $D(z)$ is the sigmoid output of the trained discriminator, representing the probability that sample $z$ came from the aggregate posterior $q(z)$ (&ldquo;real&rdquo;). For an optimal discriminator $D^*(z)$, this ratio exactly equals $\frac{q(z)}{p(z)}$, allowing the model to approximate the density ratio without explicit density estimation.</p>















<figure class="post-figure center ">
    <img src="/img/notes/ncp-vae-reweighting-the-prior-posterior.webp"
         alt="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         title="Visualization of the NCP reweighting mechanism showing three 1D distributions: q(z) the complex bimodal aggregate posterior, p(z) the simple Gaussian prior, and r(z) the learned reweighting factor that transforms p(z) to match q(z)"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The reweighting mechanism: the learned factor $r(z)$ (bottom) reweights the simple Gaussian prior $p(z)$ (middle) to approximate the complex aggregate posterior $q(z)$ (top). Where $q(z)$ has high density but $p(z)$ is low, $r(z)$ compensates with high values.</figcaption>
    
</figure>

<h4 id="hierarchical-architecture-strategy">Hierarchical Architecture Strategy</h4>
<p>For hierarchical models (like NVAE), the method trains $K$ binary classifiers in parallel (one for each latent group). Crucially, to ensure efficiency, the classifiers reuse the <strong>context feature</strong> $c(z_{&lt;k})$ extracted by the frozen VAE&rsquo;s prior network. This architectural choice provides significant computational savings.</p>
<h4 id="test-time-sampling-inference">Test-Time Sampling (Inference)</h4>
<p>Since $p_{\text{NCP}}(z)$ is an energy-based model, direct sampling is impossible. The paper employs two methods to generate samples:</p>
<ul>
<li><strong>Sampling-Importance-Resampling (SIR):</strong> Used for most results. It draws $M$ samples (e.g., $M=5000$) from the base prior $p(z)$ and resamples them based on weights $w^{(m)} = r(z^{(m)})$.</li>
<li><strong>Langevin Dynamics (LD):</strong> An iterative refinement method using the gradient of the energy function $E(z) = -\log r(z) - \log p(z)$.</li>
</ul>
<h3 id="models">Models</h3>
<h4 id="decoder-architecture">Decoder Architecture</h4>
<p>For RGB datasets (CIFAR-10, CelebA), the output likelihood must be changed from <strong>Discretized Logistic</strong> (standard NVAE) to a <strong>Normal distribution</strong>. The authors note this change alone led to &ldquo;significant improvements in the base model performance.&rdquo; Using the standard NVAE decoder will result in a weaker baseline than reported.</p>
<h4 id="discriminator-architecture">Discriminator Architecture</h4>
<p>The binary classifier uses a ResNet-style architecture with <strong>Squeeze-and-Excitation (SE)</strong> blocks:</p>
<ul>
<li><strong>Activation:</strong> Swish</li>
<li><strong>Normalization:</strong> Batch Normalization</li>
<li><strong>Optimization:</strong> Adam with Cosine Annealing (learning rate: $10^{-3} \to 10^{-7}$)</li>
</ul>
<p>The SE blocks help the model focus on channel-wise feature recalibration, which is important for distinguishing subtle differences between prior and aggregate posterior in high-dimensional latent spaces.</p>
<h3 id="hardware">Hardware</h3>
<p>The main paper is vague on training time, but the OpenReview rebuttal explicitly lists hardware costs:</p>
<ul>
<li><strong>Hardware:</strong> NVIDIA Tesla V100 (32GB) GPUs</li>
<li><strong>Per-Discriminator Training:</strong> ~13 hours for 100 epochs</li>
<li><strong>Parallelization:</strong> Because latent groups are independent, all discriminators can train in parallel</li>
<li><strong>Total Cost (CelebA-64):</strong> ~8.1 GPU-days</li>
<li><strong>Comparison:</strong> The authors argue this is efficient compared to VDVAE, which requires ~560 GPU-days</li>
</ul>
<h3 id="evaluation">Evaluation</h3>
<h4 id="inference-speed-vs-quality-trade-off">Inference Speed vs. Quality Trade-off</h4>
<p>Reviewers flagged that SIR sampling can be prohibitively slow. The authors clarified the exact trade-off:</p>
<table>
  <thead>
      <tr>
          <th style="text-align: left">Proposal Samples ($M$)</th>
          <th style="text-align: left">Time per Image</th>
          <th style="text-align: left">FID (CelebA-64)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td style="text-align: left">5,000 (paper default)</td>
          <td style="text-align: left">~10.11 seconds</td>
          <td style="text-align: left">5.25</td>
      </tr>
      <tr>
          <td style="text-align: left">500 (practical)</td>
          <td style="text-align: left">~1.25 seconds</td>
          <td style="text-align: left">6.76</td>
      </tr>
  </tbody>
</table>
<p>The quality gain from 500 to 5,000 samples is modest (FID difference of 1.51) while inference time increases roughly 8x, suggesting $M=500$ may be a practical default.</p>
<h4 id="hyperparameters">Hyperparameters</h4>
<ul>
<li><strong>FID Calculation:</strong> 50,000 samples</li>
<li><strong>SIR Proposals:</strong> 5,000 samples (paper default) or 500 (practical)</li>
<li><strong>MNIST:</strong> Dynamically binarized version used for likelihood evaluation</li>
<li><strong>Optimizers:</strong> The study largely adopts hyperparameters from baseline papers (e.g., Lawson et al. for MNIST, Ghosh et al. for RAE)</li>
</ul>
<h4 id="debugging-benchmark-25-gaussians">Debugging Benchmark: 25-Gaussians</h4>
<p>The supplement provides a toy experiment ideal for verifying a new implementation before running on expensive image datasets:</p>
<ul>
<li><strong>Setup:</strong> Synthetic dataset of 25 2D-Gaussians arranged on a grid</li>
<li><strong>Target NLL:</strong> ~-0.954 nats (NCP) vs. ~-2.753 nats (Vanilla VAE)</li>
<li><strong>Success Criterion:</strong> Samples should avoid low-density regions between grid points. A standard VAE will generate samples in these &ldquo;prior holes,&rdquo; while a working NCP implementation should cleanly remove these artifacts.</li>
</ul>
<h4 id="implementation-warnings">Implementation Warnings</h4>
<ul>
<li><strong>SIR Failure Mode:</strong> If the learned prior $p_{\text{NCP}}$ deviates too far from the base prior, SIR sampling collapses (low effective sample size). The paper shows a strong correlation between the NCE classification loss and the effective sample size (Fig. 5(b)), indicating that SIR reliability depends on how well the base prior matches the aggregate posterior.</li>
<li><strong>Temperature Scaling:</strong> The qualitative images in the paper use reduced temperature for improved visual sharpness (Section 5.3). The FID tables do not specify a temperature, so results may or may not use $T=1.0$.</li>
</ul>
<h3 id="data">Data</h3>
<p>The method was evaluated on several standard image generation benchmarks:</p>
<ul>
<li><strong>MNIST</strong> (dynamically binarized): Likelihood evaluation on a controlled, small-latent-space task</li>
<li><strong>CIFAR-10</strong>: Standard computer vision benchmark for generative modeling (32x32 RGB images)</li>
<li><strong>CelebA 64x64</strong>: Face generation task with moderate resolution</li>
<li><strong>CelebA HQ 256x256</strong>: High-resolution face generation task</li>
</ul>
<p>All datasets use standard train/test splits from the computer vision literature.</p>
<h4 id="additional-metrics">Additional Metrics</h4>
<p>Beyond FID and NLL, the paper uses:</p>
<ul>
<li><strong>Effective Sample Size (ESS):</strong> Validates reliability of the SIR sampling procedure</li>
<li><strong>Maximum Mean Discrepancy (MMD):</strong> Measures distance between aggregate posterior and NCP prior distributions</li>
</ul>
<h2 id="paper-information">Paper Information</h2>
<p><strong>Citation</strong>: Aneja, J., Schwing, A. G., Kautz, J., &amp; Vahdat, A. (2021). A contrastive learning approach for training variational autoencoder priors. <em>Advances in Neural Information Processing Systems</em>, 34, 29604-29616. <a href="https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html">https://proceedings.neurips.cc/paper/2021/hash/0496604c1d80f66fbeb963c12e570a26-Abstract.html</a></p>
<p><strong>Publication</strong>: NeurIPS 2021</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-bibtex" data-lang="bibtex"><span style="display:flex;"><span><span style="color:#a6e22e">@inproceedings</span>{aneja2021contrastive,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">title</span>=<span style="color:#e6db74">{A Contrastive Learning Approach for Training Variational Autoencoder Priors}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">author</span>=<span style="color:#e6db74">{Aneja, Jyoti and Schwing, Alexander G and Kautz, Jan and Vahdat, Arash}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">booktitle</span>=<span style="color:#e6db74">{Advances in Neural Information Processing Systems}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">volume</span>=<span style="color:#e6db74">{34}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">pages</span>=<span style="color:#e6db74">{29604--29616}</span>,
</span></span><span style="display:flex;"><span>  <span style="color:#a6e22e">year</span>=<span style="color:#e6db74">{2021}</span>
</span></span><span style="display:flex;"><span>}
</span></span></code></pre></div><p><strong>Additional Resources</strong>:</p>
<ul>
<li><a href="https://openreview.net/forum?id=LcSfRundgwI">OpenReview Discussion</a></li>
<li><a href="https://drive.google.com/drive/folders/15tCGruQcSdm2G4yLkUpKvGASluSZPIBD">Code Repository</a> (Google Drive; link may become inaccessible over time)</li>
</ul>
]]></content:encoded></item><item><title>Modern PyTorch VAEs: A Detailed Implementation Guide</title><link>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</link><pubDate>Sun, 03 Mar 2024 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/</guid><description>Complete PyTorch VAE tutorial: Copy-paste code, ELBO derivation, KL annealing, and stable softplus parameterization.</description><content:encoded><![CDATA[<h2 id="what-is-a-variational-autoencoder">What is a Variational Autoencoder?</h2>
<p>A Variational Autoencoder (VAE) is a type of <strong>generative model</strong>, meaning its primary purpose is to learn the underlying structure of a dataset so it can generate new, similar data.</p>
<p>Whether the data is images, raw audio clips, or 2D graphs of drug-like molecules, a VAE aims to capture the essential features that define the data distribution. Once trained, it should be able to create entirely new samples that resemble the training data without simply copying specific examples.</p>
<p>Introduced by Kingma and Welling in 2013 (<a href="/notes/machine-learning/generative-models/autoencoding-variational-bayes/">Auto-Encoding Variational Bayes</a>, <a href="https://arxiv.org/abs/1312.6114">Paper</a>), VAEs are used for:</p>
<ul>
<li><strong>Generation</strong>: Creating new data (images, music, text).</li>
<li><strong>Dimensionality Reduction</strong>: Compressing data into a much smaller, meaningful representation (a &ldquo;latent space&rdquo;).</li>
<li><strong>Imputation</strong>: Intelligently filling in missing data (e.g., denoising images).</li>
</ul>
<p>Importantly, they aim to provide a structured and continuous latent space, which allows for smooth interpolation between data points and meaningful manipulations of generated samples (think: optimization).</p>
<h2 id="tldr-the-complete-pytorch-implementation">TL;DR: The Complete PyTorch Implementation</h2>
<p>For those who just want the code, here is a complete, modern VAE implementation in PyTorch. It features <strong>softplus standard deviation parameterization</strong> for numerical stability and a <strong>custom training step</strong> that handles the ELBO loss correctly.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, input_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">784</span>, hidden_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">512</span>, latent_dim<span style="color:#f92672">=</span><span style="color:#ae81ff">16</span>):
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(input_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh()
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_mu <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>fc_std <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Linear(hidden_dim, latent_dim)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(latent_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, hidden_dim),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(hidden_dim, input_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x):
</span></span><span style="display:flex;"><span>        h <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>fc_mu(h)
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Softplus + epsilon for stable std deviation</span>
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>softplus(self<span style="color:#f92672">.</span>fc_std(h)) <span style="color:#f92672">+</span> <span style="color:#ae81ff">1e-6</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu, std):
</span></span><span style="display:flex;"><span>        eps <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> eps <span style="color:#f92672">*</span> std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z):
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(self, x, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>        mu, std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu, std)
</span></span><span style="display:flex;"><span>        x_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 1. Reconstruction Loss (Binary Cross Entropy for MNIST)</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Sum over features, mean over batch</span>
</span></span><span style="display:flex;"><span>        recon_loss <span style="color:#f92672">=</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(x_recon, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#39;none&#39;</span>)<span style="color:#f92672">.</span>sum(dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 2. KL Divergence</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytic KL for Normal distributions</span>
</span></span><span style="display:flex;"><span>        kl_loss <span style="color:#f92672">=</span> <span style="color:#f92672">-</span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(<span style="color:#ae81ff">1</span> <span style="color:#f92672">+</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> mu<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span> <span style="color:#f92672">-</span> std<span style="color:#f92672">**</span><span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>mean()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># 3. Total Loss (ELBO)</span>
</span></span><span style="display:flex;"><span>        loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> (kl_weight <span style="color:#f92672">*</span> kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> VAEOutput(z, mu, std, x_recon, loss, recon_loss, kl_loss)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># --- Training Loop Example ---</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">train_step</span>(model, batch, optimizer, kl_weight<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    model<span style="color:#f92672">.</span>train()
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>zero_grad()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Forward pass</span>
</span></span><span style="display:flex;"><span>    output <span style="color:#f92672">=</span> model(batch, kl_weight)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Backward pass</span>
</span></span><span style="display:flex;"><span>    output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>backward()
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># Gradient clipping (recommended)</span>
</span></span><span style="display:flex;"><span>    torch<span style="color:#f92672">.</span>nn<span style="color:#f92672">.</span>utils<span style="color:#f92672">.</span>clip_grad_norm_(model<span style="color:#f92672">.</span>parameters(), max_norm<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    optimizer<span style="color:#f92672">.</span>step()
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> output<span style="color:#f92672">.</span>loss<span style="color:#f92672">.</span>item()
</span></span></code></pre></div><h3 id="the-core-idea-learning-to-generate">The Core Idea: Learning to Generate</h3>
<p>The VAE is built on a key assumption: our complex, high-dimensional data (like a $28 \times 28$ pixel image, $\mathbf{x}$) is actually <em>generated</em> by some simpler, low-dimensional, unobserved variable (a &ldquo;latent&rdquo; variable, $\mathbf{z}$).</p>
<blockquote>
<p><strong>A Physical Metaphor: Water Molecules and Phase Diagrams</strong></p>
<p>Consider a glass of water. At the microscopic level, you have more than $10^{24}$ $\text{H}_2\text{O}$ molecules bouncing around in an incredibly high-dimensional space. Each molecule has position, velocity, and interactions with its neighbors, computationally intractable to track directly. Yet we can describe the <em>macroscopic behavior</em> of all these molecules using just two simple variables: <strong>temperature</strong> and <strong>pressure</strong>. These two dimensions create a &ldquo;phase diagram&rdquo; that tells us whether our water will be ice, liquid, or vapor. The temperature and pressure are &ldquo;latent variables&rdquo; that capture the essential physics governing this complex molecular dance.</p></blockquote>















<figure class="post-figure center ">
    <img src="/img/vae-tut/phase-diagram.webp"
         alt="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         title="Water phase diagram showing solid, liquid, and gas phases as functions of temperature and pressure"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A water phase diagram: Complex molecular behavior reduced to two simple variables (temperature and pressure). This illustrates how high-dimensional systems can often be understood through low-dimensional latent representations.</figcaption>
    
</figure>

<p>A VAE makes the same assumption: complex data (like images) emerges from simpler underlying factors. A handwritten digit might be generated by latent factors like &ldquo;pen thickness,&rdquo; &ldquo;writing angle,&rdquo; &ldquo;digit style,&rdquo; and &ldquo;size,&rdquo; a much simpler description than tracking all 784 pixel values independently.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/hypothetical-mnist-factors.webp"
         alt="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         title="Hypothetical illustration of MNIST digits generated from latent factors like pen thickness, angle, style, and size"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">A hypothetical illustration showing how MNIST digits could be generated from a few latent factors like pen thickness, writing angle, digit style, and size.</figcaption>
    
</figure>

<p>The VAE learns two functions: one that maps from complex data ($\mathbf{x}$) to these descriptive factors ($\mathbf{z}$), and another that maps from these factors back to the data. It accomplishes this with two main components, typically implemented as neural networks:</p>
<p><strong>1. The Encoder (Recognition Model)</strong></p>
<p>This network takes a complex data point $\mathbf{x}$ (an image) and determines the &ldquo;knob settings&rdquo; $\mathbf{z}$ that could explain or generate it. This allows us to <em>compress</em> or <em>understand</em> the data.</p>
<p>$$q_{\phi}(\mathbf{z} | \mathbf{x})$$</p>
<p>It&rsquo;s like examining a container of molecules and summarizing their complex arrangement into key parameters like temperature and pressure.</p>
<p>Crucially, the encoder outputs the <em>parameters</em> of a probability distribution (a simple Gaussian) that describes $\mathbf{z}$.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/encoding-diagram.webp"
         alt="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         title="Diagram mapping MNIST five to a Gaussian distribution in latent space with mean and standard deviation"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Encoder maps an input image (e.g., an MNIST digit &lsquo;5&rsquo;) to a Gaussian distribution in latent space, characterized by a mean vector and a standard deviation vector.</figcaption>
    
</figure>

<p>For each input $\mathbf{x}$, the encoder network outputs:</p>
<ul>
<li>A vector of means, $\mathbf{\mu}$</li>
<li>A vector of standard deviations, $\mathbf{\sigma}$</li>
</ul>
<p>These parameters define our approximation $q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$. We then <em>sample</em> from this distribution to get the $\mathbf{z}$ that we feed to the decoder. This probabilistic step is what forces the latent space to be continuous and structured. It forces similar inputs to map to nearby regions in latent space, enabling smooth interpolation and generation.</p>
<p><strong>2. The Decoder (Generative Model)</strong></p>
<p>This network learns the &ldquo;generative process.&rdquo; It takes a simple latent vector $\mathbf{z}$ and reconstructs the complex data $\mathbf{x}$. This allows us to <em>generate</em> new data by feeding it a random $\mathbf{z}$ and observing what image $\mathbf{x}$ it produces.</p>
<p>$$p_{\theta}(\mathbf{x} | \mathbf{z})$$</p>
<p>The decoder reverses the encoder: it takes the simple latent representation and &ldquo;paints&rdquo; the full, complex image from it. It&rsquo;s like taking temperature and pressure values and producing a detailed arrangement of water molecules consistent with those conditions. The goal is to reproduce the exact input as closely as possible.</p>
<p>After training, we have two networks that can be used for a variety of purposes:</p>
<ul>
<li><strong>Generation</strong>: If the latent space is well-structured, we can sample random $\mathbf{z}$ vectors from a simple distribution (like a standard normal) and feed them into the Decoder to generate new images. This is particularly useful for searching for data points with desired properties, like in drug discovery, where we might want to generate molecules with specific characteristics.</li>
<li><strong>Compression</strong>: The Encoder can compress complex data into a low-dimensional latent space, which can be useful for visualization or as a feature extractor for other tasks.</li>
</ul>
<h3 id="the-variational-problem">The &ldquo;Variational&rdquo; Problem</h3>
<p>Calculating the <em>true</em> distribution of latent variables $p_{\theta}(\mathbf{z}|\mathbf{x})$ (the posterior) is mathematically intractable.</p>
<p>This intractability arises from Bayes&rsquo; theorem:</p>
<p>$$p_{\theta}(\mathbf{z} | \mathbf{x}) = \frac{p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z})}{p_{\theta}(\mathbf{x})}$$</p>
<p>Breaking down each component:</p>
<ul>
<li>$p_{\theta}(\mathbf{x} | \mathbf{z})$ is our decoder, which is straightforward to compute given our likelihood model.</li>
<li>$p_{\theta}(\mathbf{z})$ is our prior over latent variables, typically a simple distribution like a standard normal, making it easy to compute.</li>
<li>$p_{\theta}(\mathbf{x})$ is the marginal likelihood of the data. And here lies the problem. It requires integrating over all possible latent variables that could have generated $\mathbf{x}$:
$$p_{\theta}(\mathbf{x}) = \int p_{\theta}(\mathbf{x} | \mathbf{z}) p_{\theta}(\mathbf{z}) d\mathbf{z}$$
It is the normalization factor that ensures the posterior is a valid probability distribution (i.e., sums to 1 over all $\mathbf{z}$).</li>
</ul>
<p>This integral is intractable because it involves integrating over a high-dimensional latent space with a complex likelihood function. No closed-form solution exists, and numerical integration is computationally prohibitive.</p>
<p>This is where the &ldquo;variational&rdquo; approach provides the solution. We approximate the true posterior by learning an encoder, $q_{\phi}(\mathbf{z} | \mathbf{x})$, that serves as a variational approximation to this intractable true distribution. The VAE&rsquo;s training process optimizes this approximation to be as accurate as possible, pushing this learned distribution closer to the true posterior.</p>
<h3 id="the-vae-objective-a-balancing-act">The VAE Objective: A Balancing Act</h3>
<p>To get these two networks (parameterized by $\theta$ and $\phi$) to work together, we train them jointly with a special loss function. This objective has two parts that balance two different goals:</p>
<h4 id="1-reconstruction-loss">1. Reconstruction Loss</h4>
<p>$$E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})]$$</p>
<p>This term asks: &ldquo;How well can we reconstruct our original image?&rdquo; It forces the VAE to be good at its job. The process goes:</p>
<ol>
<li>Take an input point $\mathbf{x}$.</li>
<li>Use the <strong>Encoder</strong> to get its latent representation $\mathbf{z} \sim q_{\phi}(\mathbf{z} | \mathbf{x})$.</li>
<li>Use the <strong>Decoder</strong> to generate a new image $\mathbf{x}&rsquo;$ from $\mathbf{z}$, $\mathbf{x}&rsquo; \sim p_{\theta}(\mathbf{x} | \mathbf{z})$.</li>
<li>Compare $\mathbf{x}$ and $\mathbf{x}&rsquo;$.</li>
</ol>
<p>The reconstruction loss measures the difference between the original and the reconstructed image.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstruction-loss-graphic.webp"
         alt="Graphic illustrating the reconstruction loss between original and reconstructed images"
         title="Graphic illustrating the reconstruction loss between original and reconstructed images"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The Reconstruction Loss measures how closely the Decoder&rsquo;s output matches the original input image.</figcaption>
    
</figure>

<ul>
<li><strong>For continuous inputs</strong> (like general images), this is often Mean Squared Error (MSE).</li>
<li><strong>For inputs in $[0, 1]$</strong> (like MNIST pixel intensities, which after <code>ToTensor()</code> are continuous values in $[0, 1]$), we use Binary Cross-Entropy (BCE). We treat each pixel as an independent Bernoulli variable whose target is its intensity. The decoder outputs the <em>logits</em> for each pixel, and the BCE-with-logits loss (e.g., <code>F.binary_cross_entropy_with_logits</code>) is the numerically stable way to compute the negative log-likelihood.</li>
<li><strong>More generally</strong>, you can output parameters of a desired output distribution. What if you wanted a mixture of Gaussians? The decoder could output the means, variances, and mixture weights, and you could compute the negative log-likelihood accordingly.</li>
</ul>
<p>This loss pushes the encoder to produce useful $\mathbf{z}$ vectors and pushes the decoder to learn how to interpret them accurately.</p>
<h4 id="2-the-kl-divergence-the-regularizer">2. The KL Divergence (The Regularizer)</h4>
<p>$$D_{KL}(q_{\phi}(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>On its own, the reconstruction loss might &ldquo;cheat.&rdquo; The encoder could learn to map every image to a different, specific point in the latent space, essentially &ldquo;memorizing&rdquo; the data. While this minimizes reconstruction error, it creates a meaningless latent space that fails at generation.</p>
<p>The KL divergence term fixes this. It&rsquo;s a regularizer that forces the latent space to be organized and smooth.</p>
<p>We force the encoder&rsquo;s output, $q_{\phi}(\mathbf{z} | \mathbf{x})$, to be close to a simple, predefined <em>prior distribution</em>, $p_{\theta}(\mathbf{z})$. This prior is almost always a standard normal distribution because it is mathematically convenient, easy to sample from, and encourages a well-behaved latent space.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/kl-loss-graphic.webp"
         alt="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         title="Graphic illustrating the KL divergence between the encoder&#39;s output distribution and the prior distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The KL Divergence measures how much the Encoder&rsquo;s output distribution diverges from the simple prior distribution.</figcaption>
    
</figure>

<p>This regularization term acts as a penalty, measuring how much the encoder&rsquo;s output distribution diverges from the simple prior. By minimizing this KL divergence, we encourage the model to:</p>
<ul>
<li><strong>Avoid overfitting</strong> by preventing the encoder from memorizing specific locations for each input</li>
<li><strong>Create meaningful clusters</strong> where similar inputs map to nearby regions in the latent space</li>
<li><strong>Maintain continuity</strong> so that points close together in latent space (like different variations of the digit &ldquo;7&rdquo;) decode into visually similar outputs</li>
</ul>
<p>This smooth, structured latent space is what enables generation: we can sample random points from our prior distribution and decode them into realistic new data.</p>
<p>Ultimately, the optimizer finds a balance between these two objectives: reconstructing the data well while keeping the latent space organized and regularized.</p>
<h3 id="the-reparameterization-trick-making-it-all-trainable">The Reparameterization Trick: Making it All Trainable</h3>
<p>We have a problem. The training process requires sampling:</p>
<ol>
<li>Encoder produces $\mathbf{\mu}$ and $\mathbf{\sigma}$.</li>
<li>We <strong>sample</strong> $\mathbf{z} \sim \mathcal{N}(\mathbf{\mu}, \mathbf{\sigma}^2\mathbf{I})$.</li>
<li>Decoder uses $\mathbf{z}$ to reconstruct $\mathbf{x}&rsquo;$.</li>
<li>We calculate the loss.</li>
</ol>
<p>The &ldquo;sampling&rdquo; step is a random, non-differentiable operation. We can&rsquo;t backpropagate the reconstruction loss from the decoder <em>through</em> this random node to update the encoder&rsquo;s weights.</p>
<p>The <strong>reparameterization trick</strong> makes the sampling process differentiable. We generate $\mathbf{z}$ deterministically by sampling a random noise vector and transforming it:</p>
<ol>
<li>Sample a random noise vector $\mathbf{\epsilon}$ from a simple, fixed distribution (e.g., the standard normal $\mathcal{N}(\mathbf{0}, \mathbf{I})$).</li>
<li>Compute $\mathbf{z}$ as: $\mathbf{z} = \mathbf{\mu} + \mathbf{\sigma} \odot \mathbf{\epsilon}$</li>
</ol>
<p>This simple change moves the randomness &ldquo;outside&rdquo; the network. The gradient can now flow deterministically from $\mathbf{z}$ back through the $\mathbf{\mu}$ and $\mathbf{\sigma}$ nodes to the encoder network. This is the key engineering insight that allows us to train the entire model end-to-end with standard backpropagation.</p>
<h3 id="where-does-this-objective-come-from-the-math">Where Does This Objective Come From? (The Math)</h3>
<p>This two-part loss function is derived directly from the goal of maximizing the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x})$.</p>
<p>For a single data point $\mathbf{x}^{(i)}$, we can write:
$$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$</p>
<ul>
<li>The first term is the KL divergence between our encoder&rsquo;s approximation and the (intractable) true posterior. This is non-negative, and unfortunately we cannot compute it.</li>
<li>The second term, $\mathcal{L}$, is the Variational Lower Bound (also known as the Evidence Lower Bound, or ELBO). Since the KL term is $\ge 0$, we know that $\log p_{\theta}(\mathbf{x}^{(i)}) \ge \mathcal{L}$.</li>
</ul>
<p>By maximizing this lower bound $\mathcal{L}$, we push up the &ldquo;floor&rdquo; on the true likelihood of our data. This is a problem we can solve.</p>
<p>When we expand this $\mathcal{L}$ term, we get our famous two-part objective:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x}^{(i)})}[\log p_{\theta}(\mathbf{x}^{(i)} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z}))$$</p>
<ul>
<li><strong>Term 1:</strong> The expected log-likelihood of reconstructing $\mathbf{x}^{(i)}$ from $\mathbf{z}$. Maximizing this is the same as minimizing the Reconstruction Loss.</li>
<li><strong>Term 2:</strong> The negative KL divergence between our encoder and the simple prior. Maximizing this is the same as minimizing the KL Divergence Loss.</li>
</ul>
<p>Thus, the VAE&rsquo;s objective balances these two critical goals: faithfully reconstructing the data while maintaining a simple, regularized latent structure that is useful for generation.</p>
<h3 id="from-elbo-to-practical-loss">From ELBO to Practical Loss</h3>
<p>Remember, our goal is to <strong>maximize</strong> the ELBO:</p>
<p>$$\mathcal{L}(\theta, \phi; \mathbf{x}) = E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] - D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>Since deep learning libraries are built to <strong>minimize</strong> a loss function, we simply flip the sign and <strong>minimize the negative ELBO ($-\mathcal{L}$)</strong>.</p>
<p>This gives us our final, practical loss function:</p>
<p>$$\text{Loss} = -\mathcal{L} = -E_{q_{\phi}(\mathbf{z} | \mathbf{x})}[\log p_{\theta}(\mathbf{x} | \mathbf{z})] + D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z}))$$</p>
<p>This is the function you actually implement. Minimizing this loss achieves both of our goals:</p>
<ol>
<li>It <strong>minimizes the Reconstruction Loss</strong> (which is the same as maximizing the log-likelihood).</li>
<li>It <strong>minimizes the KL Divergence</strong>, forcing the encoder to match the prior.</li>
</ol>
<h2 id="modern-pytorch-vae-implementation">Modern PyTorch VAE Implementation</h2>
<p>Now that we understand the VAE architecture and objective, let&rsquo;s implement a modern VAE in PyTorch. I&rsquo;ll focus primarily on the model and loss function here, though the full code is available <a href="https://github.com/hunter-heidenreich/vae">on GitHub</a>.</p>
<p>My VAE implementation uses an output <code>dataclass</code> and a VAE class extending <code>nn.Module</code>.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder (VAE) model implementation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">from</span> dataclasses <span style="color:#f92672">import</span> dataclass
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn <span style="color:#66d9ef">as</span> nn
</span></span><span style="display:flex;"><span><span style="color:#f92672">import</span> torch.nn.functional <span style="color:#66d9ef">as</span> F
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_activation</span>(activation: str) <span style="color:#f92672">-&gt;</span> nn<span style="color:#f92672">.</span>Module:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Get activation function by name.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>    activation_lower <span style="color:#f92672">=</span> activation<span style="color:#f92672">.</span>lower()
</span></span><span style="display:flex;"><span>    ACTIVATION_MAP <span style="color:#f92672">=</span> {
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;relu&#34;</span>: nn<span style="color:#f92672">.</span>ReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;tanh&#34;</span>: nn<span style="color:#f92672">.</span>Tanh(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;sigmoid&#34;</span>: nn<span style="color:#f92672">.</span>Sigmoid(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;leaky_relu&#34;</span>: nn<span style="color:#f92672">.</span>LeakyReLU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;elu&#34;</span>: nn<span style="color:#f92672">.</span>ELU(),
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;gelu&#34;</span>: nn<span style="color:#f92672">.</span>GELU(),
</span></span><span style="display:flex;"><span>    }
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">if</span> activation_lower <span style="color:#f92672">not</span> <span style="color:#f92672">in</span> ACTIVATION_MAP:
</span></span><span style="display:flex;"><span>        supported <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;, &#34;</span><span style="color:#f92672">.</span>join(ACTIVATION_MAP<span style="color:#f92672">.</span>keys())
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">raise</span> <span style="color:#a6e22e">ValueError</span>(
</span></span><span style="display:flex;"><span>            <span style="color:#e6db74">f</span><span style="color:#e6db74">&#34;Unsupported activation &#39;</span><span style="color:#e6db74">{</span>activation<span style="color:#e6db74">}</span><span style="color:#e6db74">&#39;. Supported: </span><span style="color:#e6db74">{</span>supported<span style="color:#e6db74">}</span><span style="color:#e6db74">&#34;</span>
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> ACTIVATION_MAP[activation_lower]
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEConfig</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE model configuration specifying architecture and behavior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    hidden_dim: int
</span></span><span style="display:flex;"><span>    latent_dim: int
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    input_shape: tuple[int, int, int] <span style="color:#f92672">=</span> (<span style="color:#ae81ff">1</span>, <span style="color:#ae81ff">28</span>, <span style="color:#ae81ff">28</span>)  <span style="color:#75715e"># Default: MNIST</span>
</span></span><span style="display:flex;"><span>    activation: str <span style="color:#f92672">=</span> <span style="color:#e6db74">&#34;tanh&#34;</span>  <span style="color:#75715e"># Default: tanh, what was used in the original VAE paper</span>
</span></span><span style="display:flex;"><span>    use_softplus_std: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>  <span style="color:#75715e"># Whether to use softplus for std parameterization</span>
</span></span><span style="display:flex;"><span>    n_samples: int <span style="color:#f92672">=</span> <span style="color:#ae81ff">1</span>  <span style="color:#75715e"># Number of latent samples per input during training</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#a6e22e">@dataclass</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAEOutput</span>:
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;VAE forward pass output containing all relevant tensors and optional losses.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    z: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    mu: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    x_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_recon: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>    loss_kl: torch<span style="color:#f92672">.</span>Tensor <span style="color:#f92672">|</span> <span style="color:#66d9ef">None</span> <span style="color:#f92672">=</span> <span style="color:#66d9ef">None</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">class</span> <span style="color:#a6e22e">VAE</span>(nn<span style="color:#f92672">.</span>Module):
</span></span><span style="display:flex;"><span>    <span style="color:#e6db74">&#34;&#34;&#34;Variational Autoencoder with support for deterministic and probabilistic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    DEFAULT_EPS <span style="color:#f92672">=</span> <span style="color:#ae81ff">1e-8</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">__init__</span>(self, config: VAEConfig) <span style="color:#f92672">-&gt;</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Initialize VAE with given configuration.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            config: VAE configuration specifying architecture and behavior
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        super()<span style="color:#f92672">.</span><span style="color:#a6e22e">__init__</span>()
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>config <span style="color:#f92672">=</span> config
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build encoder: input -&gt; hidden -&gt; latent parameters (mu, sigma)</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>encoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Flatten(),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape))), config<span style="color:#f92672">.</span>hidden_dim
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>hidden_dim, config<span style="color:#f92672">.</span>latent_dim <span style="color:#f92672">*</span> <span style="color:#ae81ff">2</span>),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Build decoder: latent -&gt; hidden -&gt; reconstructed input</span>
</span></span><span style="display:flex;"><span>        self<span style="color:#f92672">.</span>decoder <span style="color:#f92672">=</span> nn<span style="color:#f92672">.</span>Sequential(
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(config<span style="color:#f92672">.</span>latent_dim, config<span style="color:#f92672">.</span>hidden_dim),
</span></span><span style="display:flex;"><span>            get_activation(config<span style="color:#f92672">.</span>activation),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Linear(
</span></span><span style="display:flex;"><span>                config<span style="color:#f92672">.</span>hidden_dim, int(torch<span style="color:#f92672">.</span>prod(torch<span style="color:#f92672">.</span>tensor(config<span style="color:#f92672">.</span>input_shape)))
</span></span><span style="display:flex;"><span>            ),
</span></span><span style="display:flex;"><span>            nn<span style="color:#f92672">.</span>Unflatten(<span style="color:#ae81ff">1</span>, config<span style="color:#f92672">.</span>input_shape),
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">encode</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Encode input to latent distribution parameters.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        encoder_output <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encoder(x)
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>chunk(encoder_output, <span style="color:#ae81ff">2</span>, dim<span style="color:#f92672">=-</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu, sigma
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">decode</span>(self, z: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Decode latent representation to reconstruction logits&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> self<span style="color:#f92672">.</span>decoder(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">reparameterize</span>(self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Apply reparameterization trick for differentiable sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        epsilon <span style="color:#f92672">=</span> torch<span style="color:#f92672">.</span>randn_like(std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu <span style="color:#f92672">+</span> std <span style="color:#f92672">*</span> epsilon
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">forward</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        compute_loss: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">True</span>,
</span></span><span style="display:flex;"><span>        reconstruct: bool <span style="color:#f92672">=</span> <span style="color:#66d9ef">False</span>,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> VAEOutput:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Forward pass through the VAE.
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Args:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            x: Input tensor of shape (batch_size, *input_shape)
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            compute_loss: Whether to compute VAE loss components
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            reconstruct: Whether to return reconstructions or distributions
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            eps: Small epsilon value for numerical stability
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        Returns:
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">            VAEOutput containing all relevant tensors and optionally computed losses
</span></span></span><span style="display:flex;"><span><span style="color:#e6db74">        &#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Prepare input for multiple sampling if needed</span>
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_for_sampling(x) <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">&gt;</span> <span style="color:#ae81ff">1</span> <span style="color:#66d9ef">else</span> x
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Encode and sample from latent space</span>
</span></span><span style="display:flex;"><span>        mu, sigma <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>encode(x)
</span></span><span style="display:flex;"><span>        std <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_sigma_to_std(sigma, eps<span style="color:#f92672">=</span>eps)
</span></span><span style="display:flex;"><span>        mu_expanded, std_expanded <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_expand_latent_params(mu, std)
</span></span><span style="display:flex;"><span>        z <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>reparameterize(mu_expanded, std_expanded)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Decode latent samples</span>
</span></span><span style="display:flex;"><span>        x_logits <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>decode(z)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Create output object</span>
</span></span><span style="display:flex;"><span>        output <span style="color:#f92672">=</span> VAEOutput(
</span></span><span style="display:flex;"><span>            x_logits<span style="color:#f92672">=</span>x_logits,
</span></span><span style="display:flex;"><span>            z<span style="color:#f92672">=</span>z,
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">=</span>mu,
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">=</span>std,
</span></span><span style="display:flex;"><span>            x_recon<span style="color:#f92672">=</span>torch<span style="color:#f92672">.</span>sigmoid(x_logits) <span style="color:#66d9ef">if</span> reconstruct <span style="color:#66d9ef">else</span> <span style="color:#66d9ef">None</span>,
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Compute losses if requested</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> compute_loss:
</span></span><span style="display:flex;"><span>            loss, loss_recon, loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_loss(
</span></span><span style="display:flex;"><span>                x_expanded, x_logits, mu, sigma, std
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss <span style="color:#f92672">=</span> loss
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_recon <span style="color:#f92672">=</span> loss_recon
</span></span><span style="display:flex;"><span>            output<span style="color:#f92672">.</span>loss_kl <span style="color:#f92672">=</span> loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> output
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Helper Methods ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(
</span></span><span style="display:flex;"><span>        self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_for_sampling</span>(self, x: torch<span style="color:#f92672">.</span>Tensor) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand input tensor for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        shape_dims <span style="color:#f92672">=</span> [<span style="color:#ae81ff">1</span>] <span style="color:#f92672">*</span> len(self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>        x_expanded <span style="color:#f92672">=</span> x<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)<span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#f92672">*</span>shape_dims)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> x_expanded<span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, <span style="color:#f92672">*</span>self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>input_shape)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_expand_latent_params</span>(
</span></span><span style="display:flex;"><span>        self, mu: torch<span style="color:#f92672">.</span>Tensor, std: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Expand latent parameters for multiple sampling.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples <span style="color:#f92672">==</span> <span style="color:#ae81ff">1</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> mu, std
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        mu_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            mu<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>        std_expanded <span style="color:#f92672">=</span> (
</span></span><span style="display:flex;"><span>            std<span style="color:#f92672">.</span>unsqueeze(<span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>repeat(<span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>n_samples, <span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>            <span style="color:#f92672">.</span>view(<span style="color:#f92672">-</span><span style="color:#ae81ff">1</span>, self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>latent_dim)
</span></span><span style="display:flex;"><span>        )
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> mu_expanded, std_expanded
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#75715e"># ==================== Loss Computation ====================</span>
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        x: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        x_logits: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> tuple[torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor, torch<span style="color:#f92672">.</span>Tensor]:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute VAE loss components for deterministic reconstruction.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        loss_recon <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_reconstruction_loss(x, x_logits)
</span></span><span style="display:flex;"><span>        loss_kl <span style="color:#f92672">=</span> self<span style="color:#f92672">.</span>_compute_kl_loss(mu, sigma, std)
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> loss_recon <span style="color:#f92672">+</span> loss_kl, loss_recon, loss_kl
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_reconstruction_loss</span>(
</span></span><span style="display:flex;"><span>        self, x: torch<span style="color:#f92672">.</span>Tensor, x_logits: torch<span style="color:#f92672">.</span>Tensor
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute reconstruction loss using binary cross-entropy.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>binary_cross_entropy_with_logits(
</span></span><span style="display:flex;"><span>            x_logits, x, reduction<span style="color:#f92672">=</span><span style="color:#e6db74">&#34;sum&#34;</span>
</span></span><span style="display:flex;"><span>        ) <span style="color:#f92672">/</span> x<span style="color:#f92672">.</span>size(<span style="color:#ae81ff">0</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_compute_kl_loss</span>(
</span></span><span style="display:flex;"><span>        self,
</span></span><span style="display:flex;"><span>        mu: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        sigma: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        std: torch<span style="color:#f92672">.</span>Tensor,
</span></span><span style="display:flex;"><span>        eps: float <span style="color:#f92672">=</span> DEFAULT_EPS,
</span></span><span style="display:flex;"><span>    ) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Compute KL divergence between latent distribution and standard normal prior.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#75715e"># Analytical KL: KL(N(μ,σ²) || N(0,1)) = 0.5 * Σ(μ² + σ² - 1 - log(σ²))</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma is just the raw output, need to use std directly: σ</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>                mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>            )
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#75715e"># sigma represents log-variance parameterization: log(σ²)</span>
</span></span><span style="display:flex;"><span>            kl_per_sample <span style="color:#f92672">=</span> <span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> sigma<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> sigma, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">return</span> kl_per_sample<span style="color:#f92672">.</span>mean()
</span></span></code></pre></div><h3 id="loss-scaling">Loss Scaling</h3>
<p>Both components of the VAE loss should be summed over data dimensions and averaged over the batch size. A common mistake is using the <code>reduction=&quot;mean&quot;</code> option in PyTorch loss functions, which averages over all elements in the tensor.</p>
<ul>
<li>The <strong>KL Divergence</strong> (<code>loss_kl</code>) is a penalization term. Each dimension of the latent space has the potential to add complexity and deviate from the prior. As you increase the latent dimensionality, you typically see the KL loss increase in magnitude. That&rsquo;s the cost of having a more expressive latent space.</li>
<li>The <strong>Reconstruction Loss</strong> (<code>loss_recon</code>) measures how well the model reconstructs the input data, and it should scale with input dimensionality (this can bias the model toward better reconstruction for higher-dimensional data).</li>
</ul>
<p>In the case of MNIST, if we used <code>reduction=&quot;mean&quot;</code> for BCE, it would be averaged over all $784 \times \text{batch size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, potentially leading to posterior collapse.</p>
<p>While modern optimizers can handle a variety of scenarios and you can still learn effective models with imperfect scaling, the original VAE paper used the scaling described above, and I recommend following that convention.</p>
<h3 id="mitigating-posterior-collapse-kl-annealingwarmup">Mitigating Posterior Collapse: KL Annealing/Warmup</h3>
<p>One common issue in training VAEs, especially with powerful decoders (like RNNs or deep CNNs), is <strong>posterior collapse</strong>. This happens when the KL term dominates the loss early in training. The model quickly learns to just output the prior distribution ($q(z|x) \approx p(z)$) to drive the KL loss to zero, effectively ignoring the latent code $z$. The decoder then becomes a powerful autoregressive model that ignores the latent input.</p>
<p>To prevent this, we often use <strong>KL Annealing</strong> (or Warmup). We introduce a weight $\beta$ for the KL term that starts at 0 and slowly increases to 1 over the first $N$ steps or epochs.</p>
<p>$$ \mathcal{L} = \mathcal{L}_{recon} + \beta \cdot D_{KL} $$</p>
<p>This allows the model to focus purely on reconstruction first (using the full latent capacity), and then slowly adds the regularization pressure.</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Simple Linear Annealing Scheduler</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">def</span> <span style="color:#a6e22e">get_kl_weight</span>(step, total_steps, max_val<span style="color:#f92672">=</span><span style="color:#ae81ff">1.0</span>):
</span></span><span style="display:flex;"><span>    val <span style="color:#f92672">=</span> (step <span style="color:#f92672">/</span> total_steps) <span style="color:#f92672">*</span> max_val
</span></span><span style="display:flex;"><span>    <span style="color:#66d9ef">return</span> min(max(val, <span style="color:#ae81ff">0.0</span>), max_val)
</span></span><span style="display:flex;"><span>
</span></span><span style="display:flex;"><span><span style="color:#75715e"># In your training loop:</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> epoch <span style="color:#f92672">in</span> range(epochs):
</span></span><span style="display:flex;"><span>    beta <span style="color:#f92672">=</span> get_kl_weight(epoch, warmup_epochs)
</span></span><span style="display:flex;"><span>    loss <span style="color:#f92672">=</span> recon_loss <span style="color:#f92672">+</span> beta <span style="color:#f92672">*</span> kl_loss
</span></span></code></pre></div><h4 id="parameterizing-standard-deviation">Parameterizing Standard Deviation</h4>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span>    <span style="color:#66d9ef">def</span> <span style="color:#a6e22e">_sigma_to_std</span>(self, sigma: torch<span style="color:#f92672">.</span>Tensor, eps: float <span style="color:#f92672">=</span> DEFAULT_EPS) <span style="color:#f92672">-&gt;</span> torch<span style="color:#f92672">.</span>Tensor:
</span></span><span style="display:flex;"><span>        <span style="color:#e6db74">&#34;&#34;&#34;Convert sigma parameter to standard deviation.&#34;&#34;&#34;</span>
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">if</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">is</span> <span style="color:#f92672">not</span> <span style="color:#66d9ef">None</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>sigmoid(sigma) <span style="color:#f92672">*</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>bound_std <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">elif</span> self<span style="color:#f92672">.</span>config<span style="color:#f92672">.</span>use_softplus_std:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> F<span style="color:#f92672">.</span>softplus(sigma) <span style="color:#f92672">+</span> eps
</span></span><span style="display:flex;"><span>        <span style="color:#66d9ef">else</span>:
</span></span><span style="display:flex;"><span>            <span style="color:#66d9ef">return</span> torch<span style="color:#f92672">.</span>exp(<span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> sigma)  <span style="color:#75715e"># sigma represents log-variance</span>
</span></span></code></pre></div><p>Parameterizing the mean of the latent distribution is straightforward since $\mu \in \mathbb{R}$. However, the standard deviation $\sigma$ must be strictly positive (as must the variance $\sigma^2$). This type of constrained optimization is challenging for neural networks.</p>
<p><strong>Log-Variance</strong>
One common approach is to have the network output the <strong>log-variance</strong> ($\log \sigma^2$). This is what the original VAE paper did. The idea is to allow the network to output any real number and treat that value as the log-variance, $s = \log \sigma^2$. We can then compute the standard deviation as $\sigma = \exp(0.5 s)$, which is always positive.</p>
<p>The KL divergence formula simplifies nicely with this parameterization:</p>
<p>$$
\text{KL}( \mathcal{N}(\mu, \sigma^2) || \mathcal{N}(0, 1) ) = \frac{1}{2} \sum_{i=1}^d (\mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2)
$$</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> s<span style="color:#f92672">.</span>exp() <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> s, dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>)
</span></span></code></pre></div><p><strong>Softplus Standard Deviation</strong>
An alternative is to have the network output $\sigma$ directly. This must be handled with care to ensure positivity. Strictly positive activations like <code>softplus</code> are required. Activations like <code>ReLU</code> can output zero, leading to numerical instability (during training) and deterministic behavior (during sampling). Additionally, adding a small epsilon value ensures numerical stability by preventing $\sigma$ from being exactly zero.</p>
<p>The KL divergence formula becomes slightly more complex:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#ae81ff">0.5</span> <span style="color:#f92672">*</span> torch<span style="color:#f92672">.</span>sum(
</span></span><span style="display:flex;"><span>    mu<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">-</span> <span style="color:#ae81ff">1</span> <span style="color:#f92672">-</span> torch<span style="color:#f92672">.</span>log(std<span style="color:#f92672">.</span>pow(<span style="color:#ae81ff">2</span>) <span style="color:#f92672">+</span> eps), dim<span style="color:#f92672">=</span><span style="color:#ae81ff">1</span>
</span></span><span style="display:flex;"><span>)
</span></span></code></pre></div><p><strong>Bounded Standard Deviation</strong>
Another option is to bound the standard deviation to a maximum value using a <code>sigmoid</code> transformation (or similar). This replaces mapping to $(0, \infty)$ with mapping to $(0, \text{bound})$. This helps prevent extremely high variance values that might destabilize training, while limiting the expressiveness of the latent distribution. Like with <code>softplus</code>, adding a small epsilon ensures numerical stability by preventing $\sigma$ from being exactly zero or approaching it too closely.</p>
<p><strong>Gradient Behavior</strong>
All parameterizations can work well in practice and have different gradient behaviors. Think of $g(s)$ as a transformation function from the network output to the proper domain of $\sigma$ (or $\sigma^2$); in the log-variance case, $g(s) = \exp(s)$, while in the softplus case, $g(s) = \text{softplus}(s) + \epsilon$.</p>
<p>The gradient of the loss with respect to these outputs can be written using the chain rule:</p>
<p>$$
\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \sigma} \cdot \frac{\partial g(s)}{\partial s}
$$</p>
<p>where $\frac{\partial g(s)}{\partial s}$ is the derivative of the transformation function.</p>
<p>We need to guard against two pathological cases:</p>
<ul>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow 0$: This leads to vanishing gradients, making it hard for the network to learn.</li>
<li>$\frac{\partial g(s)}{\partial s} \rightarrow \infty$: This leads to exploding gradients, causing instability during training and potentially divergence.</li>
</ul>
<p>The log-variance parameterization, with its exponential transformation that is its own derivative, exhibits both issues at extreme values. If $s \rightarrow -\infty$, then $\sigma \rightarrow 0$ and the gradient vanishes. If $s \rightarrow \infty$, then $\sigma \rightarrow \infty$ and the gradient explodes. Since the interval $(0, 1)$ is mapped to $(-\infty, 0)$ in log-space, it&rsquo;s much more difficult for the network to drive $\sigma$ to small values. In practice, exploding gradients at high values have been more problematic in my experience. Gradient clipping, learning rate scheduling, and clamping the log-variance output to a maximum value can help mitigate this.</p>
<p>What about softplus? The derivative of <code>softplus</code> is the <code>sigmoid</code> function, which smoothly maps $(-\infty, \infty)$ to $(0, 1)$. Gradients are always bounded by unity, preventing explosion (barring explosion from other parts of the network). However, as $s \rightarrow -\infty$, the gradient approaches zero, leading to vanishing gradients. Adding a small epsilon helps mitigate this, ensuring that $\sigma$ never gets too close to zero. Nonetheless, learning can still slow down.</p>
<p>For bounded standard deviation, the derivative of the <code>sigmoid</code> function is also bounded, preventing exploding gradients. (The gradient of <code>sigmoid</code> is defined in terms of itself: $\text{sig}&rsquo;(x) = \text{sig}(x)(1 - \text{sig}(x))$; its maximum value is $0.25$ at $x=0$.)</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/gradient_behaviors.webp"
         alt="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         title="Graph comparing gradient behaviors of log-variance, softplus, and bounded standard deviation parameterizations"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Gradient behaviors of different standard deviation parameterizations: Log-Variance (exponential), Softplus, and Bounded Standard Deviation (sigmoid). Each has unique characteristics affecting training stability.</figcaption>
    
</figure>

<h2 id="experiments">Experiments</h2>
<h3 id="2d-mnist-vae-with-different-std-dev-parameterizations">2D MNIST VAE with Different Std. Dev. Parameterizations</h3>
<p>First, let&rsquo;s run an experiment that is close to what was done in the original VAE paper. We&rsquo;ll use MNIST as our dataset, a simple feedforward architecture with <code>tanh</code> activations, and the log-variance parameterization for the latent distribution.</p>
<p>Some of the differences from the original paper include:</p>
<ul>
<li>Using a hidden size of 512 (the original used 500)</li>
<li>Using the AdamW optimizer (the original used vanilla Adagrad)</li>
<li>Applying similar weight decay, doing so quite differently due to the optimizer change</li>
<li>Focusing primarily on 2D latent spaces (for now)</li>
</ul>
<p>This results in a network with 807,700 parameters.
I train each model for 150 epochs at most and highlight the best based on the reconstruction loss on the test set.
Just for fun, I sweep across different standard deviation parameterizations and learning rate warmup strategies.</p>
<table>
  <thead>
      <tr>
          <th>Std. Dev. Param</th>
          <th>Warmup Steps</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Log-Variance</td>
          <td>0</td>
          <td>140.88</td>
          <td>6.96</td>
          <td>147.84</td>
      </tr>
      <tr>
          <td>Log-Variance</td>
          <td>600</td>
          <td>141.41</td>
          <td>6.63</td>
          <td>148.04</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>0</td>
          <td>141.51</td>
          <td>6.56</td>
          <td>148.07</td>
      </tr>
      <tr>
          <td>Softplus</td>
          <td>600</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>0</td>
          <td>140.96</td>
          <td>6.82</td>
          <td>147.79</td>
      </tr>
      <tr>
          <td>Bounded Std. Dev. (10)</td>
          <td>600</td>
          <td>141.78</td>
          <td>6.68</td>
          <td>148.45</td>
      </tr>
  </tbody>
</table>
<p>From this summary table, all three parameterizations work well. The differences in final loss values are quite small. This could be due to the simplicity of the dataset and model architecture, further amplified by forcing the network to compress images into a very low-dimensional latent space (2D).</p>
<p>Since the softplus parameterization with learning rate warmup achieved the best reconstruction loss, let&rsquo;s visualize some of its training dynamics and results more closely.</p>
<h4 id="loss-dynamics">Loss Dynamics</h4>
<p>To understand the VAE&rsquo;s behavior, we must look at the ELBO and its two components: the Reconstruction Loss and the KL Divergence.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-elbo_epochs.webp"
         alt="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing ELBO across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Total ELBO: Training and testing ELBO across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstruction_loss_epochs.webp"
         alt="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing reconstruction loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Reconstruction Loss: Training and testing reconstruction loss across 150 epochs.</figcaption>
    
</figure>
















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-kl_loss_epochs.webp"
         alt="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         title="Plot showing training and testing KL divergence loss across 150 epochs for the softplus parameterization with learning rate warmup"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">KL Divergence: Training and testing KL divergence loss across 150 epochs.</figcaption>
    
</figure>

<p>These plots reveal a clear narrative:</p>
<ol>
<li><strong>Rapid Initial Learning:</strong> Performance skyrockets in the first ~15 epochs.</li>
<li><strong>Overfitting:</strong> The <strong>Reconstruction Loss</strong> (middle) flatlines for the test set while continuing to improve for training, a classic sign of memorization.</li>
<li><strong>The Balancing Act:</strong> The <strong>KL Divergence</strong> (bottom) initially rises (&ldquo;The Cost of Learning&rdquo;) as the model stretches the latent space to encode digits, then saturates.</li>
<li><strong>Equilibrium:</strong> The total <strong>ELBO</strong> (top) improves slowly, driven by the model finding the optimal trade-off between reconstruction and regularization. Notice that Test and Train KL tracks closely: a sign of good regularization!</li>
</ol>
<p><strong>Visualizing the VAE Trade-Off: BCE vs. KL</strong></p>
<p>While the line plots visualize progress over time, they miss the evolving <em>relationship</em> between our two competing objectives.</p>
<p>A VAE is fundamentally a multi-objective optimization problem. We want to:</p>
<ol>
<li>Minimize Reconstruction Loss (BCE)</li>
<li>Minimize KL Divergence</li>
</ol>
<p>Combining them as the ELBO is common and effective, though it can mask some of the underlying dynamics.</p>
<p>These two goals are in direct conflict. To get perfect reconstruction (BCE = 0), the encoder would need to &ldquo;memorize&rdquo; each input, mapping it to a unique, precise point in latent space. This would cause the KL divergence to skyrocket, as these specific, &ldquo;pointy&rdquo; distributions are nothing like our smooth <code>N(0, 1)</code> prior.</p>
<p>Conversely, to get perfect KL divergence (KL = 0), the encoder must <em>always</em> output <code>N(0, 1)</code>, regardless of the input. This perfectly matches the prior. Since the latent code $\mathbf{z}$ now contains zero information about the input $\mathbf{x}$, the decoder can only learn to output the &ldquo;average&rdquo; image, resulting in terrible reconstruction.</p>
<p>The training process is a search for the best compromise.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence, showing the training path from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training path on the Test set, plotting Reconstruction Loss (BCE) vs. KL Divergence. The model&rsquo;s journey clearly shows the trade-off between these two objectives.</figcaption>
    
</figure>

<p>This plot shows the test set&rsquo;s BCE (y-axis) vs. KL Divergence (x-axis) at every evaluation step. The color gradient from cool (blue) to warm (red) represents the training progress from Epoch 0 to 150.</p>
<p>Here&rsquo;s how to interpret this training path:</p>
<ol>
<li>
<p><strong>The Start (Green Diamond, ~Epoch 0):</strong> The model starts at the top-left.</p>
<ul>
<li><strong>High BCE (Reconstruction):</strong> The decoder is random and hasn&rsquo;t learned to reconstruct anything. Reconstruction is terrible.</li>
<li><strong>Low KL Divergence:</strong> The <em>encoder</em> is also random. Its output distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ are a random mess. On average, this &ldquo;mess&rdquo; is coincidentally close to the &ldquo;mess&rdquo; of the prior $p_{\theta}(\mathbf{z})$, so the KL penalty is low. The model isn&rsquo;t encoding any useful information yet, so it&rsquo;s not paying a high price for it.</li>
</ul>
</li>
<li>
<p><strong>Phase 1: The Initial Plunge (Blue Path):</strong> The path moves almost <em>straight down</em>.</p>
<ul>
<li><strong>BCE Plummets:</strong> The model&rsquo;s first and easiest task is to learn to reconstruct <em>something</em>. The optimizer finds massive, easy gains by making the decoder output &ldquo;blurry digits&rdquo; to replace the initial noise.</li>
<li><strong>KL Stays Low:</strong> The model achieves this huge reconstruction win without needing to learn a very complex latent space. It&rsquo;s the &ldquo;low-hanging fruit&rdquo; of training.</li>
</ul>
</li>
<li>
<p><strong>Phase 2: The Trade-Off (The &ldquo;Elbow&rdquo;):</strong> The path stops dropping vertically and starts moving to the <em>right and down</em>.</p>
<ul>
<li><strong>&ldquo;Spending&rdquo; KL to &ldquo;Buy&rdquo; Reconstruction:</strong> This is the true VAE trade-off in action. The easy wins are gone. To make the reconstructions sharper and more accurate (lowering BCE further), the model must now learn a more complex, informative latent representation.</li>
<li>It &ldquo;stretches&rdquo; the latent distributions $q_{\phi}(\mathbf{z} | \mathbf{x})$ to encode more details about each specific digit. This &ldquo;stretching&rdquo; moves it further from the simple <code>N(0, 1)</code> prior, and the KL divergence (the &ldquo;cost&rdquo;) goes up.</li>
</ul>
</li>
<li>
<p><strong>The End Game (Red Path &amp; Star):</strong> The path settles in the bottom-right corner.</p>
<ul>
<li><strong>Finding the &ldquo;Elbow&rdquo;:</strong> The model finds an equilibrium. It has pushed the KL divergence as high as it&rsquo;s &ldquo;worth&rdquo; for the reconstruction gains it gets. Trying to get even better reconstruction (moving further down) would cost an enormous, disproportionate amount in KL divergence (moving far to the right), and the total loss would increase.</li>
<li><strong>Best Recon (Orange Star):</strong> The best reconstruction model (Epoch 118) is found right at this &ldquo;elbow,&rdquo; representing the best-found balance point on the trade-off frontier.</li>
</ul>
</li>
</ol>
<p>This single plot visualizes the entire training dynamic as a journey along the <strong>Pareto frontier</strong>: the set of optimal solutions where you can&rsquo;t improve one objective (BCE) without worsening the other (KL).</p>
<h4 id="generative-performance">Generative Performance</h4>
<p>Let&rsquo;s take a look at how well this model can decode samples.</p>
<p><strong>Reconstruction Performance</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         title="Grid of original and reconstructed MNIST images from the test set using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using the trained VAE model.</figcaption>
    
</figure>

<p>Immediately, we see a couple of key points:</p>
<ul>
<li>Reconstructions are quite blurry compared to the originals. This is expected given the low capacity of the model and the extreme compression into a 2D latent space. General structure is typically preserved, while fine details are lost.</li>
<li>The network struggles with 4s and 9s, often mixing them up or producing ambiguous shapes. This is a common failure mode in MNIST models due to the similarity of these digits.</li>
</ul>
<p><strong>Sampling from the Prior</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using the trained VAE model.</figcaption>
    
</figure>

<p>If we sample from the prior <code>N(0, 1)</code> and decode those samples, we get a variety of digit-like images. From this, we get a pretty rich representation of digits. Almost all digits appear to be featured in this random sampling. Again, we see the standard blurriness.</p>
<p><strong>Sweeping the Latent Space</strong></p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_interpolation.webp"
         alt="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         title="Grid of images generated by sweeping across the 2D latent space of the trained VAE model"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Images generated by sweeping across the 2D latent space of the trained VAE model.</figcaption>
    
</figure>

<p>We can select two points at random (here, two zeros), embed them into our latent space and then walk across that latent space to interpolate between two data points.
Here, we see a walk that takes us from a zero that is askew to one that is more upright.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-generation_latent_sweep.webp"
         alt="2D latent sweep, varying one dimension at a time while holding the other constant"
         title="2D latent sweep, varying one dimension at a time while holding the other constant"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent sweep, varying one dimension at a time while holding the other constant.</figcaption>
    
</figure>

<p>Finally, we can also sweep each latent dimension independently to see how they affect the generated images.</p>
<ol>
<li>Sweeping <code>z_1</code> (top row), we see a 5 become an 8 and then a 9. The slant shifts from left to right as we sweep the dimension.</li>
<li>Sweeping <code>z_2</code> (bottom row), we see a 4 become a 9 and then an 8. Then it becomes a 3, a 2, some nonsense, and a 6.</li>
</ol>
<p>So clearly each latent dimension is encoding some high-level features of the digits, and we can manipulate those features by moving in latent space.</p>
<h4 id="inspecting-the-latent-space">Inspecting the Latent Space</h4>
<p>What does the actual latent space look like?</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_combined.webp"
         alt="2D latent space visualization with points colored by their true digit labels"
         title="2D latent space visualization with points colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with points colored by their true digit labels (left) and 2D heatmap of latent space density (right).</figcaption>
    
</figure>

<p>Even without class information, the network organizes the latent space to encode digit structure effectively. It also becomes immediately apparent why 4s and 9s are so confused by the model. That region is a dense mixture of the two.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-latent_marginals.webp"
         alt="1D histograms of each latent dimension compared to the standard normal distribution"
         title="1D histograms of each latent dimension compared to the standard normal distribution"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">1D histograms of each latent dimension compared to the standard normal distribution.</figcaption>
    
</figure>

<p>We can also look at the marginal distributions of each latent dimension to see how well they match the prior <code>N(0, 1)</code>. Here, <code>z_1</code> is closer to the prior than <code>z_2</code>. <code>z_2</code> exhibits a bimodal marginal distribution, indicating that the encoder is using this dimension to separate two distinct clusters of data.</p>
<p>We also might want to understand how the log-variance of the latent distributions behaves.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/z2-logvar_combined.webp"
         alt="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         title="2D latent space visualization with log-variance values and 1D histograms of log-variance for each latent dimension"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D latent space visualization with log-variance values with respect to digit class (left) and 2D heatmap of log-variance magnitude (right).</figcaption>
    
</figure>

<p>For the most part, we see similar concentration. Some digits are more concentrated than others, though in general the difference is slight.</p>
<h3 id="beyond-2d-higher-dimensional-latent-spaces">Beyond 2D: Higher-Dimensional Latent Spaces</h3>
<p>What happens as we increase the latent dimensionality? We must do dimensionality reduction to visualize latent spaces, giving us an approximate sense of how the latent space is organized.</p>
<table>
  <thead>
      <tr>
          <th>Latent Dimensionality</th>
          <th>Test Recon. Loss</th>
          <th>Test KL Loss</th>
          <th>Test Total Loss</th>
          <th>KL per Dim</th>
          <th>Active Dims (KL &gt; 0.1)</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>2</td>
          <td>140.37</td>
          <td>6.67</td>
          <td>147.04</td>
          <td>3.34</td>
          <td>2</td>
      </tr>
      <tr>
          <td>4</td>
          <td>114.94</td>
          <td>10.61</td>
          <td>125.56</td>
          <td>2.65</td>
          <td>4</td>
      </tr>
      <tr>
          <td>8</td>
          <td>89.46</td>
          <td>16.84</td>
          <td>106.31</td>
          <td>2.11</td>
          <td>8</td>
      </tr>
      <tr>
          <td>16</td>
          <td>76.57</td>
          <td>23.65</td>
          <td>100.21</td>
          <td>1.48</td>
          <td>16</td>
      </tr>
      <tr>
          <td>32</td>
          <td>74.65</td>
          <td>25.59</td>
          <td>100.25</td>
          <td>0.80</td>
          <td>24</td>
      </tr>
  </tbody>
</table>
<p>As we double the dimensionality, we see a dominant trend at first:</p>
<ul>
<li>The reconstruction loss goes down</li>
<li>The KL loss goes up</li>
<li>The KL loss per dimension goes down</li>
</ul>
<p>Something odd happens when we jump from 16 to 32 latent dimensions: some of our latent dimensions become degenerate and stop encoding useful information.
This could be an indication we need to choose our hyperparameters a little more cautiously. Perhaps we need a different architecture. Or maybe there is an intrinsic limit to the dimensionality needed for this dataset past which it&rsquo;s not really helpful to keep scaling the latent dimension.</p>
<h4 id="training-dynamics">Training Dynamics</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/loss_scatter_epochs.webp"
         alt="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         title="Scatter plot of Test BCE vs KL Divergence for different latent dimensionalities, showing training paths from epoch 0 to 150"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">The training paths on the Test set for different latent dimensionalities, plotting Reconstruction Loss (BCE) vs. KL Divergence. Each path shows the model&rsquo;s journey, clearly illustrating the trade-off between these two objectives.</figcaption>
    
</figure>

<p>The training dynamics show the battle between reconstruction and KL divergence for different latent dimensionalities. As we increase the latent dimensionality, the oscillation in the KL divergence becomes more pronounced. Particularly chaotic is the $D=16$ case, which struggles to find a stable equilibrium. By the time we expand to $D=32$, the KL penalty seems to overpower the ability to encode information in the latent space, leading to many inactive dimensions. The drop in KL complexity has staircase-like steps without clearly gaining reconstruction ability.</p>
<h4 id="reconstruction-and-generation">Reconstruction and Generation</h4>
<p>As we increase the latent dimensionality, the reconstruction quality improves significantly.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/reconstructions.webp"
         alt="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         title="Grid of original and reconstructed MNIST images from the test set using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">Original (top row) vs. Reconstructed (bottom row) MNIST images from the test set using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>As we increase the dimensionality, we see the increase in quality we&rsquo;d expect given the reduction in BCE reconstruction loss. In the jump to 4D, we&rsquo;re able to better resolve the differences between 4s and 9s. Images become much sharper by the time we hit 16 dimensions. The differences between 16 and 32 dimensions, however, are marginal.</p>















<figure class="post-figure center ">
    <img src="/img/vae-tut/samples.webp"
         alt="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         title="Grid of MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">MNIST-like images generated by sampling from the prior distribution using trained VAE models with different latent dimensionalities.</figcaption>
    
</figure>

<p>Sampling quality also improves with latent dimensionality. Images are sharper as we increase the dimensionality. However, the space seems to get sparser as we increase to the largest dimensionalities, which makes sense given the size and nature of our dataset.</p>
<h4 id="latent-space-visualizations">Latent Space Visualizations</h4>















<figure class="post-figure center ">
    <img src="/img/vae-tut/latent_combined.webp"
         alt="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         title="2D PCA projections of higher-dimensional latent spaces colored by their true digit labels"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption">2D PCA projections of higher-dimensional latent spaces colored by their true digit labels.</figcaption>
    
</figure>

<p>The challenge with visualizing higher-dimensional latent spaces is that we must reduce their dimensionality to 2D. PCA struggles to capture the variance of higher dimensionalities. The 4D and 8D plots suggest increasingly better separation of the numeric classes. However, the 16D and 32D plots only show 10-20% of the variance and give a misleading image of overlap.</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this tutorial, we&rsquo;ve journeyed from the core theory of Variational Autoencoders to a practical, modern PyTorch implementation and a series of experiments on the MNIST dataset. Our findings highlight several key takeaways for practitioners:</p>
<ol>
<li>
<p><strong>The VAE is a Balancing Act:</strong> The fundamental tension between reconstruction fidelity and latent space regularization is the core of the VAE. Our visualization of the BCE vs. KL loss trade-off clearly showed training as a search for an optimal point on this Pareto frontier, where improving one objective necessarily means sacrificing the other.</p>
</li>
<li>
<p><strong>Latent Dimensionality is a Critical Hyperparameter:</strong> Increasing the latent dimension consistently improved reconstruction quality with diminishing returns. As we saw in the jump from 16 to 32 dimensions, too much capacity can lead to &ldquo;inactive&rdquo; dimensions, where the KL penalty overpowers the model&rsquo;s ability to encode useful information. This demonstrates that choosing the right latent size is crucial for both performance and efficiency.</p>
</li>
<li>
<p><strong>VAEs Learn Meaningful Unsupervised Representations:</strong> Without any labels, our VAE successfully organized the latent space, clustering similar digits and enabling smooth interpolations. This underscores the power of VAEs for unsupervised learning, dimensionality reduction, and discovering the underlying structure in complex data.</p>
</li>
<li>
<p><strong>Implementation Details Matter:</strong> While different standard deviation parameterizations yielded similar results on this simple problem, understanding their gradient behaviors is key for tackling more complex datasets where training stability can be a major challenge. Proper loss scaling is similarly crucial to prevent one term from dominating the other and leading to issues like posterior collapse.</p>
</li>
</ol>
<p>While the classic VAE produces characteristically blurry reconstructions, it remains a foundational generative model. The principles we&rsquo;ve explored here (the ELBO, the reparameterization trick, and the trade-off between reconstruction and regularization) are central to many more advanced generative models used today.</p>
<p><strong>Questions or feedback?</strong> Feel free to reach out. I&rsquo;d love to hear about your experiences with VAE experiments!</p>
]]></content:encoded></item><item><title>Understanding GANs: From Fundamentals to Objective Functions</title><link>https://hunterheidenreich.com/posts/what-is-a-gan/</link><pubDate>Sat, 18 Aug 2018 00:00:00 +0000</pubDate><guid>https://hunterheidenreich.com/posts/what-is-a-gan/</guid><description>A complete guide to Generative Adversarial Networks (GANs), covering intuitive explanations, mathematical foundations, and objective functions.</description><content:encoded><![CDATA[<h2 id="understanding-generative-models">Understanding Generative Models</h2>
<p>Modern generative AI is dominated by diffusion models and autoregressive transformers. The adversarial training dynamics and objective functions introduced by <a href="https://arxiv.org/abs/1406.2661">Generative Adversarial Networks</a> (GANs) still inform how the field thinks about loss-function design and training stability today. Before diving into GANs, let&rsquo;s establish what we&rsquo;re trying to accomplish with generative models.</p>
<p><strong>The core goal</strong>: Create a system that can generate new, realistic data that appears to come from the same distribution as our training data.</p>
<p>Think of having a model that can create images, text, or audio that are difficult to distinguish from human-created content. This is what generative modeling aims to achieve.</p>
<h3 id="the-mathematical-foundation">The Mathematical Foundation</h3>
<p>Generative models aim to estimate the probability distribution of real data. If we have parameters $\theta$, we want to find the optimal $\theta^*$ that maximizes the likelihood of observing our real samples:</p>
<p>$$
\theta^* = \arg\max_\theta \prod_{i=1}^{n} p_\theta(x_i)
$$</p>
<p>This is equivalent to minimizing the distance between our estimated distribution and the true data distribution. A common distance measure is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>. Maximizing log-likelihood equals minimizing KL divergence.</p>
<h3 id="two-approaches-to-generative-modeling">Two Approaches to Generative Modeling</h3>
<h4 id="explicit-distribution-models">Explicit Distribution Models</h4>
<p>These models define an explicit probability distribution and refine it through training.</p>
<p><strong>Example</strong>: <a href="https://arxiv.org/abs/1606.05908">Variational Auto-Encoders</a> (VAEs) require:</p>
<ul>
<li>An explicitly assumed prior distribution</li>
<li>A likelihood distribution</li>
<li>A &ldquo;variational approximation&rdquo; to evaluate performance</li>
</ul>
<h4 id="implicit-distribution-models">Implicit Distribution Models</h4>
<p>These models learn to generate data by indirectly sampling from a learned distribution. GANs exemplify this implicit approach, learning distributions through adversarial competition.</p>















<figure class="post-figure center ">
    <img src="/img/gen_ai_types.webp"
         alt="Types of deep generative models showing taxonomy"
         title="Types of deep generative models showing taxonomy"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>Taxonomy of Deep Generative Models</strong>: GANs fall into the implicit density category, learning distributions through adversarial training. <em>Source: NeurIPS 2016 tutorial on Generative Adversarial Networks</em></figcaption>
    
</figure>

<h2 id="the-gan-architecture-a-game-of-deception">The GAN Architecture: A Game of Deception</h2>
<p>Generative Adversarial Networks get their name from three key components:</p>
<ul>
<li><strong>Generative</strong>: They create new data</li>
<li><strong>Adversarial</strong>: Two networks compete against each other</li>
<li><strong>Networks</strong>: Built using neural networks</li>
</ul>
<p>The core innovation is the adversarial setup: two neural networks compete against each other, driving mutual improvement.</p>















<figure class="post-figure center ">
    <img src="/img/GAN-70.webp"
         alt="Diagram showing data flow through a GAN architecture"
         title="Diagram showing data flow through a GAN architecture"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>GAN Data Flow</strong>: The generator creates fake samples from random noise, while the discriminator tries to distinguish real from fake data. This adversarial competition drives both networks to improve.</figcaption>
    
</figure>

<h3 id="the-generator-the-forger">The Generator: The Forger</h3>
<p><strong>Role</strong>: Create convincing fake data from random noise</p>
<p>The generator network $G$ learns a mapping function:
$$z \rightarrow G(z) \approx x_{\text{real}}$$</p>
<p>Where:</p>
<ul>
<li>$z$ is a random latent vector (the &ldquo;noise&rdquo;)</li>
<li>$G(z)$ is the generated sample</li>
<li>The goal is making $G(z)$ indistinguishable from real data</li>
</ul>
<p><strong>Key insight</strong>: The latent space $z$ is continuous, meaning small changes in $z$ produce smooth, meaningful changes in the generated output.</p>
<h3 id="the-discriminator-the-detective">The Discriminator: The Detective</h3>
<p><strong>Role</strong>: Distinguish between real and generated samples</p>
<p>The discriminator network $D$ outputs a probability:
$$D(x) = P(\text{x is real})$$</p>
<ul>
<li>$D(x) \approx 1$ for real samples</li>
<li>$D(x) \approx 0$ for fake samples</li>
</ul>
<p>It functions as an &ldquo;authenticity detector&rdquo; that progressively improves.</p>
<h3 id="the-adversarial-competition">The Adversarial Competition</h3>
<p>This adversarial dynamic drives the training process. The generator and discriminator have <strong>directly opposing objectives</strong>:</p>
<table>
  <thead>
      <tr>
          <th>Generator Goal</th>
          <th>Discriminator Goal</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Fool the discriminator</td>
          <td>Correctly classify all samples</td>
      </tr>
      <tr>
          <td>Minimize $D(G(z))$</td>
          <td>Maximize $D(x_{\text{real}})$ and minimize $D(G(z))$</td>
      </tr>
      <tr>
          <td>&ldquo;Create convincing fakes&rdquo;</td>
          <td>&ldquo;Never be fooled&rdquo;</td>
      </tr>
  </tbody>
</table>
<p>This creates a dynamic where both networks continuously improve:</p>
<ul>
<li>Generator creates better fakes to fool the discriminator</li>
<li>Discriminator becomes better at detecting fakes</li>
<li>The cycle continues until equilibrium</li>
</ul>















<figure class="post-figure center ">
    <img src="/img/GAN-SUMMARY-50.webp"
         alt="Illustration of GAN training process showing adversarial competition"
         title="Illustration of GAN training process showing adversarial competition"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>The Adversarial Training Process</strong>: Through competition, both networks improve. The generator learns to create increasingly realistic samples while the discriminator becomes more discerning.</figcaption>
    
</figure>

<h2 id="learning-through-metaphors">Learning Through Metaphors</h2>
<p>Relatable analogies often clarify complex concepts. Here are two metaphors that capture different aspects of how GANs work.</p>
<h3 id="the-art-forger-vs-critic">The Art Forger vs. Critic</h3>
<p><strong>Generator = Art Forger</strong><br>
<strong>Discriminator = Art Critic</strong></p>
<p>A criminal forger tries to create fake masterpieces, while an art critic must identify authentic works. Each interaction teaches both parties:</p>
<ul>
<li>The forger learns what makes art look authentic</li>
<li>The critic develops a keener eye for detecting fakes</li>
<li>Eventually, the forger becomes so skilled that even experts can&rsquo;t tell the difference</li>
</ul>
<p><em>This captures the adversarial nature and continuous improvement aspect of GANs.</em></p>
<h3 id="the-counterfeiter-vs-bank-teller">The Counterfeiter vs. Bank Teller</h3>
<p><strong>Generator = Counterfeiter</strong><br>
<strong>Discriminator = Bank Teller</strong></p>
<p>Day 1: Criminal brings a crayon drawing of a dollar bill. Even a new teller spots this fake.</p>
<p>Day 100: The counterfeiter has learned better techniques. The teller has developed expertise in security features.</p>
<p>Day 1000: The fake money is so convincing that detecting it requires advanced equipment.</p>
<p><em>This illustrates the progressive improvement and escalating sophistication in both networks.</em></p>
<h2 id="the-mathematical-foundation-1">The Mathematical Foundation</h2>
<p>Now let&rsquo;s examine the mathematical framework that makes GANs work. The core of GAN training is solving a <strong>minimax optimization problem</strong>.</p>
<h3 id="the-minimax-objective">The Minimax Objective</h3>
<p>$$
\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]
$$</p>
<p><strong>Breaking this down:</strong></p>
<ul>
<li>$\mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)]$: The expected log-probability for real data.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term to correctly classify real samples.</li>
</ul>
</li>
<li>$\mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$: The expected log-probability for fake data being correctly identified as fake.
<ul>
<li><strong>Discriminator&rsquo;s Goal</strong>: Maximize this term.</li>
<li><strong>Generator&rsquo;s Goal</strong>: Minimize this term to fool the discriminator.</li>
</ul>
</li>
</ul>
<h3 id="why-minimax">Why &ldquo;Minimax&rdquo;?</h3>
<ul>
<li><strong>Discriminator ($D$)</strong>: Tries to <strong>maximize</strong> the objective → Better at distinguishing real from fake.</li>
<li><strong>Generator ($G$)</strong>: Tries to <strong>minimize</strong> the objective → Better at fooling the discriminator.</li>
</ul>
<h3 id="a-practical-challenge-vanishing-gradients">A Practical Challenge: Vanishing Gradients</h3>
<p>The minimax objective presents a practical problem early in training. When the generator is poor, the discriminator can easily distinguish real from fake samples with high confidence ($D(G(z)) \approx 0$). This causes $\log(1 - D(G(z)))$ to saturate and results in vanishing gradients for the generator, which effectively stalls learning.</p>
<p><strong>The Solution</strong>: Practitioners typically train the generator to <strong>maximize</strong> $\log(D(G(z)))$ to provide stronger gradients early in training. This non-saturating heuristic prevents the learning process from stalling.</p>
<h3 id="the-training-process">The Training Process</h3>
<p>The beauty of GANs lies in their alternating optimization:</p>
<ol>
<li><strong>Fix $G$, train $D$</strong>: Make the discriminator optimal for the current generator</li>
<li><strong>Fix $D$, train $G$</strong>: Improve the generator against the current discriminator</li>
<li><strong>Repeat</strong>: Continue until reaching Nash equilibrium</li>
</ol>
<h3 id="theoretical-goal-nash-equilibrium">Theoretical Goal: Nash Equilibrium</h3>
<p>At convergence, the discriminator outputs $D(x) = 0.5$ for all samples, meaning it can&rsquo;t distinguish between real and fake data. This indicates that $p_{\text{generator}} = p_{\text{data}}$. Our generator has learned the true data distribution.</p>
<h2 id="the-evolution-of-objective-functions">The Evolution of Objective Functions</h2>
<p>The objective function is the mathematical heart of any GAN. It defines how we measure the &ldquo;distance&rdquo; between our generated distribution and the real data distribution. This choice profoundly impacts:</p>
<ul>
<li><strong>Training stability</strong>: Some objectives lead to more stable convergence</li>
<li><strong>Sample quality</strong>: Different losses emphasize different aspects of realism</li>
<li><strong>Mode collapse</strong>: The tendency to generate limited variety</li>
<li><strong>Computational efficiency</strong>: Some objectives are faster to compute</li>
</ul>
<p>The original GAN uses Jensen-Shannon Divergence (JSD), but researchers have discovered many alternatives that address specific limitations. Let&rsquo;s explore this evolution.</p>
<h3 id="the-original-gan-jensen-shannon-divergence">The Original GAN: Jensen-Shannon Divergence</h3>
<p>The foundational GAN minimizes the Jensen-Shannon Divergence:</p>
<p>$$
\text{JSD}(P, Q) = \frac{1}{2} \text{KL}(P | M) + \frac{1}{2} \text{KL}(Q | M)
$$</p>
<p>Where $M = \frac{1}{2}(P + Q)$ is the average distribution, and $\text{KL}$ is the <a href="https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence">Kullback-Leibler Divergence</a>.</p>
<p><strong>Strengths</strong>: Solid theoretical foundation, introduced adversarial training<br>
<strong>Limitations</strong>: Can suffer from vanishing gradients and mode collapse</p>
<h3 id="wasserstein-gan-wgan">Wasserstein GAN (WGAN)</h3>
<p>The <a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a> replaced Jensen-Shannon divergence with the Earth-Mover (Wasserstein) distance, which gives meaningful gradients even when the real and generated distributions do not overlap.</p>
<h4 id="understanding-earth-mover-distance">Understanding Earth-Mover Distance</h4>
<p>The Wasserstein distance, also known as Earth-Mover distance, has an intuitive interpretation:</p>
<blockquote>
<p><strong>Imagine two probability distributions as piles of dirt.</strong> The Earth-Mover distance measures the minimum cost to transform one pile into the other, where cost = mass x distance moved.</p></blockquote>
<p>Mathematically:</p>
<p>$$
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{M xM} d(x, y)^p , d\gamma(x, y) \right)^{1/p}
$$</p>
<h4 id="why-earth-mover-distance-matters">Why Earth-Mover Distance Matters</h4>
<table>
  <thead>
      <tr>
          <th>Jensen-Shannon Divergence</th>
          <th>Earth-Mover Distance</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Can be discontinuous</td>
          <td><strong>Always continuous</strong></td>
      </tr>
      <tr>
          <td>May have vanishing gradients</td>
          <td><strong>Meaningful gradients everywhere</strong></td>
      </tr>
      <tr>
          <td>Limited convergence guarantees</td>
          <td><strong>Broader convergence properties</strong></td>
      </tr>
  </tbody>
</table>
<h4 id="wgan-implementation">WGAN Implementation</h4>
<p>Since we can&rsquo;t compute Wasserstein distance directly, WGAN uses the <strong>Kantorovich-Rubinstein duality</strong>:</p>
<ol>
<li><strong>Train a critic function</strong> $f$ to approximate the Wasserstein distance</li>
<li><strong>Constrain the critic</strong> to be 1-Lipschitz (using weight clipping)</li>
<li><strong>Optimize the generator</strong> to minimize this distance</li>
</ol>















<figure class="post-figure center ">
    <img src="/img/wasserstein.webp"
         alt="WGAN training results showing stable convergence"
         title="WGAN training results showing stable convergence"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>WGAN Results</strong>: Demonstrating improved training stability and meaningful loss curves. <em>Source: Wasserstein GAN paper</em></figcaption>
    
</figure>

<h4 id="key-wgan-benefits">Key WGAN Benefits</h4>
<p><strong>Meaningful loss function</strong>: Loss correlates with sample quality<br>
<strong>Improved stability</strong>: Less prone to mode collapse<br>
<strong>Theoretical guarantees</strong>: Solid mathematical foundation<br>
<strong>Better convergence</strong>: Works even when distributions don&rsquo;t overlap</p>
<h3 id="improved-wgan-solving-the-weight-clipping-problem">Improved WGAN: Solving the Weight Clipping Problem</h3>
<p><a href="https://arxiv.org/abs/1704.00028">Improved WGAN</a> (WGAN-GP) addresses a critical flaw in the original WGAN: <strong>weight clipping</strong>.</p>
<h4 id="the-problem-with-weight-clipping">The Problem with Weight Clipping</h4>
<p>Original WGAN clips weights to maintain the 1-Lipschitz constraint:</p>
<div class="highlight"><pre tabindex="0" style="color:#f8f8f2;background-color:#272822;-moz-tab-size:4;-o-tab-size:4;tab-size:4;"><code class="language-python" data-lang="python"><span style="display:flex;"><span><span style="color:#75715e"># Problematic approach</span>
</span></span><span style="display:flex;"><span><span style="color:#66d9ef">for</span> param <span style="color:#f92672">in</span> critic<span style="color:#f92672">.</span>parameters():
</span></span><span style="display:flex;"><span>    param<span style="color:#f92672">.</span>data<span style="color:#f92672">.</span>clamp_(<span style="color:#f92672">-</span><span style="color:#ae81ff">0.01</span>, <span style="color:#ae81ff">0.01</span>)
</span></span></code></pre></div><p><strong>Issues with clipping</strong>:</p>
<ul>
<li>Forces critic to use extremely simple functions</li>
<li>Pushes weights toward extreme values ($\pm c$)</li>
<li>Can lead to poor gradient flow</li>
<li>Capacity limitations hurt performance</li>
</ul>
<h4 id="the-gradient-penalty-solution">The Gradient Penalty Solution</h4>
<p>WGAN-GP introduces a <strong>gradient penalty term</strong> to constrain the critic:</p>
<p>$$
L = E_{\tilde{x} \sim P_g}[D(\tilde{x})] - E_{x \sim P_r}[D(x)] + \lambda E_{\hat{x}}[(||\nabla_{\hat{x}} D(\hat{x})||_2 - 1)^2]
$$</p>
<p>Where $\hat{x}$ are points sampled uniformly along straight lines between real and generated data points.</p>
<p><strong>Advantages</strong>:</p>
<ul>
<li>No capacity limitations</li>
<li>Better gradient flow</li>
<li>More stable training</li>
<li>Works across different architectures</li>
</ul>
<h3 id="lsgan-the-power-of-least-squares">LSGAN: The Power of Least Squares</h3>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares GAN</a> takes a different approach. It replaces the logarithmic loss with <strong>L2 (least squares) loss</strong>.</p>
<h4 id="motivation-beyond-binary-classification">Motivation: Beyond Binary Classification</h4>
<p>Traditional GANs use log loss, which focuses primarily on correct classification:</p>
<ul>
<li>Real sample correctly classified → minimal penalty</li>
<li>Fake sample correctly classified → minimal penalty</li>
<li>Distance from decision boundary ignored</li>
</ul>
<h4 id="l2-loss-distance-matters">L2 Loss: Distance Matters</h4>
<p>LSGAN uses L2 loss, which <strong>penalizes proportionally to distance</strong>:</p>
<p>$$
\min_D V_{LSGAN}(D) = \frac{1}{2}E_{x \sim p_{data}(x)}[(D(x) - b)^2] + \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - a)^2]
$$</p>
<p>$$
\min_G V_{LSGAN}(G) = \frac{1}{2}E_{z \sim p_z(z)}[(D(G(z)) - c)^2]
$$</p>
<p>Where typically: $a = 0$ (fake label), $b = c = 1$ (real label)</p>
<h4 id="benefits-of-l2-loss">Benefits of L2 Loss</h4>
<table>
  <thead>
      <tr>
          <th>Log Loss</th>
          <th>L2 Loss</th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td>Binary focus</td>
          <td><strong>Distance-aware</strong></td>
      </tr>
      <tr>
          <td>Can saturate</td>
          <td><strong>Informative gradients</strong></td>
      </tr>
      <tr>
          <td>Sharp decision boundary</td>
          <td><strong>Smooth decision regions</strong></td>
      </tr>
  </tbody>
</table>















<figure class="post-figure center ">
    <img src="/img/lsgan-result.webp"
         alt="LSGAN generated samples showing improved quality"
         title="LSGAN generated samples showing improved quality"
         
         
         loading="lazy"
         class="post-image">
    
    <figcaption class="post-caption"><strong>LSGAN Results</strong>: Demonstrating improved sample quality through distance-aware loss functions. <em>Source: LSGAN paper</em></figcaption>
    
</figure>

<p><strong>Key insight</strong>: LSGAN minimizes the Pearson χ² divergence, providing smoother optimization landscape than JSD.</p>
<h3 id="relaxed-wasserstein-gan-rwgan">Relaxed Wasserstein GAN (RWGAN)</h3>
<p><a href="https://arxiv.org/abs/1705.07164">Relaxed WGAN</a> bridges the gap between WGAN and WGAN-GP, proposing a <strong>general framework</strong> for designing GAN objectives.</p>
<h4 id="key-innovations">Key Innovations</h4>
<p><strong>Asymmetric weight clamping</strong>: RWGAN introduces an asymmetric approach that provides better balance.</p>
<p><strong>Relaxed Wasserstein divergences</strong>: A generalized framework that extends the Wasserstein distance, enabling systematic design of new GAN variants while maintaining theoretical guarantees.</p>
<h4 id="benefits">Benefits</h4>
<ul>
<li>Better convergence properties than standard WGAN</li>
<li>Framework for designing new loss functions and GAN architectures</li>
<li>Competitive performance with other Wasserstein-based methods</li>
</ul>
<p><strong>Key insight</strong>: RWGAN parameterized with KL divergence shows excellent performance while maintaining the theoretical foundations that make Wasserstein GANs attractive.</p>
<h3 id="statistical-distance-approaches">Statistical Distance Approaches</h3>
<p>Several GAN variants focus on minimizing specific statistical distances between distributions.</p>
<h4 id="mcgan-mean-and-covariance-matching">McGAN: Mean and Covariance Matching</h4>
<p><a href="https://arxiv.org/abs/1702.08398">McGAN</a> belongs to the Integral Probability Metric (IPM) family, using <strong>statistical moments</strong> as the distance measure.</p>
<p><strong>Approach</strong>: Match first and second-order statistics:</p>
<ul>
<li><strong>Mean matching</strong>: Align distribution centers</li>
<li><strong>Covariance matching</strong>: Align distribution shapes</li>
</ul>
<p>Moment-matching objectives like this are conceptually related to settings where aligning statistical moments matters, such as matching a generated distribution to a target physical distribution (e.g., molecular conformations). McGAN itself, however, was introduced and demonstrated as an IPM method for image generation.</p>
<p><strong>Limitation</strong>: Relies on weight clipping like original WGAN.</p>
<h4 id="gmmn-maximum-mean-discrepancy">GMMN: Maximum Mean Discrepancy</h4>
<p><a href="https://arxiv.org/abs/1502.02761">Generative Moment Matching Networks</a> eliminates the discriminator entirely, directly minimizing <strong>Maximum Mean Discrepancy (MMD)</strong>.</p>
<p><strong>MMD Intuition</strong>: Compare distributions by their means in a high-dimensional feature space:</p>
<p>$$
\text{MMD}^2(X, Y) = ||E[\phi(x)] - E[\phi(y)]||^2
$$</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Simple, discriminator-free training</li>
<li>Theoretical guarantees</li>
<li>Can incorporate autoencoders for better MMD estimation</li>
</ul>
<p><strong>Drawbacks</strong>:</p>
<ul>
<li>Computationally expensive</li>
<li>Often weaker empirical results</li>
</ul>
<h4 id="mmd-gan-learning-better-kernels">MMD GAN: Learning Better Kernels</h4>
<p><a href="https://arxiv.org/abs/1705.08584">MMD GAN</a> improves GMMN by <strong>learning optimal kernels</strong> adversarially to improve upon fixed Gaussian kernels.</p>
<p><strong>Innovation</strong>: Combine GAN adversarial training with MMD objective for the best of both worlds.</p>
<h3 id="different-distance-metrics">Different Distance Metrics</h3>
<h4 id="cramer-gan-addressing-sample-bias">Cramer GAN: Addressing Sample Bias</h4>
<p><a href="https://arxiv.org/abs/1705.10743">Cramer GAN</a> identifies a critical issue with WGAN: <strong>biased sample gradients</strong>.</p>
<p><strong>The Problem</strong>: WGAN&rsquo;s Wasserstein distance lacks three important properties:</p>
<ol>
<li><strong>Sum invariance</strong> (satisfied)</li>
<li><strong>Scale sensitivity</strong> (satisfied)</li>
<li><strong>Unbiased sample gradients</strong> (not satisfied)</li>
</ol>
<p><strong>The Solution</strong>: Use the <strong>Cramer distance</strong>, which satisfies all three properties:</p>
<p>$$
d_C^2(\mu, \nu) = \int ||E_{X \sim \mu}[X - x] - E_{Y \sim \nu}[Y - x]||^2 d\pi(x)
$$</p>
<p><strong>Benefit</strong>: More reliable gradients lead to better training dynamics.</p>
<h4 id="fisher-gan-chi-square-distance">Fisher GAN: Chi-Square Distance</h4>
<p><a href="https://arxiv.org/abs/1705.09675">Fisher GAN</a> uses a <strong>data-dependent constraint</strong> on the critic&rsquo;s second-order moments (variance).</p>
<p><strong>Key Innovation</strong>: The constraint naturally bounds the critic without manual techniques:</p>
<ul>
<li>No weight clipping needed</li>
<li>No gradient penalties required</li>
<li>Constraint emerges from the objective itself</li>
</ul>
<p><strong>Distance</strong>: Approximates the <strong>Chi-square distance</strong> as critic capacity increases:</p>
<p>$$
\chi^2(P, Q) = \int \frac{(P(x) - Q(x))^2}{Q(x)} dx
$$</p>
<p>The Fisher GAN essentially measures the Mahalanobis distance, which accounts for correlated variables relative to the distribution&rsquo;s centroid. This ensures the generator and critic remain bounded, and as the critic&rsquo;s capacity increases, it estimates the Chi-square distance.</p>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Efficient computation</li>
<li>Training stability</li>
<li>Unconstrained critic capacity</li>
</ul>
<h3 id="beyond-traditional-gans-alternative-approaches">Beyond Traditional GANs: Alternative Approaches</h3>
<p>The following variants explore fundamentally different architectures and training paradigms.</p>
<h4 id="ebgan-energy-based-discrimination">EBGAN: Energy-Based Discrimination</h4>
<p><a href="https://arxiv.org/abs/1609.03126">Energy-Based GAN</a> replaces the discriminator with an <strong>autoencoder</strong>.</p>
<p><strong>Key insight</strong>: Use reconstruction error as the discrimination signal:</p>
<ul>
<li>Good data → Low reconstruction error</li>
<li>Poor data → High reconstruction error</li>
</ul>
<p><strong>Architecture</strong>:</p>
<ol>
<li>Train autoencoder on real data</li>
<li>Generator creates samples</li>
<li>Poor generated samples have high reconstruction loss</li>
<li>This loss drives generator improvement</li>
</ol>
<p><strong>Benefits</strong>:</p>
<ul>
<li>Fast and stable training</li>
<li>Robust to hyperparameter changes</li>
<li>No need to balance discriminator/generator</li>
</ul>
<h4 id="began-boundary-equilibrium">BEGAN: Boundary Equilibrium</h4>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN</a> combines EBGAN&rsquo;s autoencoder approach with WGAN-style loss functions.</p>
<p><strong>Innovation</strong>: Dynamic equilibrium parameter $k_t$ that balances:</p>
<ul>
<li>Real data reconstruction quality</li>
<li>Generated data reconstruction quality</li>
</ul>
<p><strong>Equilibrium equation</strong>:</p>
<p>$$
L_D = L(x) - k_t L(G(z))
$$</p>
<p>$$
k_{t+1} = k_t + \lambda(\gamma L(x) - L(G(z)))
$$</p>
<h4 id="magan-adaptive-margins">MAGAN: Adaptive Margins</h4>
<p><a href="https://arxiv.org/abs/1704.03817">MAGAN</a> improves EBGAN by making the margin in the hinge loss <strong>adaptive over time</strong>.</p>
<p><strong>Concept</strong>: Start with a large margin, gradually reduce it as training progresses:</p>
<ul>
<li>Early training: Focus on major differences</li>
<li>Later training: Fine-tune subtle details</li>
</ul>
<p><strong>Result</strong>: Better sample quality and training stability.</p>
<h2 id="summary-the-evolution-of-gan-objectives">Summary: The Evolution of GAN Objectives</h2>
<p>The evolution of GAN objective functions reflects the field&rsquo;s progression toward more stable and theoretically grounded training procedures. Each variant addresses specific limitations in earlier approaches.</p>
<h3 id="complete-reference-table">Complete Reference Table</h3>
<table>
  <thead>
      <tr>
          <th><strong>GAN Variant</strong></th>
          <th><strong>Key Innovation</strong></th>
          <th><strong>Main Benefit</strong></th>
          <th><strong>Limitation</strong></th>
      </tr>
  </thead>
  <tbody>
      <tr>
          <td><strong>Original GAN</strong></td>
          <td>Jensen-Shannon divergence</td>
          <td>Foundation of adversarial training</td>
          <td>Vanishing gradients, mode collapse</td>
      </tr>
      <tr>
          <td><strong>WGAN</strong></td>
          <td>Earth-Mover distance</td>
          <td>Meaningful loss, better stability</td>
          <td>Weight clipping issues</td>
      </tr>
      <tr>
          <td><strong>WGAN-GP</strong></td>
          <td>Gradient penalty</td>
          <td>Solves weight clipping problems</td>
          <td>Additional hyperparameter tuning</td>
      </tr>
      <tr>
          <td><strong>LSGAN</strong></td>
          <td>Least squares loss</td>
          <td>Better gradients, less saturation</td>
          <td>May converge to non-optimal points</td>
      </tr>
      <tr>
          <td><strong>RWGAN</strong></td>
          <td>Relaxed Wasserstein framework</td>
          <td>General framework for new designs</td>
          <td>Complex theoretical setup</td>
      </tr>
      <tr>
          <td><strong>McGAN</strong></td>
          <td>Mean/covariance matching</td>
          <td>Simple statistical alignment</td>
          <td>Limited by weight clipping</td>
      </tr>
      <tr>
          <td><strong>GMMN</strong></td>
          <td>Maximum mean discrepancy</td>
          <td>No discriminator needed</td>
          <td>Computationally expensive</td>
      </tr>
      <tr>
          <td><strong>MMD GAN</strong></td>
          <td>Adversarial kernels for MMD</td>
          <td>Improved GMMN performance</td>
          <td>Still computationally heavy</td>
      </tr>
      <tr>
          <td><strong>Cramer GAN</strong></td>
          <td>Cramer distance</td>
          <td>Unbiased sample gradients</td>
          <td>Complex implementation</td>
      </tr>
      <tr>
          <td><strong>Fisher GAN</strong></td>
          <td>Chi-square distance</td>
          <td>Self-constraining critic</td>
          <td>Limited empirical validation</td>
      </tr>
      <tr>
          <td><strong>EBGAN</strong></td>
          <td>Autoencoder discriminator</td>
          <td>Fast, stable training</td>
          <td>Requires careful regularization</td>
      </tr>
      <tr>
          <td><strong>BEGAN</strong></td>
          <td>Boundary equilibrium</td>
          <td>Dynamic training balance</td>
          <td>Additional equilibrium parameter</td>
      </tr>
      <tr>
          <td><strong>MAGAN</strong></td>
          <td>Adaptive margin</td>
          <td>Progressive refinement</td>
          <td>Margin scheduling complexity</td>
      </tr>
  </tbody>
</table>
<h3 id="practical-recommendations">Practical Recommendations</h3>
<p>For practitioners, the choice depends on specific requirements and engineering tradeoffs:</p>
<ul>
<li><strong>WGAN-GP</strong>: Best balance of stability and performance for most applications. However, tuning the gradient penalty $\lambda$ can be sensitive in practice.</li>
<li><strong>LSGAN</strong>: Simpler implementation with good empirical results.</li>
<li><strong>EBGAN</strong>: Fast experimentation and prototyping.</li>
<li><strong>Original GAN</strong>: Educational purposes and understanding fundamentals.</li>
</ul>
<p><strong>Real-World Impact:</strong> In my work training VLMs on terabyte-scale multimodal data and forecasting chaotic physical systems, these foundational dynamics still matter. Most generation today runs on diffusion models or autoregressive transformers, but the loss-design and training-stability lessons that came out of GAN research carry over. The choice of objective function shapes generation quality, training stability, and compute cost.</p>
<hr>
<p><strong>Acknowledgments</strong>: This post was inspired by the excellent survey &ldquo;<a href="https://arxiv.org/abs/1711.05914">How Generative Adversarial Networks and Their Variants Work: An Overview of GAN</a>&rdquo;.</p>
]]></content:encoded></item></channel></rss>