Introduction

The last decade has seen a surge in generative modeling research, including Generative Adversarial Networks (GANs), normalizing flows, and diffusion models. Despite the advancements, Variational Autoencoders (VAEs) remain a cornerstone in this field. A comprehensive understanding of VAEs is essential for anyone delving into generative AI, serving as a foundation for more complex models.

Many tutorials on VAEs exist, yet few leverage the latest PyTorch advancements, potentially missing out on optimization and numerical stability enhancements. This tutorial aims to fill that gap by demonstrating modern PyTorch techniques applied to VAEs, reducing the risk of issues like “NaN” loss.

We’ll cover:

  • VAE fundamentals
  • A modern PyTorch VAE implementation, featuring:
    • torchvision.transforms.v2 for preprocessing
    • torch.distributions for structured VAEs
    • dataclasses for cleaner code
    • tensorboard for metric tracking
    • Softplus and epsilon addition for stability
  • VAE validation using the MNIST dataset
  • Discussion on VAE extensions and limitations

Let’s dive in!

What is a Variational Autoencoder?

The Variational Autoencoder (VAE) is a generative model first introduced in Auto-Encoding Variational Bayes by Kingma and Welling in 2013. To best understand VAEs, you should start with understanding why they were developed.

The Fundamental Problem

Consider a scenario where you have a dataset $$\mathbf{X} = [ \mathbf{x}^{(i)} ]_{i=1}^N$$ Each $\mathbf{x}^{(i)}$ is IID (independent and identically distributed) and can be continuous- or discrete-valued.

We make the modeling assumption that this dataset is generated by some other, lower-dimensional, random process $\mathbf{z}$. Specifically, we assume that there exists a random variable $\mathbf{z}$ such that

  • Each $\mathbf{z}^{(i)} \sim p_{\theta^\ast}(\mathbf{z})$, a true and unknown prior distribution
  • Each $\mathbf{x}^{(i)} \sim p_{\theta^\ast}(\mathbf{x}|\mathbf{z}^{(i)})$, a true and unknown conditional distribution

Since we cannot know the true distributions nor observe $\mathbf{z}$, our goal is to learn parameters that best approximate these distributions. To do so, we must further assume that:

  • The prior distribution $p_{\theta^\ast}(\mathbf{z})$ is simple and known (e.g., a standard normal distribution)
  • The conditional distribution $p_{\theta^\ast}(\mathbf{x}|\mathbf{z})$ is simple and known (e.g., a Gaussian distribution, a deterministic function, etc.)
  • The distributions are differentiable almost every with respect to both their parameters $\theta$ and their inputs $\mathbf{z}$

There’s are a lot of assumptions, so what does this gain us? A general algorithm that can handle:

  • Intractability: VAEs can handle cases with intractable components (e.g., intractable marginal likelihood, intractable posterior inference). Through approximations, VAEs work around these issues.
  • Large Datasets: Variational algorithms that require a sampling loop per data point are not feasible for large datasets. VAEs are trained using gradient-based optimization, which is much more efficient. Because of this, VAEs can be trained on large datasets.

More concretely, this gives as an algorithm that one can use for a diverse set of purposes, including:

  • Dimensionality Reduction: VAEs can be used to learn a low-dimensional representation of high-dimensional data. This is useful for visualization, data compression, and feature extraction. In this way, VAEs are similar to PCA, t-SNE, and UMAP, or autoencoders. Another way is to think of VAEs as a tool for identifying the intrinsic dimensionality of the data.
  • Imputation: VAEs can be used to fill in missing data. This is useful for data preprocessing and data augmentation. Image in-painting, de-noising, and super-resolution are all examples of imputation.
  • Generation: VAEs can be used to generate new data. This is useful for data augmentation, data synthesis, and generative modeling. Image generation, text generation, and music generation are all examples of generation. Additionally, if we are mimicking a physical process, we may also be interested in the learned parameters of the model.

The VAE Solution

To achieve these aims, VAEs have two components of their architecture.

First, an encoder (or recognition model) that maps the input data to a latent space, $$q_{\phi}(\mathbf{z} | \mathbf{x})$$ where $\phi$ are the parameters of the encoder (e.g., a neural network). It serves as an approximation to the true posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$ which is intractable and unknown.

Second, a decoder (or generative model) that maps the latent space back to the input data, $$p_{\theta}(\mathbf{x} | \mathbf{z})$$ where $\theta$ are the parameters of the decoder (e.g., a neural network). Note that we say $q_{\phi}(\mathbf{z} | \mathbf{x})$ is an approximation of the true posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$ of this decoder model because it is intractable to compute in some scenarios. It is this approximation that makes the VAE practical in the intractable cases.

VAE Architecture

VAE Architecture overview

How can we jointly learn $\phi$ and $\theta$? Let’s turn to our objective function.

The VAE Objective

The VAE objective consists of two terms: the reconstruction loss and the KL divergence.

Reconstruction Loss

The reconstruction loss is a measure of how well the model can reconstruct the input data from the latent space. It is typically the negative log-likelihood of the input data given the latent space.

  • For a continuous input space, the reconstruction loss is often the negative log-likelihood of the input data given the latent space, which is equivalent to the mean squared error (MSE) between the input data and the reconstructed data.
    • Deterministic decoder loss: $\mathcal{L}_{\text{rec}} = \frac{1}{N} \sum_i^N \left( \mathbf{x}^{(i)} - f(\mathbf{z}^{(i)}) \right)^2$, just the mean squared error between the input data and the reconstructed data.
    • Stochastic decoder (assumed multivariate Gaussian form) loss: $\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \left(\log \mathcal{N}(\mathbf{x}^{(i)} | f(\mathbf{z}^{(i)})) \right)$, the negative log-likelihood of the input data given the latent space.
  • For a discrete input space, the reconstruction loss is often the cross-entropy between the input data and the reconstructed data.
    • In our demonstration, we’ll use the binary cross-entropy loss when dealing with MNIST pixel values: $\mathcal{L}_{\text{rec}} = -\frac{1}{N} \sum_i^N \left( \mathbf{x}^{(i)} \log f(\mathbf{z}^{(i)}) + (1 - \mathbf{x}^{(i)}) \log (1 - f(\mathbf{z}^{(i)})) \right)$.

Intuitively, the reconstruction loss encourages the generator learn to be able to reconstruct the input data from the latent space. Similarly, through backpropagation, the encoder learns to map the input data to the latent space in a way that the decoder can reconstruct the input data.

KL Divergence

The KL divergence measures how far away the approximate posterior is from the prior. It is the KL divergence between the approximate posterior and the prior. As we want to penalize a posterior that deviates too much from the prior, this KL term encourages this similarity.

  • For just about every situation, you’ll likely be using a Gaussian prior and a Gaussian approximate posterior. In this case, the KL divergence has a closed-form solution. In fact, it’s already implemented in PyTorch as torch.distributions.kl.kl_divergence.
    • The KL divergence between two multivariate Gaussian distributions $\mathcal{N}(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1)$ and $\mathcal{N}(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2)$ is given by: $$\text{KL}(\mathcal{N}_1 || \mathcal{N}_2) = \frac{1}{2} \left( \text{tr}(\boldsymbol{\Sigma}_2^{-1} \boldsymbol{\Sigma}_1) + (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1)^\top \boldsymbol{\Sigma}_2^{-1} (\boldsymbol{\mu}_2 - \boldsymbol{\mu}_1) - k + \log \frac{\text{det}(\boldsymbol{\Sigma}_2)}{\text{det}(\boldsymbol{\Sigma}_1)} \right)$$ where $k$ is the dimensionality of the Gaussian distribution.

Where does this KL divergence come from?

Consider that what we really seek is to maximize the marginal likelihood of the data, $\log p_{\theta}(\mathbf{x}^{(1)}, \ldots, \mathbf{x}^{(N)})$. As this data is assumed IID, we can write this as the product of the marginal likelihoods of each data point, $\log p_{\theta}(\mathbf{x}^{(1)}) + \ldots + \log p_{\theta}(\mathbf{x}^{(N)})$. For a single data point, we have:

$$ \log p_{\theta}(\mathbf{x}^{(i)}) = D_{KL}(q_\phi(\mathbf{z} | \mathbf{x}^{(i)}) || p_{\theta}(\mathbf{z} | \mathbf{x}^{(i)})) + \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) $$

  • The first term is the KL divergence between the approximate posterior and the prior. It is a measure of how much information is lost when using the approximate posterior to represent the true posterior. As we do not know the true posterior, we cannot compute this term directly. This is necessarily so; we want this method to work for cases where the true posterior is intractable.
  • The second term is the variational lower bound, $\mathcal{L}(\theta, \phi; \mathbf{x}^{(i)})$. Because the first term is the KL divergence, it is non-negative. When we omit it, we focus on the second term which is necessarily a lower bound on the marginal likelihood. This is the term we can compute and optimize.

How so? Let’s expand out the term further:

$$ \mathcal{L}(\theta, \phi; \mathbf{x}^{(i)}) = E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x}, \mathbf{z}) - \log q_\phi(\mathbf{z} | \mathbf{x}) \right] $$ which is further rewritten as: $$ E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) + \log p_{\theta}(\mathbf{z}) - \log q_\phi(\mathbf{z} | \mathbf{x}) \right] $$ as the prior is assumed to be known and simple, we can rewrite this as: $$ E_{q_\phi(\mathbf{z} | \mathbf{x}^{(i)})} \left[ \log p_{\theta}(\mathbf{x} | \mathbf{z}) - \text{KL}(q_\phi(\mathbf{z} | \mathbf{x}) || p_{\theta}(\mathbf{z})) \right] $$ Hence, we have the variational lower bound, with a reconstruction term and a KL divergence term.

Intuitively, we can see the full loss as a tug-of-war between being able to reconstruct the input data from the latent space and not deviating too much from the assumed prior form for a latent space representation.

Modern PyTorch VAE Implementation

Now that we have an understanding of the VAE architecture and objective, let’s implement a modern VAE in PyTorch. We’ll use the MNIST dataset for validation.

PyTorch VAE Implementation

Our VAE implementation is broken into an Output template (in the form of a dataclass), and a VAE class that extends the nn.Module class. See the code here:

Some of the key features to pay mind to:

  • The usage of modern activation functions nn.Softplus and nn.SiLU (also known as Swish) for better numerical stability and convergence. Additionally, notice on line 58 scale = self.softplus(logvar) + eps adds a small value eps to the softplus-activated log-variance to better handle numerical stability by enforcing a lower bound.
  • The combining of the linear layers for mean and log-variance into a single nn.Linear layer, which is more efficient and cleaner. Line 28 sets up the encoder this way, while line 57 demonstrates how to use torch’s in-built chunk function to separate the output of the encoder into the mean and log-variance: mu, logvar = torch.chunk(x, 2, dim=-1)
  • Using torch.distributions for an encoder output: torch.distributions.MultivariateNormal(mu, scale_tril=scale_tril). This gives access to in-built functionality for re-parameterized sampling (see the function that’s just a simple wrapper around rsample). It also makes the KL divergence calculation cleaner and more efficient.
  • Check out the loss function portion, nested within the forward method. We calculate the reconstruction loss using binary cross-entropy loss, and the KL divergence using torch.distributions.kl.kl_divergence. We then sum these two terms together to get the final loss.
    • Note the construction of a standard normal multivariate Gaussian distribution for the prior.
    • For the BCE loss, the + 0.5 is for MNIST and is not generally needed. You’ll see below this is just undoing the - 0.5 we do to zero-center the input data.
  • While personal preference, I think it’s good practice to setup some basic dataclasses for the VAE output. In the most general case, I’d also make a dataclass for the input hyperparameters, use omegaconf to manage the configuration, and be sure to save that alongside the weights (very HuggingFace-esque).

Data Preparation

For a “sanity check” of our VAE, we’ll use the MNIST dataset. For pre-processing images, we’ll be using torchvision.transforms.v2 for cleaner, more efficient code over the now-deprecated torchvision.transforms module.

import torch
from torchvision import datasets
from torchvision.transforms import v2


batch_size = 128
transform = v2.Compose([
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.view(-1) - 0.5),
])

# Download and load the training data
train_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=True, 
    transform=transform,
)
# Download and load the test data
test_data = datasets.MNIST(
    '~/.pytorch/MNIST_data/', 
    download=True, 
    train=False, 
    transform=transform,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size, 
    shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=batch_size, 
    shuffle=False,
)
  • We use v2.ToImage() to convert the tensor to an image, and v2.ToDtype to convert the image to a float32 tensor. This takes care of the initial conversion from uint8 to float32 and the scaling of the pixel values to the range [0, 1].
  • We then use v2.Lambda to zero-center the input data. Since we’re using a simple feed-forward network, we’re also flattening the input data to a 1D tensor in this step.
  • We’re using a batch size of 128. This is because we want at least 100 (the recommended setting from the paper), but also I’d prefer to use a power of 2 for the batch size. This is because many libraries are optimized for powers of 2, and it’s a good habit to get into.

Training and Validation

We can instantiate a model, optimizer, and tensorboard writer, and then train the model using the following code:

from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 50
latent_dim = 2
hidden_dim = 512

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
writer = SummaryWriter(f'runs/mnist/vae_{datetime.now().strftime("%Y%m%d-%H%M%S")}')

This yields a network with about 1.1M parameters:

Number of parameters: 1,149,972

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): SiLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): SiLU()
    (8): Linear(in_features=64, out_features=4, bias=True)
  )
  (softplus): Softplus(beta=1, threshold=20)
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): SiLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): SiLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): SiLU()
    (8): Linear(in_features=512, out_features=784, bias=True)
    (9): Sigmoid()
  )
)

Our train function will look as follows:

def train(model, dataloader, optimizer, prev_updates, writer=None):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    
    for batch_idx, (data, target) in enumerate(tqdm(dataloader)):
        n_upd = prev_updates + batch_idx
        
        data = data.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        
        output = model(data)  # Forward pass
        loss = output.loss
        
        loss.backward()
        
        if n_upd % 100 == 0:
            # Calculate and log gradient norms
            total_norm = 0.0
            for p in model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** (1. / 2)
        
            print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Loss: {loss.item():.4f} (Recon: {output.loss_recon.item():.4f}, KL: {output.loss_kl.item():.4f}) Grad: {total_norm:.4f}')

            if writer is not None:
                global_step = n_upd
                writer.add_scalar('Loss/Train', loss.item(), global_step)
                writer.add_scalar('Loss/Train/BCE', output.loss_recon.item(), global_step)
                writer.add_scalar('Loss/Train/KLD', output.loss_kl.item(), global_step)
                writer.add_scalar('GradNorm/Train', total_norm, global_step)
            
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)    
        
        optimizer.step()  # Update the model parameters
        
    return prev_updates + len(dataloader)

And our test loop:

def test(model, dataloader, cur_step, writer=None):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
        cur_step (int): The current step.
        writer: The TensorBoard writer.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    test_recon_loss = 0
    test_kl_loss = 0
    
    with torch.no_grad():
        for data, target in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data
            
            output = model(data, compute_loss=True)  # Forward pass
            
            test_loss += output.loss.item()
            test_recon_loss += output.loss_recon.item()
            test_kl_loss += output.loss_kl.item()
            
    test_loss /= len(dataloader)
    test_recon_loss /= len(dataloader)
    test_kl_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f} (BCE: {test_recon_loss:.4f}, KLD: {test_kl_loss:.4f})')
    
    if writer is not None:
        writer.add_scalar('Loss/Test', test_loss, global_step=cur_step)
        writer.add_scalar('Loss/Test/BCE', output.loss_recon.item(), global_step=cur_step)
        writer.add_scalar('Loss/Test/KLD', output.loss_kl.item(), global_step=cur_step)
        
        # Log reconstructions
        writer.add_images('Test/Reconstructions', output.x_recon.view(-1, 1, 28, 28), global_step=cur_step)
        writer.add_images('Test/Originals', data.view(-1, 1, 28, 28), global_step=cur_step)
        
        # Log random samples from the latent space
        z = torch.randn(16, latent_dim).to(device)
        samples = model.decode(z)
        writer.add_images('Test/Samples', samples.view(-1, 1, 28, 28), global_step=cur_step)

Then we’ll run the training job:

prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer)
    test(model, test_loader, prev_updates, writer=writer)

TensorBoard Visualization

My loss curves during training over the 50 epochs looked like this:

Loss Curves

Loss Curves during training

where you can see the final testing loss roughly around 140.
Looking at the reconstruction error, we see that it dominates the majority of the loss coming to about 130ish. Per pixel, that’s roughly $130 / 784 \approx 0.165$ which seems much better with that context. In fact, it makes the KL term of $7 / 2 \approx 3.5$ seem quite a bit larger in comparison. For this reason, it might be worth see what our latent space looks like.

For the curious, this is what my gradient norms looked like:

Gradient Norms

Gradient Norms during training

  • Notably, the gradient norms grow to the order of $10^2$. We’re clipping the norm at 1.0. I experimented with and without this, but didn’t notice a significant difference in the final loss. It did seem like clipping offered slightly better stability, but it’s not clear if that’s a generalizable result.

And 64 samples from the latent space:

Latent Space Samples

64 samples from the latent space

I’d say these are a bit blurry, but they’re not terrible!

Marginals and Joint Distributions in Latent Space

To satisfy my own curiosity, I’ve plotted the training set as a scatter plot colored by their number class:

MNIST 2D Scatter

MNIST 2D Scatter plot

This figure tells us a lot about how our model is learning to represent the data. We can actually see that, in general, it does a good job at figuring out the number classes. Looking at specific numbers though, we can see it struggles. Look at the mixture of 4s and 9s. That said, not bad for a model that never sees the labels!

We know from this scatter plot that the latent space is not aligning super well with the prior. But how bad is it? I plotted the 2D histogram of the space, with a LogNorm filter to better visualize the distribution:

MNIST 2D Histogram

MNIST 2D Histogram plot

I find it quite pretty. The data does concentrate around the origin, but it’s not a perfect Gaussian. Instead, it radiates flowery patterns out from the origin. This makes sense though. We’re reducing 768 dimensions down to 2. We’re going to lose information and even then, it won’t be Gaussian.

Our 1D marginals are also quite interesting:

MNIST 1D Marginals

MNIST 1D Marginals plot

If we want our latent space to be more Gaussian, we might need to up-weight the KL divergence term in our loss function.

Interpolating in Latent Space

Choosing a random point in latent space, I did a little interpolation through the latent space illustrating the transformation from a 7 to a 0:

import torch
import matplotlib.pyplot as plt


n = 15
z1 = torch.linspace(-0, 1, n)
z2 = torch.zeros_like(z1) + 2
z = torch.stack([z1, z2], dim=-1).to(device)
samples = model.decode(z)
samples = torch.sigmoid(samples)

# Plot the generated images
fig, ax = plt.subplots(1, n, figsize=(n, 1))
for i in range(n):
    ax[i].imshow(samples[i].view(28, 28).cpu().detach().numpy(), cmap='gray')
    ax[i].axis('off')
    
plt.savefig('vae_mnist_interp.webp')
Latent Space Interpolation

Latent Space Interpolation

Jupyter Notebook

All code for this tutorial is contained within one location in this Variational Autoencoder Jupyter Notebook.

Extensions and Limitations

Extensions

How might you extend the above? Here are a few ideas:

  • Conditional VAEs: You can condition the VAE on some auxiliary information, such as class labels. This is useful for tasks like semi-supervised learning, where you have a small amount of labeled data and a large amount of unlabeled data.
  • Different Priors: You can use different priors for the latent space. Here, we used a standard normal distribution, but perhaps you can identify a case that warrants a different prior. Hopefully, it’s still closed form to evaluate. If so, you can likely re-use the torch.distributions.kl.kl_divergence function!
  • Different Output Distributions: You can use different output distributions for the decoder. Here, we used something implicit in the binary cross-entropy loss, but you could use a Gaussian distribution with a constrained mean. Or, you could model an output pixel as a categorical distribution over the 256 possible pixel values. What’s the best modeling form? What balances efficiency, expressiveness, and interpretability?
  • Disentangled VAEs: You can try to learn a disentangled representation of the data. This is useful for tasks like transfer learning, where you want to transfer knowledge from one domain to another. It’s also useful for tasks like data augmentation, where you want to generate new data that is similar to the original data, but different in some way.
  • Hierarchical VAEs: You can use a hierarchical VAE to learn a hierarchical representation of the data. This is useful for tasks like image generation, where you want to generate images at different levels of abstraction. It’s also useful for tasks like data compression, where you want to compress the data in a way that is efficient and lossless.

Limitations

What are the limitations of the VAE?

  • Limited Priors: The VAE is limited by the choice of priors. If the true prior is not in the family of priors you’re considering, you may not be able to learn a good representation of the data. This is a fundamental limitation of the VAE, and it’s not clear how to overcome it. Frequently, we’re stuck with a Gaussian prior, which is not always the best choice. It assumes that the latent space is isotropic, which is not always the case. It also will not be able to capture multi-modal distributions in the latent space well.
  • Approximation Quality: The VAE is limited by the quality of the approximations. If the approximations are not good, you may not be able to learn a good representation of the data. Often, we rely on what is tractable over what is accurate and this will lead to gaps in the learned representation and the true distribution.
  • Sample Blurriness: Often attributed to the averaging effect of a KL divergence term, the VAE can produce blurry samples. This is because the KL divergence term encourages the approximate posterior to be close to the prior, which can lead to a loss of detail in the samples.
  • Mode Collapse: In a VAE, mode collapse can sometimes occur in the focusing on some portions of the distribution while neglecting others. Especially when conditioning training to be better behaved through things like gradient clipping, we can induce mode collapse where a model ignores some portion of the data. This is a common problem in generative models.
  • Balancing Reconstruction and KL Loss: We need to balance the two portions of our loss function. Sometimes this is done through the introduction of a hyperparameter, $\beta$, which is used to weight the KL divergence term. This is a bit of a hack, and it’s not clear how to set $\beta$ in practice.
    • One tip from my own experience is to look at the marginals of your latent distribution. This will allow you to assess how “normal” your latent space is, and if you’re getting a good balance between the two terms. Then, if you want better reconstruction, you can attempt to down-weight the KL divergence term. If you want a more “normal” latent space, you can up-weight the KL divergence term. It’s all empirical and subject to tuning, but this is definitely a fruitful place to experiment with.

Conclusion

In this tutorial, we’ve explored modern PyTorch techniques for building Variational Autoencoders. We’ve covered the fundamentals of VAEs, a modern PyTorch VAE implementation, and validation using the MNIST dataset. We’ve also discussed VAE extensions and limitations.

I hope this tutorial has been helpful in bridging the gap in existing VAE literature by integrating modern PyTorch functionalities like torch.distributions and dataclasses for more efficient, cleaner code. This tutorial is aimed at advancing understanding and application of VAEs with the latest PyTorch features.

If you have any questions or feedback, feel free to reach out.