Paper Summary
Citation: Diederik P. Kingma and Max Welling. “Auto-Encoding Variational Bayes.” arXiv:1312.6114v11 [stat.ML], 2022.
Publication: arXiv preprint (originally 2013)
What kind of paper is this?
This paper introduces a stochastic variational inference and learning algorithm designed for directed probabilistic models with continuous latent variables. The method, called the Auto-Encoding VB (AEVB) algorithm, leads to what we now know as the variational auto-encoder (VAE) when neural networks are used as the recognition model.
What is the motivation?
The paper tackles efficient inference and learning in directed probabilistic models facing two key challenges:
Intractable Distributions: Models with continuous latent variables where the true posterior distribution $p_{\theta}(z|x)$ is intractable. This intractability also applies to the marginal likelihood $p_{\theta}(x)$, preventing use of standard methods like the EM algorithm.
Large Datasets: The method scales to large datasets where batch optimization is too costly, requiring a stochastic approach that can make parameter updates using small minibatches.
What is the novelty?
The paper makes two key contributions:
The SGVB Estimator (Reparameterization Trick): Introduces a reparameterization of the variational lower bound. This technique, now known as the “reparameterization trick,” expresses a random variable $z \sim q_{\phi}(z|x)$ as a deterministic function $z = g_{\phi}(\epsilon, x)$, where $\epsilon$ is auxiliary noise from a simple distribution (e.g., $\epsilon \sim \mathcal{N}(0, I)$). This yields a differentiable, low-variance estimator called the SGVB (Stochastic Gradient Variational Bayes) estimator that can be optimized using standard stochastic gradient methods.
The AEVB Algorithm (Variational Auto-Encoder): Proposes the Auto-Encoding VB (AEVB) algorithm that uses the SGVB estimator to optimize a recognition model $q_{\phi}(z|x)$:
- The recognition model acts as a probabilistic encoder, approximating the intractable true posterior
- The generative model $p_{\theta}(x|z)$ acts as a probabilistic decoder
- Encoder parameters ($\phi$) and decoder parameters ($\theta$) are learned jointly by optimizing the SGVB variational lower bound
- When implemented with neural networks (e.g., MLPs), this forms the variational auto-encoder
What experiments were performed?
The authors evaluated their method on MNIST and Frey Face image datasets:
- Baselines: Compared AEVB against the wake-sleep algorithm and Monte Carlo EM (MCEM)
- Metrics: Evaluated variational lower bound and estimated marginal likelihood on train and test sets
- Architecture: Used MLPs (multi-layered perceptrons) for encoder and decoder
- Training: Stochastic gradient ascent with Adagrad, minibatches of size $M=100$, and $L=1$ sample per datapoint
- Analysis: Visualized 2D latent manifolds learned by the models
What were the outcomes and conclusions?
Performance: AEVB converged faster and reached better solutions than the wake-sleep algorithm when optimizing the variational lower bound (Figure 2). It also achieved better estimated marginal log-likelihood than both wake-sleep and MCEM (Figure 3).
Regularization: Using more latent variables didn’t increase overfitting, explained by the regularizing effect of the KL divergence term in the lower bound.
Impact: The paper successfully introduced the SGVB estimator and AEVB algorithm as efficient methods for inference and learning in models with continuous latent variables, with theoretical advantages reflected in experimental results.
Key Appendices
The paper’s appendices provide important implementation details:
- Appendix A: Visualizations of 2D latent manifolds learned for MNIST and Frey Face datasets
- Appendix B: Analytical solution for the KL divergence term $-D_{KL}(q_{\phi}(z)||p_{\theta}(z))$ when both the approximate posterior $q_{\phi}(z|x)$ and prior $p_{\theta}(z)$ are Gaussian (used in Equation 10)
- Appendix C: Implementation details for MLPs as probabilistic encoders and decoders, including Bernoulli MLPs for binary data (MNIST) and Gaussian MLPs for real-valued data (Frey Face)