Introduction

The last decade has seen explosive growth in generative modeling research, including Generative Adversarial Networks (GANs), normalizing flows, and diffusion models. Despite these advances, Variational Autoencoders (VAEs) remain foundational to generative AI, serving as the stepping stone for understanding more complex models.

Many VAE tutorials exist, but few leverage modern PyTorch features that improve optimization and numerical stability. This tutorial demonstrates current best practices for VAE implementation, focusing on techniques that significantly reduce common issues like “NaN” loss.

We’ll cover:

  • VAE fundamentals and mathematical foundations
  • A modern PyTorch VAE implementation featuring:
    • torchvision.transforms.v2 for preprocessing
    • torch.distributions for structured VAEs
    • dataclasses for cleaner code organization
    • tensorboard for comprehensive metric tracking
    • Softplus and epsilon addition for numerical stability
  • VAE validation using the MNIST dataset
  • Extensions and practical limitations

Let’s dive in!

What is a Variational Autoencoder?

The Variational Autoencoder (VAE) is a generative model introduced in Auto-Encoding Variational Bayes by Kingma and Welling in 2013. To understand VAEs, we need to first understand the problem they solve.

The Fundamental Problem

Consider a scenario where you have a dataset $$\mathbf{X} = {\mathbf{x}^{(i)}}_{i=1}^N$$ where each $\mathbf{x}^{(i)}$ is independently and identically distributed (IID) and can be continuous or discrete.

We make the modeling assumption that this dataset is generated by some other, lower-dimensional, random process $\mathbf{z}$. Specifically, we assume that there exists a random variable $\mathbf{z}$ such that

  • Each $\mathbf{z}^{(i)} \sim p_{\theta^\ast}(\mathbf{z})$, a true and unknown prior distribution
  • Each $\mathbf{x}^{(i)} \sim p_{\theta^\ast}(\mathbf{x}|\mathbf{z}^{(i)})$, a true and unknown conditional distribution

Since we cannot know the true distributions nor observe $\mathbf{z}$, our goal is to learn parameters that approximate these distributions. We make additional assumptions:

  • The prior distribution $p_{\theta^\ast}(\mathbf{z})$ is simple and known (e.g., standard normal)
  • The conditional distribution $p_{\theta^\ast}(\mathbf{x}|\mathbf{z})$ is simple and known (e.g., Gaussian or deterministic function)
  • The distributions are differentiable almost everywhere with respect to parameters $\theta$ and inputs $\mathbf{z}$

These assumptions enable algorithms that handle:

  • Intractability: VAEs handle cases with intractable components (e.g., intractable marginal likelihood, posterior inference) through variational approximations.
  • Large Datasets: VAEs use gradient-based optimization rather than sampling loops per data point, enabling efficient training on large datasets.

This provides algorithms for diverse applications:

  • Dimensionality Reduction: Learn low-dimensional representations for visualization, compression, and feature extraction. VAEs are similar to PCA, t-SNE, and UMAP but provide a probabilistic framework.
  • Imputation: Fill in missing data for preprocessing and augmentation, including image inpainting, denoising, and super-resolution.
  • Generation: Create new data for augmentation and synthesis across domains like images, text, and music. When modeling physical processes, the learned parameters often provide scientific insights.

The VAE Solution

To achieve these goals, VAEs use two key architectural components:

Encoder (Recognition Model): Maps input data to latent space: $$q_{\phi}(\mathbf{z} | \mathbf{x})$$ where $\phi$ are the encoder parameters. This approximates the intractable true posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$.

Decoder (Generative Model): Maps latent space back to input data: $$p_{\theta}(\mathbf{x} | \mathbf{z})$$ where $\theta$ are the decoder parameters.

The encoder approximation makes VAEs practical when the true posterior is intractable.

VAE Architecture
VAE Architecture overview

How do we jointly learn $\phi$ and $\theta$? The answer lies in the VAE objective function.

The VAE Objective

The VAE objective balances two terms: reconstruction loss and KL divergence.

Reconstruction Loss

The reconstruction loss measures how well the model reconstructs input data from the latent space. It’s the negative log-likelihood of the input given the latent representation.

For continuous inputs: Uses mean squared error (MSE) or negative Gaussian log-likelihood:

  • Deterministic decoder: $\mathcal{L}_{\text{rec}} = \frac{1}{N} \sum_i^N \left( \mathbf{x}^{(i)} - f(\mathbf{z}^{(i)}) \right)^2$
  • Stochastic decoder: $\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \log \mathcal{N}(\mathbf{x}^{(i)} | f(\mathbf{z}^{(i)}))$

For discrete inputs: Uses cross-entropy loss. For MNIST, we use binary cross-entropy: $$\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \left[ \mathbf{x}^{(i)} \log f(\mathbf{z}^{(i)}) + (1 - \mathbf{x}^{(i)}) \log (1 - f(\mathbf{z}^{(i)})) \right]$$

Reconstruction loss encourages the decoder to accurately reconstruct input data from latent representations. Through backpropagation, the encoder learns to map inputs to latent spaces that enable effective reconstruction.

KL Divergence

The KL divergence measures how much the approximate posterior deviates from the prior, encouraging similarity between them. For Gaussian priors and posteriors, this has a closed-form solution available in PyTorch as torch.distributions.kl.kl_divergence.

For multivariate Gaussians $\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)$ and $\mathcal{N}(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2)$: $$\text{KL}(\mathcal{N}_1 || \mathcal{N}_2) = \frac{1}{2} \left[ \text{tr}(\boldsymbol{\Sigma}_2^{-1} \boldsymbol{\Sigma}_1) + (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1)^\top \boldsymbol{\Sigma}_2^{-1} (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1) - k + \log \frac{\text{det}(\boldsymbol{\Sigma}_2)}{\text{det}(\boldsymbol{\Sigma}_1)} \right]$$

Where Does This KL Divergence Come From?

Consider maximizing the marginal likelihood $\log p_{\theta}(\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(N)})$. For IID data, this becomes $\sum_i \log p_{\theta}(\mathbf{x}^{(i)})$.

For a single data point: $$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$

The first term (KL divergence between approximate and true posterior) is intractable. The second term is the variational lower bound we can optimize:

Expanding the variational lower bound: $$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) - \text{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z})) \right]$$

This gives us the familiar two-term objective: reconstruction loss and KL divergence. The loss balances reconstructing input data while maintaining reasonable latent representation structure.

Modern PyTorch VAE Implementation

Now that we understand the VAE architecture and objective, let’s implement a modern VAE in PyTorch using MNIST for validation.

PyTorch VAE Implementation

Our VAE implementation uses an output dataclass and a VAE class extending nn.Module. Here are the key modern PyTorch features:

Key Modern Features:

  • Numerical Stability: nn.Softplus and nn.SiLU activations improve convergence. Line 58 adds epsilon eps to softplus-activated log-variance for numerical stability.
  • Efficient Architecture: Combined linear layers for mean and log-variance (line 28), separated using torch.chunk (line 57).
  • PyTorch Distributions: torch.distributions.MultivariateNormal enables clean re-parameterized sampling and efficient KL computation.
  • Clean Loss Computation: Uses torch.distributions.kl.kl_divergence for KL terms and BCE for reconstruction loss.
  • Best Practices: Dataclass output structure for organized code. Consider adding hyperparameter dataclasses and configuration management for production use.

Data Preparation

For validation, we’ll use MNIST with modern preprocessing via torchvision.transforms.v2:

import torch
from torchvision import datasets
from torchvision.transforms import v2

batch_size = 128
transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.view(-1) - 0.5),
])

# Download and load the training data
train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/',
    download=True,
    train=True,
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/',
    download=True,
    train=False,
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
)
  • v2.ToImage() and v2.ToDtype handle uint8→float32 conversion and [0,1] scaling
  • v2.Lambda zero-centers data and flattens for the feed-forward network
  • Batch size of 128 balances computational efficiency with the original paper’s recommendation (≥100)

Training and Validation

We can instantiate a model, optimizer, and tensorboard writer, and then train the model using the following code:

from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 50
latent_dim = 2
hidden_dim = 512

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')

This yields a network with about 1.1M parameters:

Number of parameters: 1,149,972

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): SiLU()
    (8): Linear(in_features=64, out_features=4, bias=True)
  )
  (softplus): Softplus(beta=1, threshold=20)
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): SiLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): SiLU()
    (8): Linear(in_features=512, out_features=784, bias=True)
    (9): Sigmoid()
  )
)

Our train function will look as follows:

def train(model, dataloader, optimizer, prev_updates, writer=None):
    """
    Trains the model on the given data.

    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode

    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx

        data = data.to(device)

        optimizer.zero_grad()  # Zero the gradients

        output = model(data)  # Forward pass
        loss = output.loss

        loss.backward()

        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)

            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) Grad: {total_norm:.4f}')

            if writer is not None:
                global_step = n_upd
                writer.add_scalar('Loss/Train', loss.item(), global_step)
                writer.add_scalar('Loss/Train/BCE', output.loss_recon.item(), global_step)
                writer.add_scalar('Loss/Train/KLD', output.loss_kl.item(), global_step)
                writer.add_scalar('GradNorm/Train', total_norm, global_step)

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()  # Update the model parameters

    return prev_updates + len(dataloader)

And our test loop:

def test(model, dataloader, cur_step, writer=None):
    """
    Tests the model on the given data.

    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
        writer: The TensorBoard writer.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    test_recon_loss = 0
    test_kl_loss = 0

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data

            output = model(data, compute_loss=True)  # Forward pass

            test_loss += output.loss.item()
            test_recon_loss += output.loss_recon.item()
            test_kl_loss += output.loss_kl.item()

    test_loss /= len(dataloader)
    test_recon_loss /= len(dataloader)
    test_kl_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f} (BCE: {test_recon_loss:.4f}, KLD: {test_kl_loss:.4f})')

    if writer is not None:
        writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
        writer.add_scalar('Loss/Test/BCE', output.loss_recon.item(), global_step=cur_step)
        writer.add_scalar('Loss/Test/KLD', output.loss_kl.item(), global_step=cur_step)

        # Log reconstructions
        writer.add_images('Test/Reconstructions', output.x_recon.view(-1, 1, 28, 28), global_step=cur_step)
        writer.add_images('Test/Originals', data.view(-1, 1, 28, 28), global_step=cur_step)

        # Log random samples from the latent space
        z = torch.randn(16, latent_dim).to(device)
        samples = model.decode(z)
        writer.add_images('Test/Samples', samples.view(-1, 1, 28, 28), global_step=cur_step)

Then we’ll run the training job:

prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

TensorBoard Visualization

The loss curves during training over 50 epochs:

Loss Curves
Loss curves during training

The loss curves show convergence around 140 total loss. The reconstruction loss (~130) dominates, giving ~0.165 per pixel—quite reasonable. The KL term (~7) contributes meaningfully to regularization.

Gradient norms during training:

Gradient Norms
Gradient norms during training

Gradient norms reach ~$10^2$ before clipping at 1.0. While clipping provided slight stability improvements, the impact on final loss was minimal.

64 samples from the latent space:

Latent Space Samples
64 samples from the latent space

The samples are somewhat blurry—typical for VAEs—but show clear digit structure.

Latent Space Analysis

To analyze the learned latent space, I plotted the training set as a scatter plot colored by digit class:

MNIST 2D Scatter
MNIST 2D scatter plot

The scatter plot reveals impressive unsupervised digit clustering. While some confusion exists (notably between 4s and 9s), the model discovers meaningful structure without labels.

The latent distribution deviates from the Gaussian prior, showing radial “flowery” patterns. This is expected when compressing 784 dimensions to 2—information loss and non-Gaussian structure are inevitable. To enforce stronger Gaussian structure, increase the KL weight β.

MNIST 2D Histogram
MNIST 2D histogram plot
MNIST 1D Marginals
MNIST 1D marginals plot

Interpolating in Latent Space

Linear interpolation in latent space demonstrates smooth transitions between digit types:

import torch
import matplotlib.pyplot as plt


n = 15
z1 = torch.linspace(-0, 1, n)
z2 = torch.zeros_like(z1) + 2
z = torch.stack([z1, z2], dim=-1).to(device)
samples = model.decode(z)
samples = torch.sigmoid(samples)

# Plot the generated images
fig, ax = plt.subplots(1, n, figsize=(n, 1))
for i in range(n):
    ax[i].imshow(samples[i].view(28, 28).cpu().detach().numpy(), cmap='gray')
    ax[i].axis('off')

plt.savefig('vae_mnist_interp.webp')
Latent Space Interpolation
Latent space interpolation

Complete Code

The full implementation is available in this Jupyter notebook.

Extensions and Limitations

Conditional VAEs (CVAEs): Condition on auxiliary information like class labels for semi-supervised learning.

Alternative Priors: Experiment with non-Gaussian priors. PyTorch’s torch.distributions.kl.kl_divergence supports many distribution families.

Output Distributions: Beyond BCE loss, try Gaussian outputs with learned variance or categorical distributions over pixel intensities.

Disentangled VAEs: Learn interpretable latent factors using β-VAE, Factor-VAE, or other disentanglement methods.

Hierarchical VAEs: Model data at multiple abstraction levels for complex generation tasks.

Limitations

Prior Limitations: Gaussian priors assume isotropic latent spaces and struggle with multi-modal distributions. The choice of prior family fundamentally constrains learning.

Approximation Quality: Tractability often trumps accuracy. Mean-field approximations and other simplifying assumptions create gaps between learned and true distributions.

Sample Quality: VAEs typically produce blurrier samples than GANs due to the averaging effect of the KL term and maximum likelihood training.

Mode Collapse: VAEs can ignore data regions, especially when gradient clipping or other regularization techniques are overly aggressive.

Loss Balancing: The reconstruction-KL tradeoff requires careful tuning. Use β-weighting and monitor latent distribution marginals to assess balance:

  • Better reconstruction: Decrease β (KL weight)
  • More Gaussian latents: Increase β
  • Assessment: Examine latent marginals’ deviation from the prior

Conclusion

This tutorial demonstrated modern PyTorch techniques for building robust VAEs. We covered VAE fundamentals, implemented a clean architecture using current best practices, and validated our approach on MNIST. Key contributions include:

  • Numerical stability through modern activations and epsilon additions
  • Clean architecture with torch.distributions and dataclasses
  • Practical insights on loss balancing and latent space analysis

Modern PyTorch features like torch.distributions and improved preprocessing pipelines make VAE implementation more reliable and maintainable. These techniques provide a solid foundation for exploring advanced generative models.

Questions or feedback? Feel free to reach out—I’d love to hear about your VAE experiments!