VAE-ELBO
Variational Autoencoder
Traditional Autoencoders form dense representations with not a lot of meaningful “structure”. This is fine when all you want to do is compress an input and then reconstruct it, but what if you then want to generate new images by sampling the representation space (latent space)? Small movements is the latent space lead to large and discontinuous jumps in the reconstructed output. We need to apply an additional loss to force the created latent space to be smooth and nicely structured.
In the term “Variational Autoencoder” (VAE), “variational” refers to the use of variational inference , a technique from Bayesian statistics. Variational inference approximates complex probability distributions with simpler, parameterized distributions, making computations more tractable.
Understanding Variational Inference in VAEs:
In a VAE, the goal is to model the distribution of data by learning a latent space representation. Directly computing the exact posterior distribution of the latent variables given the data is often computationally intractable. Variational inference addresses this by introducing a simpler, parameterized distribution (the variational distribution) and optimizing its parameters to approximate the true posterior distribution.
How to Visualize Variational Inference in VAEs:
- Encoder Network: Maps input data to a distribution in the latent space, typically assumed to be Gaussian.
- Latent Space Sampling: Samples from this distribution to obtain latent variables.
- Decoder Network: Reconstructs the input data from the sampled latent variables.
By optimizing the variational distribution to approximate the true posterior, VAEs learn a structured latent space that captures the underlying factors of variation in the data.
This approach enables VAEs to generate new, realistic data samples by sampling from the latent space, making them powerful tools in generative modeling.
How to understand the ELBO expression?
The ELBO consists of two primary components:
- Reconstruction Term: Measures how well the decoder reconstructs the input data from the latent variables.
KL Divergence Term: Regularizes the latent space by penalizing deviations of the approximate posterior distribution $q(z x)$ from the prior distribution $p(z)$.
The KL divergence term ensures that the latent variables $z$ are distributed similarly to the prior distribution, typically a standard normal distribution. This regularization prevents the model from overfitting to the training data and encourages the latent space to capture meaningful features of the data.
Balancing Feature Extraction and Regularization:
The weight of the KL divergence term relative to the reconstruction term controls the balance between feature extraction and regularization:
High Weight on KL Divergence: Emphasizing the KL divergence term encourages the latent space to adhere closely to the prior distribution. While this promotes regularization, it may limit the model’s capacity to capture complex features of the data, potentially leading to underfitting.
Low Weight on KL Divergence: Reducing the emphasis on the KL divergence term allows the latent space to capture more intricate features of the data. However, this may result in overfitting, where the model learns noise in the training data rather than generalizable features.</u>
By carefully tuning the balance between these terms, we can achieve a latent space that effectively captures the underlying structure of the data while maintaining regularization to prevent overfitting.
ELBO in VAE
\(D_{\mathrm{KL}}(p\|q)=\int p(x)\ln\left({\frac{p(x)}{q(x)}}\right)d x\) \(D_{\mathrm{KL}}(p\vert\vert q)=\ln\left(\frac{\sigma_{q}}{\sigma_{p}}\right)+\frac{\sigma_{p}^{2}+(\mu_{p}-\mu_{q})^{2}}{2\sigma_{q}^{2}}-\frac{1}{2}\)
\[{\mathcal{L}}(\theta,\phi;x)=\mathbb{E}_{a_{\phi}(z|x)}|\log p\theta(x|z)|-D_{\mathrm{KL}}(q_{\phi}(z|x)\|p(z))\]Sure! Let me break it down in two ways: one with simpler language (layman’s terms) and the other with more academic rigor.
Layman’s Terms
Imagine you have two bell-shaped curves (normal distributions), each representing something like the distribution of people’s heights in two different countries. Each curve has a mean (average height) and a spread (how wide or narrow the curve is, related to the standard deviation).
Now, the KL divergence is a way to measure how “surprised” we would be if we tried to use one curve to predict data from the other. It’s like saying, “How different is the second country’s height distribution from the first?”
Here’s what the formula is telling us:
- First term: $\ln\left( \frac{\sigma_q}{\sigma_p} \right)$
- This compares the spreads (widths) of the curves. If one curve is wider than the other, it penalizes that difference.
- Second term: $\frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{2\sigma_q^2}$
- This part measures how much the means (average heights) differ between the two countries and how much the widths differ. If both the average heights and the widths of the curves are different, it adds to the “surprise.”
- Third term: $- \frac{1}{2}$
- This is just a small adjustment to make sure the formula behaves correctly when the two curves are exactly the same.
So in simple terms: KL divergence measures how much you’re “off” if you use one distribution to describe the other. If the curves are very similar (same spread, same average), the KL divergence will be small. If they’re different, the KL divergence will be large.
Academic Expression
In the context of KL divergence between two univariate normal distributions, $p \sim \mathcal{N}(\mu_p, \sigma_p^2)$ and $q \sim \mathcal{N}(\mu_q, \sigma_q^2)$, the divergence quantifies the relative entropy of the distribution $p$ with respect to $q$, i.e., how much information is lost when $p$ is approximated by $q$. The expression for the KL divergence is given by:
$ D_{\text{KL}}(p | q) = \ln\left( \frac{\sigma_q}{\sigma_p} \right) + \frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{2\sigma_q^2} - \frac{1}{2} $
Term 1: $\ln\left( \frac{\sigma_q}{\sigma_p} \right)$
This term captures the difference in scale (i.e., spread or standard deviation) between the two distributions. When $\sigma_q \neq \sigma_p$, this contributes positively or negatively to the divergence, reflecting how much “wider” or “narrower” the distribution $q$ is compared to $p$.Term 2: $\frac{\sigma_p^2 + (\mu_p - \mu_q)^2}{2\sigma_q^2}$
This term accounts for the difference in both mean and variance between the distributions. The first part, $\frac{\sigma_p^2}{2\sigma_q^2}$, reflects the relative variance of $p$ and $q$. The second part, $\frac{(\mu_p - \mu_q)^2}{2\sigma_q^2}$, quantifies the mean difference, normalized by the variance of $q$. This highlights how much the location (mean) and spread (variance) of $p$ diverge from $q$.Term 3: $- \frac{1}{2}$
This term is a constant that normalizes the divergence expression, ensuring that the KL divergence of identical distributions ($p = q$) is zero. Without this term, the divergence would not behave as expected in the case where the distributions are identical.
Thus, this formula provides a measure of dissimilarity between the two distributions, based on differences in both their means and variances. The KL divergence is zero if and only if the two distributions are identical, and it increases as the distributions become more dissimilar. The formula reflects this intuitive idea by balancing the effects of scale (spread), location (mean), and overall shape (variance).
\(\begin{array}{c}{p(x)=N(\mu_{p},\sigma_{p})}\\ \\ K L(p)=\ln(\frac{1}{\sigma_{p}})+\frac{\sigma_{p}^{2}+(\mu_{p}-0)^{2}}{2\star1^{2}}-\frac{1}{2}\end{array}\) \(\Downarrow\) \(K L(p,q)=-\frac{1}{2}{\left(2\ln(\sigma_{p})-\sigma_{p}^{2}-\mu_{p}^{2}+1\right)}\)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm
1
2
3
4
5
batch_size = 64
lr = 1e-4
nepoch = 10
latent_size = 128
root = "../data"
1
2
3
4
use_data = torch.cuda.is_available()
gpu_indx = 0
device = torch.device(gpu_indx if use_data else "cpu")
device
1
device(type='cuda', index=0)
1
2
3
4
5
6
7
8
9
10
11
12
# Define our transform
# We'll upsample the images to 32x32 as it's easier to contruct our network
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
train_set = Datasets.MNIST(root=root, train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size,shuffle=True, num_workers=4)
test_set = Datasets.MNIST(root=root, train=False, transform=transform, download=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)
1. VAE Loss Function Explanation
Here’s the function you provided:
1
2
3
4
def vae_loss(recon, x, mu, logvar):
recon_loss = F.mse_loss(recon, x)
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
return recon_loss + kl_loss * .1
1.1 Reconstruction Loss (MSE Loss)
1
recon_loss = F.mse_loss(recon, x)
Reconstruction Loss: This term is the Mean Squared Error (MSE) between the reconstructed data (
recon
) and the original data (x
). It measures how well the decoder is able to reconstruct the input data from the latent variables.MSE is given by:
$ \text{MSE} = \frac{1}{N} \sum_{i=1}^N (recon_i - x_i)^2 $
Where:
- $N$ is the number of data points.
- $recon_i$ is the reconstructed data point.
- $x_i$ is the original data point.
This encourages the decoder to produce reconstructions that are as close as possible to the input data.
1.2 KL Divergence Loss
1
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
KL Divergence: This term is the Kullback-Leibler divergence between the approximate posterior distribution $q(z x)$ (which is parameterized by the mean mu
and the log-variancelogvar
) and the prior distribution $p(z)$ (typically a standard normal distribution $\mathcal{N}(0, 1)$).
The formula for KL divergence between two Gaussian distributions $\mathcal{N}(\mu_p, \sigma_p^2)$ and $\mathcal{N}(\mu_q, \sigma_q^2)$ is:
$D_{\text{KL}}(q | p) = \frac{1}{2} \left( \sigma_q^2 + \mu_q^2 - \log(\sigma_q^2) - 1 \right)$
In the VAE context:
- $\mu$ is the mean of the approximate posterior.
- $logvar$ is the log of the variance of the approximate posterior.
Here’s how the formula works:
- $\mu^2$ is the square of the mean of the approximate posterior.
- $logvar$ is the log of the variance, so $logvar.exp()$ is the actual variance $\sigma_q^2$.
- The KL term is then computed as:
$ D_{\text{KL}}(q | p) = -0.5 \times \left(1 + \log(\sigma^2) - \mu^2 - \sigma^2\right) $
This is the formula you are seeing in the code.
- The
mean()
function computes the average over the batch of data points. - The negative sign ensures that we are minimizing the KL divergence.
1.3 Scaling the KL Loss
1
return recon_loss + kl_loss * .1
- The KL loss is scaled by a factor of 0.1. This scaling factor is commonly used to control the relative importance of the KL divergence term compared to the reconstruction loss. By scaling the KL loss, we can adjust how strongly we regularize the latent space. If this term is too large, the model might underfit by overly regularizing the latent variables, while if it is too small, the model might overfit by ignoring the regularization and focusing too much on reconstruction.
2. KL Divergence Formula
The formula you provided is:
$ D_{\text{KL}}(p | q) = -\frac{1}{2} \left( 2 \ln(\sigma_p) - \sigma_p^2 - \mu_p^2 + 1 \right) $
This is a specific case of KL divergence between two univariate Gaussian distributions.
- $p$ is the prior distribution (typically a standard normal $\mathcal{N}(0, 1)$).
- $q$ is the approximate posterior distribution $\mathcal{N}(\mu, \sigma^2)$.
Here’s how the components break down:
- $\mu_p$ is the mean of the prior (0 for a standard normal).
- $\sigma_p^2$ is the variance of the prior (1 for a standard normal).
- $\mu_q$ and $\sigma_q^2$ are the mean and variance of the approximate posterior $q$.
The formula for the KL divergence between two Gaussian distributions $\mathcal{N}(\mu_1, \sigma_1^2)$ and $\mathcal{N}(\mu_2, \sigma_2^2)$ is:
$ D_{\text{KL}}(q | p) = \frac{1}{2} \left( \frac{\sigma_2^2}{\sigma_1^2} + \frac{(\mu_2 - \mu_1)^2}{\sigma_1^2} - 1 - \log\left( \frac{\sigma_2^2}{\sigma_1^2} \right) \right) $
For the standard normal prior $\mathcal{N}(0, 1)$ (i.e., $\mu_p = 0$ and $\sigma_p^2 = 1$), the formula simplifies:
$ D_{\text{KL}}(q | p) = \frac{1}{2} \left( \sigma_q^2 + \mu_q^2 - 1 - \log(\sigma_q^2) \right) $
This is the standard form of the KL divergence used in VAEs. The formula you provided is mathematically equivalent to this, with the terms rearranged to match the format used in code.
Why does logvar equal to \sigma ^ 2?
1. What is logvar
?
In the context of Variational Autoencoders (VAEs), the encoder outputs two parameters:
Mean ( mu
) of the approximate posterior distribution $q(zx)$. - Log of the variance (
logvar
).
The logvar
variable is the logarithm of the variance $\log(\sigma^2)$, not the standard deviation.
2. Exponentiating the logvar
to get the variance
To understand why logvar.exp()
gives us the variance, we need to look at the math:
- Variance ($\sigma^2$) and Log Variance ($\log(\sigma^2)$) are related by logarithmic and exponential functions.
If we have: $ \text{logvar} = \log(\sigma^2) $
Then to recover the variance $\sigma^2$ from logvar
, we need to “undo” the logarithm. The way to undo a logarithmic operation is by applying the exponential function $\exp(x)$ (which is the inverse of the natural logarithm $\ln(x)$):
$ \sigma^2 = e^{\text{logvar}} $
In Python and PyTorch, we use torch.exp()
to compute the exponential function. So, logvar.exp()
computes:
$ \text{logvar.exp()} = e^{\log(\sigma^2)} = \sigma^2 $
3. Why does this work?
- The logarithm $\log(\sigma^2)$ tells us the “power” to which the base $e$ must be raised to get the value $\sigma^2$.
- Applying the exponential function $\exp(x)$ to $\log(\sigma^2)$ essentially “undoes” the logarithm and gives us the original value $\sigma^2$.
So, when we apply torch.exp(logvar)
, we are computing the exponential of the log variance, which recovers the variance $\sigma^2$.
- logvar is the log of the variance: $\log(\sigma_q^2)$.
- By the logarithmic property, $\log(\sigma_q^2) = 2 \log(\sigma_q)$.
- Hence, logvar is equivalent to $2 \ln(\sigma_q)$.
- When you exponentiate logvar (i.e., $logvar.exp()$), you get the variance $\sigma_q^2$.
1
2
3
4
def vae_loss(recon, x, mu, logvar):
recon_loss = F.mse_loss(recon, x)
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
return recon_loss + kl_loss * .1
1