If you’ve worked with Variational Autoencoders (VAEs), you’ve almost certainly used the standard $\mathcal{L}_1$ objective, or ELBO. It’s trained by taking one sample ($k=1$) from the recognition network to calculate the loss.

A natural question follows: “What if I use more samples? Won’t that make it better?”

The answer is a fantastic “yes, but…” that reveals a crucial insight. Just averaging the loss over $k$ samples isn’t the right idea. Instead, you need to change the objective itself. This post explores the difference between a “multi-sample VAE” and the Importance Weighted Autoencoder (IWAE), a model that uses the same architecture as a VAE but is trained with a fundamentally more powerful objective.

All ideas here are based on the fantastic paper: “Importance Weighted Autoencoders” by Burda, Grosse, and Salakhutdinov.

The Two Ways to Use $k$ Samples

Let’s say we have our encoder $q(h|x)$ and decoder $p(x,h)$. We decide to use $k=5$ samples instead of $k=1$. We have two main options for how to calculate our loss.

Option 1: The “Multi-Sample VAE” (The Naive Way)

This is the most straightforward idea. For each input $x$ in our batch:

  1. Draw 5 samples ($h_1, …, h_5$) from $q(h|x)$.
  2. Calculate the standard VAE $\mathcal{L}_1$ loss for each sample.
  3. Average these 5 losses together.

This is an average of logs. As the IWAE paper shows experimentally, this approach gives you a more stable gradient, but the final performance (in terms of log-likelihood) is “only slightly” better. You’re paying a 5x computational cost for a marginal gain because you’re still optimizing the same “loose” $\mathcal{L}_1$ bound.

Option 2: The Importance Weighted Autoencoder (IWAE) (The Right Way)

The IWAE takes a different approach. For each input $x$:

  1. Draw 5 samples ($h_1, …, h_5$) from $q(h|x)$.
  2. Calculate an “importance weight” $w_i$ for each sample.
  3. Average these 5 weights together.
  4. Take the logarithm of that average.

This is a log of an average, and this mathematical difference is profound.

The Math: Average-of-Logs vs. Log-of-Averages

Let’s make this concrete. The standard VAE $\mathcal{L}_1$ objective is:

$$ \mathcal{L}_1(x) = \mathbb{E} _{h\sim q(h|x)} \left[ \log \frac{p(x,h)}{q(h|x)} \right] $$

A multi-sample VAE simply gets a better estimate of this same value:

$$ \mathcal{L} _{\text{VAE}, k}(x) \approx \frac{1}{k} \sum _{i=1}^{k} \log w_i \quad \text{where} \quad w_i = \frac{p(x,h_i)}{q(h_i|x)} $$

The IWAE objective, $\mathcal{L}_k$, is fundamentally different:

$$ \mathcal{L} _k (x) = \mathbb{E} _{h_1..h_k \sim q(h|x)} \left[ \log \left( \frac{1}{k} \sum _{i=1}^{k} \frac{p(x,h_i)}{q(h_i|x)} \right) \right] $$

In practice, we estimate this with a single Monte Carlo sample (of $k$ latents):

$$ \mathcal{L} _k (x) \approx \log \left( \frac{1}{k} \sum _{i=1}^{k} w_i \right) $$

Because the logarithm is a concave function, Jensen’s Inequality tells us that the “log of an average” is always greater than or equal to the “average of logs.”

$$ \mathcal{L}_k(x) \ge \mathcal{L}_1(x) $$

This means the IWAE is optimizing a strictly tighter lower bound on the true log-likelihood of the data.

Why Does This “Log-of-Average” Matter?

This isn’t just a mathematical curiosity. It has two huge practical benefits.

1. Better Density Estimation

Because $\mathcal{L}_k$ is a tighter bound on the true $p(x)$, optimizing it pushes the model to learn a much better generative distribution. The paper shows that IWAEs achieve “significantly higher log-likelihoods” than VAEs.

2. Richer Latent Representations

This is the most interesting part. The standard VAE $\mathcal{L}_1$ objective “harshly penalizes” the model if its one sample $h$ is a poor explanation for $x$. This pressure forces the recognition network $q(h|x)$ to be “overly simplified” to avoid bad samples, which can lead to the “dead units” problem.

The IWAE objective is more flexible. It only needs one of the $k$ samples to be good. This “increased flexibility” allows the model to learn far more complex posterior distributions and “richer latent space representations.” The paper’s experiments confirm this, showing IWAEs learn to use many more “active units” in their latent space.

What This Looks Like in Code (PyTorch)

The implementation difference makes this crystal clear.

First, the “k-sample” trick: for a batch x of shape [B, D] and k=5 samples, we repeat x to get x_repeated of shape [B*k, D]. We do all our forward passes on this large tensor.

VAE (Multi-Sample, k > 1) Loss

Here, we can still use the analytical KL divergence, which is a big simplification.

# x_repeated has shape [B*k, 784]
# mu, logvar have shape [B*k, latent_dim]
# recon_x has shape [B*k, 784]

# recon_loss_all shape: [B*k]
recon_loss_all = F.binary_cross_entropy(recon_x, x_repeated, reduction='none').sum(dim=1)

# kl_loss_all shape: [B*k]
# We use the simple, analytical KL term!
kl_loss_all = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

# total_loss_all shape: [B*k]
total_loss_all = recon_loss_all + kl_loss_all

# --- The Key Step ---
# Just average all B*k losses. This is the "average of logs".
loss = total_loss_all.mean()

IWAE (k > 1) Loss

Here, we cannot use the analytical KL. We must compute the exact log-probabilities of the specific samples we drew.

# Helper function to compute log-prob of a sample from a Gaussian
def log_prob_gaussian(sample, mu, logvar):
    const = -0.5 * sample.shape[-1] * torch.log(2 * torch.tensor(math.pi))
    log_det = -0.5 * torch.sum(logvar, dim=-1)
    log_exp = -0.5 * torch.sum((sample - mu)**2 / torch.exp(logvar), dim=-1)
    return const + log_det + log_exp

# --- Get the 3 log-prob components ---
# x_repeated, recon_x, z_samples, mu_repeated, logvar_repeated
# all have a first dimension of [B*k]

# 1. log p(x|h_i): Log-Reconstruction Probability
# log_p_x_given_h shape: [B*k]
log_p_x_given_h = -F.binary_cross_entropy(recon_x, x_repeated, reduction='none').sum(dim=1)

# 2. log p(h_i): Log-Prior Probability (under N(0, I))
# log_p_h shape: [B*k]
log_p_h = log_prob_gaussian(z_samples, 0.0, 0.0) # mu=0, logvar=0

# 3. log q(h_i|x): Log-Encoder Probability
# log_q_h_given_x shape: [B*k]
log_q_h_given_x = log_prob_gaussian(z_samples, mu_repeated, logvar_repeated)

# --- The Key Step ---
# Combine to get the log-importance-weight
# log_w shape: [B*k]
log_w = log_p_x_given_h + log_p_h - log_q_h_given_x

# Reshape to [B, k] to group samples by their original input
log_w_matrix = log_w.view(B, k) # B is original batch size

# --- Apply the IWAE Objective (Log-Sum-Exp Trick) ---
# This is the "log of the average"
# log( (1/k) * sum(exp(log_w)) ) = logsumexp(log_w) - log(k)
log_iwae_bound_per_x = torch.logsumexp(log_w_matrix, dim=1) - math.log(k)

# The objective is to MAXIMIZE this bound, so the loss is its negative
loss = -log_iwae_bound_per_x.mean()

The Critical Implementation Detail

Notice the key difference in the final step:

  • VAE: loss = total_loss_all.mean() — average of individual losses
  • IWAE: loss = -torch.logsumexp(log_w_matrix, dim=1).mean() — log of averaged weights

This seemingly small change implements the fundamental mathematical difference between optimizing an “average of logs” versus a “log of averages.”

When to Use Each Approach

ModelWhen to UseKey Benefit
VAE ($k=1$)Your default baseline. It’s fast, simple, and often “good enough” for many tasks.Speed and simplicity.
Multi-Sample VAE ($k>1$)When you want slightly more stable gradients but aren’t ready for the full IWAE complexity.Marginal improvement with minimal code changes.
IWAE ($k>1$)When your baseline VAE is insufficient. Specifically, if you need:
1. The best possible log-likelihood.
2. To fix a “dead unit” problem or learn richer representations.
Better performance and richer latents, at the cost of compute (scales linearly with $k$).

The Computational Trade-off

Both approaches scale linearly with $k$—if you use $k=5$ samples, you’re doing roughly 5x the computation. The question is whether you get 5x the benefit.

For multi-sample VAEs, the answer is usually “no”—you get more stable gradients but only marginal performance improvements.

For IWAEs, the answer is often “yes”—you get meaningfully better log-likelihoods and richer latent representations that can be worth the computational cost.

Conclusion

The next time you’re tempted to use more samples with your VAE, remember: don’t just average the losses. If you’re going to pay the computational cost of $k > 1$, switch to the IWAE objective to get the full benefit.

The mathematical insight is simple but powerful: Jensen’s Inequality tells us that the “log of an average” is always greater than or equal to the “average of logs.” By optimizing this tighter bound, IWAEs achieve better density estimation and learn richer latent representations than standard VAEs.

The implementation requires computing exact log-probabilities rather than using analytical KL divergence, but the result is a fundamentally more powerful model using the exact same architecture.

Want to dive deeper? Check out the original IWAE paper for experimental results and theoretical analysis, or explore my VAE tutorial for hands-on implementation details.