Variational Autoencoders (VAEs)
- A Variational Autoencoder (VAE) is a type of generative model that learns to encode input data into a compressed latent space and then decode it back into meaningful outputs to generate new data similar to the training data.
- Unlike traditional autoencoders, VAEs not only compress data but also introduce a probabilistic sampling or probability distribution of the data and sample from it.
- That means they can create new data, such as images, text, or sounds, by sampling from a latent space.
How VAEs Work?
A VAE consists of two main components:
- Encoder (Inference Network)
It takes input data (e.g., an image) and encodes it into a distribution over latent variables (usually a Gaussian distribution with mean μ and standard deviation σ). - Decoder (Generative Network)
It samples a point z from the latent distribution and reconstructs the input data from that sample.
Mathematical Intuition
Let’s try to understand the math without getting too deep into hardcore equations.
Traditional Autoencoder:
Tries to learn a function:
x → z → x̂
where z is the latent code and x̂ is the reconstructed input.
VAE’s Core Idea:
Instead of learning a single point z, we learn a distribution over z:
- Encoder outputs: μ(x), σ(x)
- Sample latent variable: z ∼ N(μ(x), σ(x))
- Decoder reconstructs x̂ from sampled z.
Loss Function (Objective):
- Reconstruction Loss: How close is x̂ to x? (e.g., MSE or cross-entropy)
- KL Divergence: How close is the learned distribution to the standard normal distribution?
The total loss: Lvae = Reconstruction Loss + KL Divergence
This ensures the encoder learns to compress the input while keeping the latent space well-behaved and smooth for generating new data.
Types of VAEs
- β-VAE
Adds a hyperparameter β to control the balance between reconstruction and KL divergence. Encourages disentangled representations. - Conditional VAE (CVAE)
Adds a condition (like class label) to both encoder and decoder. Good for controlled generation. - Vector Quantized VAE (VQ-VAE)
Uses discrete latent variables and vector quantization.
Applications of VAEs
- Image generation (e.g., fashion, faces)
- Data compression
- Denoising
- Anomaly detection (e.g., medical imaging)
- Generative design in fashion or architecture
- Synthetic data creation for imbalanced datasets
Python Implementation for VAEs
# Import Necessary Libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.datasets import fashion_mnist
import matplotlib.pyplot as plt
# Load dataset
(x_train, _), (x_test, _) = fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = x_test.astype("float32") / 255.0
x_test = np.expand_dims(x_test, -1)
latent_dim = 2 # Low-dimensional latent space
# Encoder
class Encoder(tf.keras.Model):
def __init__(self, latent_dim):
super().__init__()
self.flatten = layers.Flatten()
self.dense1 = layers.Dense(256, activation="relu")
self.mu = layers.Dense(latent_dim)
self.log_var = layers.Dense(latent_dim)
def call(self, x):
x = self.flatten(x)
x = self.dense1(x)
mu = self.mu(x)
log_var = self.log_var(x)
return mu, log_var
# Sampling layer using the reparameterization trick
class Sampling(layers.Layer):
def call(self, inputs):
mu, log_var = inputs
epsilon = tf.random.normal(shape=tf.shape(mu))
return mu + tf.exp(0.5 * log_var) * epsilon
# Decoder
class Decoder(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = layers.Dense(256, activation="relu")
self.dense2 = layers.Dense(28 * 28, activation="sigmoid")
self.reshape = layers.Reshape((28, 28, 1))
def call(self, z):
x = self.dense1(z)
x = self.dense2(x)
return self.reshape(x)
# VAE Model
class VAE(tf.keras.Model):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.sampling = Sampling()
def call(self, x):
mu, log_var = self.encoder(x)
z = self.sampling((mu, log_var))
reconstructed = self.decoder(z)
kl_loss = -0.5 * tf.reduce_mean(1 + log_var - tf.square(mu) - tf.exp(log_var))
self.add_loss(kl_loss)
return reconstructed
# Compile and train
encoder = Encoder(latent_dim)
decoder = Decoder()
vae = VAE(encoder, decoder)
vae.compile(optimizer="adam", loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=10, batch_size=128)
# Generate new images
z = tf.random.normal((16, latent_dim))
generated_images = decoder(z)
# Plot the generated images
plt.imshow(generated_images[1, :, :, 0], cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()
Output Explanation:
- Generated items should be like shirts, shoes, dresses, coats, etc.
- It is blurry but recognizable because it’s only trained for 10 epochs and uses simple dense layers.
References:
- Kingma, D. P., & Welling, M. “An Introduction to Variational Autoencoders.” arXiv preprint arXiv:1906.02691, 2019. Available at: https://arxiv.org/abs/1906.02691
- IBM. “What is a Variational Autoencoder (VAE)?” IBM Think