Skip to content

Variational Autoencoders

Dimension Reduction and Autoencoder Basics

Autoencoders

As a recap, an encoder-decoder is a fair of function (often non-linear functions implmented as neural networks).

  • Encoder \(g: D\rightarrow F\)
  • Decoder \(f: F\rightarrow D\)

where \(D\) is the input space, \(F\) is the feature vector, where \(\dim(F)\ll \dim(D)\).

The goal of an atuoencoder is that for any input

\[\forall x \in D. \tilde x = f(g(x)) \approx x\]

Therefore, we could compress input data \(x\) into a lower dimension features, for storage, comparisons, visualizations, etc.

Proximity Issue

Note that \(f,g\) are non-linear, and decoder can not an inverse function of \(f\neq g^{-1}\). Thus, small change in the input space does not corresponds to small change in the feature space. In fact, \(f\) and \(g\) can acts as some hashing functions that prevents proximity.

In addition, if the data space has discontinuities, for example, several clusters with remote means and small variance. Then, if we sample from the "vacuum" in the input space, then the feature can be totally random.

Variational Autoencoders

VAEs solve proximity issue by introducing uncertainty. By adding noise, VAEs force the encoder to output more uniformly continuous results. The encoder outputs a distribution \(q_\theta(z|x)\) instead of a deterministic feature vector. Normally, we will use Gaussian distribution for \(q\), so that encoder outputs \((\mu, \sigma) = f(x)\), in practice \(\Sigma\) is the diagonal matrix.

The idea of VAE comes from variational inference. Let the feature vector be the laten variable, and input data be the observations, the autoencoder aims to recover the posterior \(p(z|x)\) from likelihood and prior. And we use \(q_\theta\) to approximate such \(p\) due to computational limits. For optimization / training VAEs, we use ELBO as its loss. Thus, we have

  • Encoder \(q_{\phi_i}(z|x_i) = \mathcal N(\mu_i, \sigma_i^2)\) where we store the mean \(\mu_i\) and log std \(\sigma_i\) for each input.
  • Decoder \(f(z_i) = \tilde \theta\), typically a neural network that output parameters for a class of distribution.

Pipeline

For each given input \(x_i\), we have the forward path as

  1. The encoder NN \(g\) output \(\phi_i = g(x_i)\).
  2. Sample latent vector \(z_i \sim q_{\phi_i}(z|x_i)\).
  3. The decoder NN \(f\) output \(\theta = f(z_i)\).
  4. Sample decoded sample \(\hat x_i = p_{\theta}(x|z)\)

Consider the backward path, we have two ELBOs for \(p_\theta\) and \(q_\phi\) be the losses, note that ELBO loss requires the reparameterization tricks, similar to VI.

Amortized VAE

Instead of doing VI from scratch for each \(x_i\), we learn a function that can look at the dataset all together. Therefore, instead of learning separate parameters \(\phi_i\) for each input, we learn a single global distribution \(\phi\) that specifies the parameters of the recognition model, and \(\phi\) is the parameter sets we want to store.

The basic for AVAE pipeline is very similar to VAE, except that we now have a high-dimensional model. Thus, we sample from

\[z_i \sim q_{\phi}(z|x_i) = \mathcal N(\mu_\phi(x_i), \Sigma_\phi(x_i))\]

Implementation Example: MNIST

Using the classical MNIST example

  • Input data: \(28 * 28\) pixels of value \([0,1]\), representing grayscale.
  • Likelihood function \(p_\theta(x|z) = \text{Bernoulli}(\theta)\)
  • Approximate posterior \(q_{\phi_i}(z|x_i) = \mathcal N(\mu_i, \sigma_i)\)
  • Loss be the ELBO loss + Reconstruction loss

 class VAE(nn.Module):
     def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
         super(VAE, self).__init__()

         # encoder part
         self.fc1 = nn.Linear(x_dim, h_dim1)
         self.fc2 = nn.Linear(h_dim1, h_dim2)
         self.fc31 = nn.Linear(h_dim2, z_dim)
         self.fc32 = nn.Linear(h_dim2, z_dim)
         # decoder part
         self.fc4 = nn.Linear(z_dim, h_dim2)
         self.fc5 = nn.Linear(h_dim2, h_dim1)
         self.fc6 = nn.Linear(h_dim1, x_dim)

     def encoder(self, x):
         h = F.relu(self.fc1(x))
         h = F.relu(self.fc2(h))
         return self.fc31(h), self.fc32(h) # mu, log_var

     def sampling(self, mu, log_var):
         std = torch.exp(0.5*log_var)
         eps = torch.randn_like(std)
         return eps.mul(std).add_(mu) # return z sample

     def decoder(self, z):
         h = F.relu(self.fc4(z))
         h = F.relu(self.fc5(h))
         return torch.sigmoid(self.fc6(h)) 

     def forward(self, x):
         mu, log_var = self.encoder(x.view(-1, 784))
         z = self.sampling(mu, log_var)
         return self.decoder(z), mu, log_var

 def ELBO_loss(recon_x, x, mu, log_var):
     # ~ Bernoulli ELBO
     BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
     # Normal ELBO
     KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
     return BCE + KLD

 def train(model, train_loader, optimizer):
     model.train()
     train_loss = 0
     for data, _ in train_loader:
         data = data.cuda()
         optimizer.zero_grad()
         recon_batch, mu, log_var = model(data)
         loss = ELBO_loss(recon_batch, data, mu, log_var)
         loss.backward()
         train_loss += loss.item()
         optimizer.step()
     return train_loss / len(train_loader.dataset)

 def test(model, test_loader):
     model.eval()
     test_loss= 0
     with torch.no_grad():
         for data, _ in test_loader:
             data = data.cuda()
             recon, mu, log_var = model(data)
             # sum up batch loss
             test_loss += ELBO_loss(recon, data, mu, log_var).item()
     return test_loss / len(test_loader.dataset)
The training process

loss

The reconstruction

loss

Random samples from the hidden vector code

loss

Source code
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt


# --8<-- [start:vae]
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()

        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample

    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return torch.sigmoid(self.fc6(h)) 

    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

def ELBO_loss(recon_x, x, mu, log_var):
    # ~ Bernoulli ELBO
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # Normal ELBO
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

def train(model, train_loader, optimizer):
    model.train()
    train_loss = 0
    for data, _ in train_loader:
        data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, log_var = model(data)
        loss = ELBO_loss(recon_batch, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss / len(train_loader.dataset)

def test(model, test_loader):
    model.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            recon, mu, log_var = model(data)
            # sum up batch loss
            test_loss += ELBO_loss(recon, data, mu, log_var).item()
    return test_loss / len(test_loader.dataset)
# --8<-- [end:vae]

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2).cuda()
optimizer = torch.optim.Adam(vae.parameters())
nepoch = 40
train_loss = np.empty(nepoch)
test_loss = np.empty(nepoch)
from tqdm import tqdm
for epoch in tqdm(range(nepoch)):
    train_loss[epoch] = train(vae, train_loader, optimizer)
    test_loss[epoch] = test(vae, test_loader)

plt.figure(figsize=(6,4))
plt.title("Loss")
plt.plot(train_loss, label="train loss")
plt.plot(test_loss, label="test loss")
plt.savefig("../assets/vae_loss.png")


fig, axs = plt.subplots(2, 4, figsize=(16, 8))
from random import sample
with torch.no_grad():
  for i, idx in enumerate(sample(range(len(test_dataset)), 4)):
    input = test_dataset[idx][0]
    axs[0, i].imshow(input.view(28, 28).numpy(), cmap="gray")
    input = input.cuda()
    recon, mu, log_var  = vae(input)
    axs[1, i].imshow(recon.view(28, 28).detach().cpu().numpy(), cmap="gray")
    axs[1, i].set_axis_off(); axs[0, i].set_axis_off()
fig.savefig("../assets/vae_recons.png")


fig, axs = plt.subplots(2, 4, figsize=(16, 8.4))
with torch.no_grad():
  for i in range(8):
    z = torch.randn(2)
    axs[i%2, i//2].set_title(f"mu={z[0]:.6f}, sigma={z[1]:.6f}")
    z = z.cuda()
    sample = vae.decoder(z).cuda()
    axs[i%2, i//2].imshow(sample.view(28, 28).detach().cpu().numpy(), cmap="gray")
    axs[i%2, i//2].set_axis_off()
fig.savefig("../assets/vae_random.png")