Introduction
The last decade has seen explosive growth in generative modeling research, including Generative Adversarial Networks (GANs), normalizing flows, and diffusion models. Despite these advances, Variational Autoencoders (VAEs) remain foundational to generative AI, serving as the stepping stone for understanding more complex models.
Many VAE tutorials exist, but I’ve found that fewer leverage some newer PyTorch features that can help with optimization and numerical stability. This tutorial explores what I’ve learned about VAE implementation, focusing on techniques that have helped me reduce common issues like “NaN” loss in my own work.
We’ll cover:
- VAE fundamentals and mathematical foundations
- A PyTorch VAE implementation featuring:
torchvision.transforms.v2for preprocessingtorch.distributionsfor structured VAEsdataclassesfor cleaner code organizationtensorboardfor comprehensive metric tracking- Softplus and epsilon addition for numerical stability
- VAE validation using the MNIST dataset
- Extensions and practical limitations
Let’s dive in!
What is a Variational Autoencoder?
The Variational Autoencoder (VAE) is a generative model introduced in Auto-Encoding Variational Bayes by Kingma and Welling in 2013. To understand VAEs, we need to first understand the problem they solve.
The Fundamental Problem
Consider a scenario where you have a dataset $$\mathbf{X} = {\mathbf{x}^{(i)}}_{i=1}^N$$ where each $\mathbf{x}^{(i)}$ is independently and identically distributed (IID) and can be continuous or discrete.
We make the modeling assumption that this dataset is generated by some other, lower-dimensional, random process $\mathbf{z}$. Specifically, we assume that there exists a random variable $\mathbf{z}$ such that
- Each $\mathbf{z}^{(i)} \sim p_{\theta^\ast}(\mathbf{z})$, a true and unknown prior distribution
- Each $\mathbf{x}^{(i)} \sim p_{\theta^\ast}(\mathbf{x}|\mathbf{z}^{(i)})$, a true and unknown conditional distribution
Since we cannot know the true distributions nor observe $\mathbf{z}$, our goal is to learn parameters that approximate these distributions. We make additional assumptions:
- The prior distribution $p_{\theta^\ast}(\mathbf{z})$ is simple and known (e.g., standard normal)
- The conditional distribution $p_{\theta^\ast}(\mathbf{x}|\mathbf{z})$ is simple and known (e.g., Gaussian or deterministic function)
- The distributions are differentiable almost everywhere with respect to parameters $\theta$ and inputs $\mathbf{z}$
These assumptions enable algorithms that handle:
- Intractability: VAEs handle cases with intractable components (e.g., intractable marginal likelihood, posterior inference) through variational approximations.
- Large Datasets: VAEs use gradient-based optimization rather than sampling loops per data point, enabling efficient training on large datasets.
This provides algorithms for diverse applications:
- Dimensionality Reduction: Learn low-dimensional representations for visualization, compression, and feature extraction. VAEs are similar to PCA, t-SNE, and UMAP but provide a probabilistic framework.
- Imputation: Fill in missing data for preprocessing and augmentation, including image inpainting, denoising, and super-resolution.
- Generation: Create new data for augmentation and synthesis across domains like images, text, and music. When modeling physical processes, the learned parameters often provide scientific insights.
The VAE Solution
To achieve these goals, VAEs use two key architectural components:
Encoder (Recognition Model): Maps input data to latent space: $$q_{\phi}(\mathbf{z} | \mathbf{x})$$ where $\phi$ are the encoder parameters. This approximates the intractable true posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$.
Decoder (Generative Model): Maps latent space back to input data: $$p_{\theta}(\mathbf{x} | \mathbf{z})$$ where $\theta$ are the decoder parameters.
The encoder approximation makes VAEs practical when the true posterior is intractable.

How do we jointly learn $\phi$ and $\theta$? The answer lies in the VAE objective function.
The VAE Objective
The VAE objective balances two terms: reconstruction loss and KL divergence.
Reconstruction Loss
The reconstruction loss measures how well the model reconstructs input data from the latent space. It’s the negative log-likelihood of the input given the latent representation.
For continuous inputs: Uses mean squared error (MSE) or negative Gaussian log-likelihood:
- Deterministic decoder: $\mathcal{L}_{\text{rec}} = \frac{1}{N} \sum_i^N \left( \mathbf{x}^{(i)} - f(\mathbf{z}^{(i)}) \right)^2$
- Stochastic decoder: $\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \log \mathcal{N}(\mathbf{x}^{(i)} | f(\mathbf{z}^{(i)}))$
For discrete inputs: Uses cross-entropy loss. For MNIST, we use binary cross-entropy: $$\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \left[ \mathbf{x}^{(i)} \log f(\mathbf{z}^{(i)}) + (1 - \mathbf{x}^{(i)}) \log (1 - f(\mathbf{z}^{(i)})) \right]$$
For the MNIST dataset, we can treat each pixel as a binary random variable (either 0 or 1). We are therefore modeling the output $p_{\theta}(x|z)$ as a product of 784 independent Bernoulli distributions. The negative log-likelihood of a Bernoulli distribution is the Binary Cross-Entropy (BCE). This is why our decoder outputs logits (the un-normalized log-probabilities) and we use PyTorch’s F.binary_cross_entropy_with_logits function, as it is the numerically stable way to compute this exact loss.
Reconstruction loss encourages the decoder to accurately reconstruct input data from latent representations. Through backpropagation, the encoder learns to map inputs to latent spaces that enable effective reconstruction.
KL Divergence
The KL divergence measures how much the approximate posterior deviates from the prior, encouraging similarity between them. For Gaussian priors and posteriors, this has a closed-form solution available in PyTorch as torch.distributions.kl.kl_divergence.
For multivariate Gaussians $\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)$ and $\mathcal{N}(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2)$: $$\text{KL}(\mathcal{N}_1 || \mathcal{N}_2) = \frac{1}{2} \left[ \text{tr}(\boldsymbol{\Sigma}_2^{-1} \boldsymbol{\Sigma}_1) + (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1)^\top \boldsymbol{\Sigma}_2^{-1} (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1) - k + \log \frac{\text{det}(\boldsymbol{\Sigma}_2)}{\text{det}(\boldsymbol{\Sigma}_1)} \right]$$
Where Does This KL Divergence Come From?
Consider maximizing the marginal likelihood $\log p_{\theta}(\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(N)})$. For IID data, this becomes $\sum_i \log p_{\theta}(\mathbf{x}^{(i)})$.
For a single data point: $$\log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$$
The first term (KL divergence between approximate and true posterior) is intractable. The second term is the variational lower bound we can optimize:
Expanding the variational lower bound: $$\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) - \text{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z})) \right]$$
This gives us the familiar two-term objective: reconstruction loss and KL divergence. The loss balances reconstructing input data while maintaining reasonable latent representation structure.
Modern PyTorch VAE Implementation
Now that we understand the VAE architecture and objective, let’s implement a modern VAE in PyTorch.
Before we dive into the code, I want to address what I consider a critical implementation detail: numerical stability.
A VAE encoder outputs the parameters of a Gaussian distribution: a mean $\boldsymbol{\mu}$ and a standard deviation $\boldsymbol{\sigma}$. A key constraint is that the standard deviation $\boldsymbol{\sigma}$ should always be positive.
There are two common methods to enforce this positivity, and the choice can have a significant impact on training stability.
Method 1: The “Log-Variance” Approach
This is the most common method seen in introductory VAE tutorials, and it’s what the original 2013 paper implicitly suggests.
- Encoder Output: The encoder network outputs $\boldsymbol{\mu}$ and $\log(\boldsymbol{\sigma}^2)$ (the log-variance).
- Get
std: To get the standard deviation $\boldsymbol{\sigma}$ for sampling, you compute $\boldsymbol{\sigma} = \exp(0.5 \cdot \log(\boldsymbol{\sigma}^2))$. - The Problem: This is mathematically correct, but the
exp()function is numerically unstable. If the network, especially early in training, accidentally outputs a large positive value (e.g.,100),exp(50)will result in aninfvalue, leading toNaNlosses and a crashed training run.
Method 2: The “Stable Standard Deviation” Approach (My Preferred Choice)
An alternative approach that I’ve found more robust is to parameterize $\boldsymbol{\sigma}$ directly using an activation function that is always positive but doesn’t explode.
- Encoder Output: The encoder network outputs $\boldsymbol{\mu}$ and a raw, unconstrained tensor (let’s call it
std_param). - Get
std: To get the standard deviation $\boldsymbol{\sigma}$, you pass this tensor through thesoftplusfunction: $$\boldsymbol{\sigma} = \text{softplus}\left(\text{std}_\text{param}\right) + \epsilon$$ where $\text{softplus}(x) = \log(1 + \exp(x))$ and $\epsilon$ is a small value (e.g.,1e-6) to prevent $\boldsymbol{\sigma}$ from ever being exactly zero. - The Benefit: The
softplusfunction is a “smooth” version of theReLUfunction.- It is always positive, satisfying our constraint.
- Critically, for large positive inputs $x$, $\text{softplus}(x) \approx x$. This linear behavior (unlike the exponential behavior of
exp()) prevents numerical overflow and leads to much more stable gradients.
This pattern of using softplus to parameterize the standard deviation of a distribution is a practice I’ve seen in several modern generative models, and I’ll use Method 2 for this implementation.
PyTorch VAE Implementation
My VAE implementation uses an output dataclass and a VAE class extending nn.Module.
Key Features I’ve Included:
- Logits-Based Loss: The decoder doesn’t have a final
Sigmoidlayer. It outputs raw logits. This allows us to use PyTorch’sF.binary_cross_entropy_with_logitsfunction, which I’ve found to be more numerically stable than aSigmoidlayer followed byF.binary_cross_entropy. - Numerical Stability: I parameterize the standard deviation $\boldsymbol{\sigma}$ using the
softplusfunction: $\boldsymbol{\sigma} = \text{softplus}(\text{std}_\text{param}) + \epsilon$. - PyTorch Distributions: I use
torch.distributions.Normalfor what I find to be cleaner, more efficient reparameterization and KL divergence calculation.
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from torch.nn.utils.clip_grad import clip_grad_norm_
from torchvision import datasets
from torchvision.transforms import v2 as T
from tqdm import tqdm
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
@dataclass
class VAEOutput:
"""Output dataclass for VAE forward pass."""
x_recon: torch.Tensor
z: torch.Tensor
mu: torch.Tensor
std_param: torch.Tensor # The raw output from the encoder
std: torch.Tensor # The stable, positive standard deviation
loss: Optional[torch.Tensor] = None
loss_recon: Optional[torch.Tensor] = None
loss_kl: Optional[torch.Tensor] = None
class VAE(nn.Module):
"""
Variational Autoencoder (VAE) class.
This implementation uses the 'Stable Standard Deviation' (softplus)
method for numerical stability.
Args:
input_dim (int): Dimensionality of the input data.
hidden_dim (int): Dimensionality of the hidden layer.
latent_dim (int): Dimensionality of the latent space.
"""
def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
super().__init__()
# Encoder network
self.encoder = nn.Sequential(
nn.Flatten(),
nn.Linear(input_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, latent_dim * 2), # Combined mu and std_param
)
# Decoder network (outputs logits; apply sigmoid only for visualization)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, input_dim),
)
self.latent_dim = latent_dim
def encode(self, x: torch.Tensor):
"""Encode input to latent parameters."""
h = self.encoder(x)
mu, std_param = torch.chunk(h, 2, dim=-1)
return mu, std_param
def reparameterize(self, mu: torch.Tensor, std_param: torch.Tensor):
"""
Reparameterization trick using softplus for stable std.
Args:
mu: Mean tensor
std_param: Raw, unconstrained tensor from encoder
Returns:
z: Sampled latent vector
std: Calculated standard deviation
"""
# Calculate std using softplus for stability + epsilon
std = F.softplus(std_param) + 1e-6
# Use torch.distributions.Normal for efficient sampling
dist = torch.distributions.Normal(mu, std)
# Get sample z ~ N(mu, std)
z = dist.rsample() # rsample() enables backpropagation
return z, std
def decode(self, z: torch.Tensor):
"""Decode latent representation to logits (apply sigmoid externally if needed)."""
return self.decoder(z)
def forward(
self,
x: torch.Tensor,
compute_loss: bool = True,
sample: Optional[bool] = None,
kl_beta: float = 1.0,
) -> VAEOutput:
"""
Forward pass through VAE.
Args:
x: input tensor shaped (B, 1, 28, 28)
compute_loss: whether to compute BCE+KL
sample: if True, draw z ~ N(mu, std) with rsample(); if False, use z = mu.
Defaults to self.training.
"""
mu, std_param = self.encode(x)
std = F.softplus(std_param) + 1e-6
if sample is None:
sample = self.training
if sample:
# Use reparameterization helper (stable std inside)
z, std = self.reparameterize(mu, std_param)
else:
z = mu
x_logits = self.decode(z)
x_recon = torch.sigmoid(x_logits) # Apply sigmoid only for output/visualization
output = VAEOutput(
x_recon=x_recon,
z=z,
mu=mu,
std_param=std_param,
std=std,
)
if compute_loss:
# Reconstruction loss: sum over pixels, mean over batch (matches KL scale)
x_target = x.view(x.size(0), -1) # ensure same shape as logits (B, 784)
# Use BCEWithLogits for numerical stability
loss_recon = F.binary_cross_entropy_with_logits(
x_logits, x_target, reduction="sum"
) / x.size(0)
# KL divergence loss using torch.distributions, element-wise per latent dim
prior = torch.distributions.Normal(
torch.zeros_like(mu), torch.ones_like(std)
)
posterior = torch.distributions.Normal(mu, std)
kl_per_dim = torch.distributions.kl.kl_divergence(posterior, prior)
# Raw KL for logging (per-example sum over dims, averaged over batch)
raw_kl = kl_per_dim.sum(dim=1).mean()
loss_kl = raw_kl
loss = loss_recon + kl_beta * loss_kl
output.loss = loss
output.loss_recon = loss_recon
# Log the unclamped KL so you can monitor collapse
output.loss_kl = raw_kl
return output
A Critical Note on Loss Scaling: You might notice we use reduction="sum" for the reconstruction loss and then divide by the batch size, rather than the simpler reduction="mean". This is intentional and crucial.
- The KL Divergence (
loss_kl) is a single value summed over thelatent_dimdimensions for each sample in the batch. - The Reconstruction Loss (
loss_recon) must also be a single value for each sample in the batch to be on the same scale. We sum the BCE loss over all 784 pixels.
If we used reduction="mean" for BCE, it would be averaged over all $784 \times \text{batch-size}$ pixels, making it tiny compared to the KL loss. The KL term would dominate, and the model would learn to ignore the input, leading to posterior collapse.
Features I’ve included:
- Numerical Stability: Instead of the common (but potentially unstable)
exp(log_var)method, I parameterize the standard deviation $\boldsymbol{\sigma}$ using thesoftplusfunction: $\boldsymbol{\sigma} = \text{softplus}(\text{std}_\text{param}) + \epsilon$. This approach has helped me prevent numerical overflow andNaNlosses. - Efficient Architecture: Combined linear layers for mean and standard deviation parameters (line 46), separated using
torch.chunk(line 64). - PyTorch Distributions: I use
torch.distributions.Normalfor what I find to be cleaner, more efficient reparameterization and KL divergence calculation. This seems simpler and faster thanMultivariateNormalfor a diagonal covariance. - Clean Loss Computation: Uses
torch.distributions.kl.kl_divergencefor a clean, one-line KL term calculation. - Organization: Dataclass output structure (
VAEOutput) for what I hope is more organized and self-documenting code.
Data Preparation
I’ll use MNIST with a simpler, more standard transform. I just convert the images to [0, 1] float tensors. The BCEWithLogits loss handles the [0, 1] target range well.
def get_dataloaders(batch_size: int):
# Use modern v2 transforms for explicit operations
transform = T.Compose(
[
T.ToImage(), # Converts PIL/Numpy to tensor, keeps shape (1, 28, 28)
T.ToDtype(torch.float32, scale=True), # Converts to float and scales [0, 255] -> [0, 1]
]
)
train_data = datasets.MNIST(
os.path.expanduser("~/.pytorch/MNIST_data/"),
download=True,
train=True,
transform=transform,
)
test_data = datasets.MNIST(
os.path.expanduser("~/.pytorch/MNIST_data/"),
download=True,
train=False,
transform=transform,
)
from torch.utils.data import DataLoader
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
An Important Training Technique: β-Annealing (KL Warmup)
A common failure mode for VAEs is “posterior collapse,” especially when the KL loss (the regularization term) is too strong, too early. The model learns that the easiest way to minimize the loss is to ignore the input and just output blurry averages. This problem becomes especially pronounced when VAEs are paired with powerful decoders that can generate reasonable outputs even from collapsed latent representations.
To help address this, I use a β-annealing technique (also called KL Warmup). I add a kl_beta parameter to the loss:
loss = loss_recon + kl_beta * loss_kl
This kl_beta parameter is exactly the β from the β-VAE paper, which formally studies this reconstruction-regularization trade-off. For the first N steps (e.g., 10 epochs), we linearly increase kl_beta from 0.0 to 1.0. This tells the model:
- Phase 1 (β ≈ 0): “Ignore the KL loss. Just focus all your effort on learning to reconstruct the image.”
- Phase 2 (β → 1): “Okay, you’re good at reconstructing. Now, I will slowly add ‘pressure’ to organize your latent space into a clean Gaussian.”
This single technique has been one of the most effective approaches I’ve found for getting a stable and well-performing VAE. When β > 1, you apply more pressure to learn a disentangled latent space at the cost of reconstruction quality—exactly the trade-off explored in the β-VAE paper.
The training function includes this kl_beta logic:
def train(
model: nn.Module,
dataloader,
optimizer,
prev_updates: int,
device,
writer=None,
history: Optional[dict] = None,
kl_warmup_steps: int = 0,
):
"""
Trains the model on the given data.
"""
model.train()
for batch_idx, (data, _) in enumerate(tqdm(dataloader, desc="Training")):
n_upd = prev_updates + batch_idx
data = data.to(device)
optimizer.zero_grad()
# KL warmup schedule (linear 0->1 over kl_warmup_steps)
if kl_warmup_steps and kl_warmup_steps > 0:
kl_beta = min(1.0, (n_upd + 1) / float(kl_warmup_steps))
else:
kl_beta = 1.0
output = model(data, kl_beta=kl_beta)
loss = output.loss
loss.backward()
grad_norm_pre_clip = clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if n_upd % 100 == 0:
total_norm = float(grad_norm_pre_clip)
print(
f"Step {n_upd:,} (N samples: {n_upd * dataloader.batch_size:,}), "
f"Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) "
f"Grad: {total_norm:.4f}"
)
if writer is not None:
global_step = n_upd
writer.add_scalar("Loss/Train", loss.item(), global_step)
writer.add_scalar(
"Loss/Train/BCE", output.loss_recon.item(), global_step
)
writer.add_scalar("Loss/Train/KLD", output.loss_kl.item(), global_step)
writer.add_scalar("KL/beta", kl_beta, global_step)
writer.add_scalar("GradNorm/Train", total_norm, global_step)
if history is not None:
history.setdefault("step", []).append(n_upd)
history.setdefault("loss", []).append(float(loss.item()))
history.setdefault("recon", []).append(float(output.loss_recon.item()))
history.setdefault("kl", []).append(float(output.loss_kl.item()))
history.setdefault("grad", []).append(float(total_norm))
return prev_updates + len(dataloader)
From Blurry to Sharp: A VAE’s Trade-offs in Practice
Most VAE tutorials show blurry images and say “this is a known limitation.” This is only part of the story. The blurriness is often a choice—it’s the trade-off for a specific model configuration.
When I first encountered blurry VAE outputs, my first two questions were:
- “Is my model undertrained?” (Looking at flat loss curves suggested it was not.)
- “Is my model too small?”
Let’s explore the “levers” we can adjust, starting with that second question.
Lever 1: A Common Misconception (Network Size)
The first lever we can adjust is Network Capacity (the hidden_dim). It’s logical to think, “If my model has more parameters, it can learn more details and the images will be sharper.”
Let’s test this. I ran two experiments, both with latent_dim = 2:
- Model A (Large):
hidden_dim=512. Total Parameters: ~1.15 million. - Model B (Small):
hidden_dim=256. Total Parameters: ~469,000.


The result? The reconstructions were identically blurry. The final test losses were almost the same. This suggests that the blurriness isn’t a capacity problem. The smaller 469k parameter model appears to be already powerful enough to solve the task.
This suggests that Lever 1 isn’t the main issue. The real culprit seems to be something else.
Lever 2: The Real Culprit (The Information Bottleneck)
This brings us to what seems to be the real lever: The Information Bottleneck (latent_dim).
This appears to be a core trade-off of a VAE. Let’s run two experiments where I only change this one hyperparameter. Note that this is closely related to the β-VAE trade-off: a smaller latent_dim forces more compression (similar to higher β values), trading reconstruction quality for latent space structure.
Experiment 1: The Visualizer (latent_dim = 2)
First, I train the model with latent_dim = 2. My goal is to create a 2D latent space I can plot and visually inspect. I expect that this significant information bottleneck (compressing 784 pixel dimensions to 2) will likely hurt reconstruction quality.
Latent Space Analysis: Encouraging Results!
The model successfully learns a highly structured latent space. Digits are clustered by class, and the space is continuous, all without a single label. This is a form of disentanglement—where individual latent dimensions control meaningful generative factors. You can see how the ‘1s’ (orange) form a long arc that appears to correspond to digit slant, suggesting the model has learned to separate “what digit it is” from “how it’s written.”

Reconstruction Quality: The Trade-Off
This beautiful 2D map comes at a cost.

Experiment 2: The Reconstructor (latent_dim = 32)
Now, let’s try adjusting this lever. I change one line of code (latent_dim = 32) and retrain. I’m giving the model a much wider “bandwidth” to pass information from the encoder to the decoder. I lose the 2D plot, but what do I gain?
Reconstruction Quality: Dramatic Improvement!

The samples generated from the prior (feeding z = torch.randn(B, 32) into the decoder) are also richer and more detailed.

The Story in the Loss Curves
This trade-off isn’t just visual; it’s also quantitative. Let’s compare the loss curves from both experiments.


Further Reading
This tutorial explored the core VAE concepts and trade-offs, but there’s a rich literature building on these foundations. Here are the key papers that directly extend what we’ve covered:
Original Paper: You should, of course, start with the foundational work: “Auto-Encoding Variational Bayes” (Kingma & Welling, 2013). This paper introduces the mathematical framework and reparameterization trick that makes VAEs possible.
β-VAE: To formally control the reconstruction-regularization trade-off you experimented with, the next step is the β-VAE. Your kl_beta parameter is a form of β-annealing! Read: “β-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework” (Higgins et al., 2017). This work shows how varying β can encourage disentangled representations in the latent space.
VQ-VAE (To Fix Blurriness): To solve the “blurry image” problem in a fundamentally different way, VQ-VAEs replace the continuous latent space with a discrete one. By forcing the encoder to choose from a finite “codebook” of latent vectors, the model cannot spend its capacity modeling imperceptible noise and is forced to learn higher-level features. This discrete bottleneck prevents posterior collapse and leads to much sharper reconstructions while maintaining the autoencoder structure. Read: “Neural Discrete Representation Learning” (van den Oord et al., 2017).
These papers represent the natural next steps in understanding how to tune VAEs for specific applications and overcome their traditional limitations.
Conclusion
This tutorial explored PyTorch techniques for building a Variational Autoencoder. More importantly, I’ve tried to show that the VAE isn’t just a “blurry image generator.” It appears to be a tool with specific “levers” that trade reconstruction quality for latent space structure.
Many additional “levers” can be explored, such as the choice of prior distribution, decoder architecture, and training strategies.
Key takeaways from my experience:
- Consider Numerical Stability:
BCEWithLogits(noSigmoidin the decoder) andsoftplusfor the standard deviation have been helpful for stable training in my experiments. - Try KL Warmup: This technique seems to give the model time to learn what to reconstruct before worrying about how to organize the latent space.
- The Bottleneck Matters: The
latent_dimappears to be a key hyperparameter for controlling the reconstruction-regularization trade-off.
By understanding these trade-offs, you might be able to tune a VAE for your specific goal, whether it’s a 2D visualization or a higher-fidelity generative model.
Questions or feedback? Feel free to reach out—I’d love to hear about your experiences with VAE experiments!
