Introduction
The last decade has seen explosive growth in generative modeling research, including Generative Adversarial Networks (GANs), normalizing flows, and diffusion models. Despite these advances, Variational Autoencoders (VAEs) remain foundational to generative AI, serving as the stepping stone for understanding more complex models.
Many VAE tutorials exist, but few leverage modern PyTorch features that improve optimization and numerical stability. This tutorial demonstrates current best practices for VAE implementation, focusing on techniques that significantly reduce common issues like “NaN” loss.
We’ll cover:
- VAE fundamentals and mathematical foundations
- A modern PyTorch VAE implementation featuring:
torchvision.transforms.v2
for preprocessingtorch.distributions
for structured VAEsdataclasses
for cleaner code organizationtensorboard
for comprehensive metric tracking- Softplus and epsilon addition for numerical stability
- VAE validation using the MNIST dataset
- Extensions and practical limitations
Let’s dive in!
What is a Variational Autoencoder?
The Variational Autoencoder (VAE) is a generative model introduced in Auto-Encoding Variational Bayes by Kingma and Welling in 2013. To understand VAEs, we need to first understand the problem they solve.
The Fundamental Problem
Consider a scenario where you have a dataset $$\mathbf{X} = {\mathbf{x}^{(i)}}_{i=1}^N$$ where each $\mathbf{x}^{(i)}$ is independently and identically distributed (IID) and can be continuous or discrete.
We make the modeling assumption that this dataset is generated by some other, lower-dimensional, random process $\mathbf{z}$. Specifically, we assume that there exists a random variable $\mathbf{z}$ such that
- Each $\mathbf{z}^{(i)} \sim p_{\theta^\ast}(\mathbf{z})$, a true and unknown prior distribution
- Each $\mathbf{x}^{(i)} \sim p_{\theta^\ast}(\mathbf{x}|\mathbf{z}^{(i)})$, a true and unknown conditional distribution
Since we cannot know the true distributions nor observe $\mathbf{z}$, our goal is to learn parameters that approximate these distributions. We make additional assumptions:
- The prior distribution $p_{\theta^\ast}(\mathbf{z})$ is simple and known (e.g., standard normal)
- The conditional distribution $p_{\theta^\ast}(\mathbf{x}|\mathbf{z})$ is simple and known (e.g., Gaussian or deterministic function)
- The distributions are differentiable almost everywhere with respect to parameters $\theta$ and inputs $\mathbf{z}$
These assumptions enable algorithms that handle:
- Intractability: VAEs handle cases with intractable components (e.g., intractable marginal likelihood, posterior inference) through variational approximations.
- Large Datasets: VAEs use gradient-based optimization rather than sampling loops per data point, enabling efficient training on large datasets.
This provides algorithms for diverse applications:
- Dimensionality Reduction: Learn low-dimensional representations for visualization, compression, and feature extraction. VAEs are similar to PCA, t-SNE, and UMAP but provide a probabilistic framework.
- Imputation: Fill in missing data for preprocessing and augmentation, including image inpainting, denoising, and super-resolution.
- Generation: Create new data for augmentation and synthesis across domains like images, text, and music. When modeling physical processes, the learned parameters often provide scientific insights.
The VAE Solution
To achieve these goals, VAEs use two key architectural components:
Encoder (Recognition Model): Maps input data to latent space: $$q_{\phi}(\mathbf{z} | \mathbf{x})$$ where $\phi$ are the encoder parameters. This approximates the intractable true posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$.
Decoder (Generative Model): Maps latent space back to input data: $$p_{\theta}(\mathbf{x} | \mathbf{z})$$ where $\theta$ are the decoder parameters.
The encoder approximation makes VAEs practical when the true posterior is intractable.

How do we jointly learn $\phi$ and $\theta$? The answer lies in the VAE objective function.
The VAE Objective
The VAE objective balances two terms: reconstruction loss and KL divergence.
Reconstruction Loss
The reconstruction loss measures how well the model reconstructs input data from the latent space. It’s the negative log-likelihood of the input given the latent representation.
For continuous inputs: Uses mean squared error (MSE) or negative Gaussian log-likelihood:
- Deterministic decoder: $\mathcal{L}_{\text{rec}} = \frac{1}{N} \sum_i^N \left( \mathbf{x}^{(i)} - f(\mathbf{z}^{(i)}) \right)^2$
- Stochastic decoder: $\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \log \mathcal{N}(\mathbf{x}^{(i)} | f(\mathbf{z}^{(i)}))$
For discrete inputs: Uses cross-entropy loss. For MNIST, we use binary cross-entropy: $$\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \left[ \mathbf{x}^{(i)} \log f(\mathbf{z}^{(i)}) + (1 - \mathbf{x}^{(i)}) \log (1 - f(\mathbf{z}^{(i)})) \right]$$
Reconstruction loss encourages the decoder to accurately reconstruct input data from latent representations. Through backpropagation, the encoder learns to map inputs to latent spaces that enable effective reconstruction.
KL Divergence
The KL divergence measures how much the approximate posterior deviates from the prior, encouraging similarity between them. For Gaussian priors and posteriors, this has a closed-form solution available in PyTorch as torch.distributions.kl.kl_divergence
.
For multivariate Gaussians $\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)$ and $\mathcal{N}(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2)$: $$\text{KL}(\mathcal{N}_1 || \mathcal{N}_2) = \frac{1}{2} \left[ \text{tr}(\boldsymbol{\Sigma}_2^{-1} \boldsymbol{\Sigma}_1) + (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1)^\top \boldsymbol{\Sigma}_2^{-1} (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1) - k + \log \frac{\text{det}(\boldsymbol{\Sigma}_2)}{\text{det}(\boldsymbol{\Sigma}_1)} \right]$$
Where Does This KL Divergence Come From?
Consider maximizing the marginal likelihood $\log p_{\theta}(\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(N)})$. For IID data, this becomes $\sum_i \log p_{\theta}(\mathbf{x}^{(i)})$.
For a single data point: $$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$
The first term (KL divergence between approximate and true posterior) is intractable. The second term is the variational lower bound we can optimize:
Expanding the variational lower bound: $$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) - \text{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z})) \right]$$
This gives us the familiar two-term objective: reconstruction loss and KL divergence. The loss balances reconstructing input data while maintaining reasonable latent representation structure.
Modern PyTorch VAE Implementation
Now that we understand the VAE architecture and objective, let’s implement a modern VAE in PyTorch using MNIST for validation.
PyTorch VAE Implementation
Our VAE implementation uses an output dataclass
and a VAE class extending nn.Module
. Here are the key modern PyTorch features:
Key Modern Features:
- Numerical Stability:
nn.Softplus
andnn.SiLU
activations improve convergence. Line 58 adds epsiloneps
to softplus-activated log-variance for numerical stability. - Efficient Architecture: Combined linear layers for mean and log-variance (line 28), separated using
torch.chunk
(line 57). - PyTorch Distributions:
torch.distributions.MultivariateNormal
enables clean re-parameterized sampling and efficient KL computation. - Clean Loss Computation: Uses
torch.distributions.kl.kl_divergence
for KL terms and BCE for reconstruction loss. - Best Practices: Dataclass output structure for organized code. Consider adding hyperparameter dataclasses and configuration management for production use.
Data Preparation
For validation, we’ll use MNIST with modern preprocessing via torchvision.transforms.v2
:
import torch
from torchvision import datasets
from torchvision.transforms import v2
batch_size = 128
transform = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Lambda(lambda x: x.view(-1) - 0.5),
])
# Download and load the training data
train_data = datasets.MNIST(
'~/.pytorch/MNIST_data/',
download=True,
train=True,
transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
'~/.pytorch/MNIST_data/',
download=True,
train=False,
transform=transform,
)
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size=batch_size,
shuffle=False,
)
v2.ToImage()
andv2.ToDtype
handle uint8→float32 conversion and [0,1] scalingv2.Lambda
zero-centers data and flattens for the feed-forward network- Batch size of 128 balances computational efficiency with the original paper’s recommendation (≥100)
Training and Validation
We can instantiate a model, optimizer, and tensorboard writer, and then train the model using the following code:
from datetime import datetime
import torch
from torch.utils.tensorboard import SummaryWriter
learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 50
latent_dim = 2
hidden_dim = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')
This yields a network with about 1.1M parameters:
Number of parameters: 1,149,972
VAE(
(encoder): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): SiLU()
(2): Linear(in_features=512, out_features=256, bias=True)
(3): SiLU()
(4): Linear(in_features=256, out_features=128, bias=True)
(5): SiLU()
(6): Linear(in_features=128, out_features=64, bias=True)
(7): SiLU()
(8): Linear(in_features=64, out_features=4, bias=True)
)
(softplus): Softplus(beta=1, threshold=20)
(decoder): Sequential(
(0): Linear(in_features=2, out_features=64, bias=True)
(1): SiLU()
(2): Linear(in_features=64, out_features=128, bias=True)
(3): SiLU()
(4): Linear(in_features=128, out_features=256, bias=True)
(5): SiLU()
(6): Linear(in_features=256, out_features=512, bias=True)
(7): SiLU()
(8): Linear(in_features=512, out_features=784, bias=True)
(9): Sigmoid()
)
)
Our train function will look as follows:
def train(model, dataloader, optimizer, prev_updates, writer=None):
"""
Trains the model on the given data.
Args:
model (nn.Module): The model to train.
dataloader (torch.utils.data.DataLoader): The data loader.
loss_fn: The loss function.
optimizer: The optimizer.
"""
model.train() # Set the model to training mode
for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
n_upd = prev_updates + batch_idx
data = data.to(device)
optimizer.zero_grad() # Zero the gradients
output = model(data) # Forward pass
loss = output.loss
loss.backward()
if n_upd % 100 == 0:
# Calculate and log gradient norms
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) Grad: {total_norm:.4f}')
if writer is not None:
global_step = n_upd
writer.add_scalar('Loss/Train', loss.item(), global_step)
writer.add_scalar('Loss/Train/BCE', output.loss_recon.item(), global_step)
writer.add_scalar('Loss/Train/KLD', output.loss_kl.item(), global_step)
writer.add_scalar('GradNorm/Train', total_norm, global_step)
# gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step() # Update the model parameters
return prev_updates + len(dataloader)
And our test loop:
def test(model, dataloader, cur_step, writer=None):
"""
Tests the model on the given data.
Args:
model (nn.Module): The model to test.
dataloader (torch.utils.data.DataLoader): The data loader.
cur_step (int): The current step.
writer: The TensorBoard writer.
"""
model.eval() # Set the model to evaluation mode
test_loss = 0
test_recon_loss = 0
test_kl_loss = 0
with torch.no_grad():
for data, target in tqdm(dataloader, desc='Testing'):
data = data.to(device)
data = data.view(data.size(0), -1) # Flatten the data
output = model(data, compute_loss=True) # Forward pass
test_loss += output.loss.item()
test_recon_loss += output.loss_recon.item()
test_kl_loss += output.loss_kl.item()
test_loss /= len(dataloader)
test_recon_loss /= len(dataloader)
test_kl_loss /= len(dataloader)
print(f'====> Test set loss: {test_loss:.4f} (BCE: {test_recon_loss:.4f}, KLD: {test_kl_loss:.4f})')
if writer is not None:
writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
writer.add_scalar('Loss/Test/BCE', output.loss_recon.item(), global_step=cur_step)
writer.add_scalar('Loss/Test/KLD', output.loss_kl.item(), global_step=cur_step)
# Log reconstructions
writer.add_images('Test/Reconstructions', output.x_recon.view(-1, 1, 28, 28), global_step=cur_step)
writer.add_images('Test/Originals', data.view(-1, 1, 28, 28), global_step=cur_step)
# Log random samples from the latent space
z = torch.randn(16, latent_dim).to(device)
samples = model.decode(z)
writer.add_images('Test/Samples', samples.view(-1, 1, 28, 28), global_step=cur_step)
Then we’ll run the training job:
prev_updates = 0
for epoch in range(num_epochs):
print(f'Epoch {epoch+1}/{num_epochs}')
prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
test(model, test_loader, prev_updates, writer=writer)
TensorBoard Visualization
The loss curves during training over 50 epochs:

The loss curves show convergence around 140 total loss. The reconstruction loss (~130) dominates, giving ~0.165 per pixel—quite reasonable. The KL term (~7) contributes meaningfully to regularization.
Gradient norms during training:

Gradient norms reach ~$10^2$ before clipping at 1.0. While clipping provided slight stability improvements, the impact on final loss was minimal.
64 samples from the latent space:

The samples are somewhat blurry—typical for VAEs—but show clear digit structure.
Latent Space Analysis
To analyze the learned latent space, I plotted the training set as a scatter plot colored by digit class:

The scatter plot reveals impressive unsupervised digit clustering. While some confusion exists (notably between 4s and 9s), the model discovers meaningful structure without labels.
The latent distribution deviates from the Gaussian prior, showing radial “flowery” patterns. This is expected when compressing 784 dimensions to 2—information loss and non-Gaussian structure are inevitable. To enforce stronger Gaussian structure, increase the KL weight β.


Interpolating in Latent Space
Linear interpolation in latent space demonstrates smooth transitions between digit types:
import torch
import matplotlib.pyplot as plt
n = 15
z1 = torch.linspace(-0, 1, n)
z2 = torch.zeros_like(z1) + 2
z = torch.stack([z1, z2], dim=-1).to(device)
samples = model.decode(z)
samples = torch.sigmoid(samples)
# Plot the generated images
fig, ax = plt.subplots(1, n, figsize=(n, 1))
for i in range(n):
ax[i].imshow(samples[i].view(28, 28).cpu().detach().numpy(), cmap='gray')
ax[i].axis('off')
plt.savefig('vae_mnist_interp.webp')

Complete Code
The full implementation is available in this Jupyter notebook.
Extensions and Limitations
Conditional VAEs (CVAEs): Condition on auxiliary information like class labels for semi-supervised learning.
Alternative Priors: Experiment with non-Gaussian priors. PyTorch’s torch.distributions.kl.kl_divergence
supports many distribution families.
Output Distributions: Beyond BCE loss, try Gaussian outputs with learned variance or categorical distributions over pixel intensities.
Disentangled VAEs: Learn interpretable latent factors using β-VAE, Factor-VAE, or other disentanglement methods.
Hierarchical VAEs: Model data at multiple abstraction levels for complex generation tasks.
Limitations
Prior Limitations: Gaussian priors assume isotropic latent spaces and struggle with multi-modal distributions. The choice of prior family fundamentally constrains learning.
Approximation Quality: Tractability often trumps accuracy. Mean-field approximations and other simplifying assumptions create gaps between learned and true distributions.
Sample Quality: VAEs typically produce blurrier samples than GANs due to the averaging effect of the KL term and maximum likelihood training.
Mode Collapse: VAEs can ignore data regions, especially when gradient clipping or other regularization techniques are overly aggressive.
Loss Balancing: The reconstruction-KL tradeoff requires careful tuning. Use β-weighting and monitor latent distribution marginals to assess balance:
- Better reconstruction: Decrease β (KL weight)
- More Gaussian latents: Increase β
- Assessment: Examine latent marginals’ deviation from the prior
Conclusion
This tutorial demonstrated modern PyTorch techniques for building robust VAEs. We covered VAE fundamentals, implemented a clean architecture using current best practices, and validated our approach on MNIST. Key contributions include:
- Numerical stability through modern activations and epsilon additions
- Clean architecture with
torch.distributions
and dataclasses - Practical insights on loss balancing and latent space analysis
Modern PyTorch features like torch.distributions
and improved preprocessing pipelines make VAE implementation more reliable and maintainable. These techniques provide a solid foundation for exploring advanced generative models.
Questions or feedback? Feel free to reach out—I’d love to hear about your VAE experiments!