Generative Adversarial Networks(GANs)
- Generative Adversarial Networks (GANs) introduced by Goodfellow are a type of deep learning model used to generate new data that looks like real data such as images, audio, or text.
- A GAN is a machine learning system where two neural networks compete against each other:
- Generative Model (Generator, G):
- Starts with random noise (like static on a TV).
- Tries to create fake data that looks like the real thing (e.g. images, sounds).
- Discriminative Model (Discriminator, D):
- Looks at both real data (from the training set) and fake data (from the generator).
- Tries to distinguish between real and fake.
- Outputs a score between 0 and 1 (closer to 1 = real, closer to 0 = fake).
- Generative Model (Generator, G):
- Adversarial Training:
- These two networks are like adversaries (hence the name “adversarial”). They compete in a game:
- The Generator wants to fool the Discriminator.
- The Discriminator wants to catch the Generator’s fakes.
- This competition improves both models until fakes are indistinguishable from real data.
- These two networks are like adversaries (hence the name “adversarial”). They compete in a game:
Adversarial Nets & How GANs Work
- Both of the models Generative Model & Discriminative Model are multilayer perceptrons
- To create synthetic (fake) data, we start by feeding input noise variables pz(z) into the Generator to learn the Generator’s distribution pg over data x.
- We then define a mapping from this noise space to the data space using a differentiable function G(z; Ɵg), which is implemented as a multilayer perceptron with parameters Ɵg.
- The Discriminator compares Real data from the dataset and Fake data from the Generator.
- To do this comparison, we also define a second multilayer perceptron, D(x;θd)
which produces a single scalar output. - The function D(x) estimates the probability that a given input x originates from the real data distribution rather than from pg
- Both networks are trained using backpropagation, which adjusts their parameters to get better over time.
- We train the discriminator D to maximize the probability of classifying correctly both real training data and samples generated by G.
- At the same time, we train the generator G to minimize log(1−D(G(z))), thereby encouraging it to produce outputs that the discriminator is more likely to classify as real.
- The Discriminator gives feedback to the Generator.
- The Generator learns from this and produces better fakes.
- Over time, the fake data becomes more and more realistic.
- From Mathematical Intuition, the training is based on a minimax game with a value function:
minGmaxD V(D,G) = Ex∼Pdata [logD(x)] + Ez∼Pz[log(1−D(G(z)))]
D(x): the probability that input x is real
G(z): Generator’s output from random noise z
- In practice, the above Equation may not give strong enough gradients for effective training of G.
- In the initial stage of training, when G generates low-quality samples that D can easily distinguish them from real data, resulting in weak learning signals for G
- That’s why, instead of training G to minimize log(1−D(G(z))), G can be trained to maximize logD(G(z)), which produce stronger learning signals.
Types of GANs
Vanilla GANs
- Basic form of GANs with a generator and a discriminator in an adversarial setup.
- Generator creates fake data samples.
- Discriminator tries to distinguish between real and fake samples.
- Uses simple multilayer perceptrons (MLPs) for both components.
- Easy to implement due to straightforward architecture.
- MLPs help process and classify data based on known patterns.
- Training is often unstable and requires careful hyperparameter tuning for good performance.
Conditional GANs (cGAN)
# Import Necessary Library
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import fashion_mnist
# Step 1: Load and preprocess dataset
(x_train, _), (_, _) = fashion_mnist.load_data()
x_train = x_train / 127.5 - 1.0 # Normalize to [-1, 1]
x_train = np.expand_dims(x_train, axis=-1) # Add channel dimension
BUFFER_SIZE = 60000
BATCH_SIZE = 128
LATENT_DIM = 100 # Dimension of noise vector
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# Step 2: Build Generator
def build_generator():
model = models.Sequential([
layers.Dense(7*7*256, use_bias=False, input_shape=(LATENT_DIM,)),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Reshape((7, 7, 256)),
layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
layers.BatchNormalization(),
layers.LeakyReLU(),
layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
])
return model
# Step 3: Build Discriminator
def build_discriminator():
model = models.Sequential([
layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
layers.LeakyReLU(),
layers.Dropout(0.3),
layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
layers.LeakyReLU(),
layers.Dropout(0.3),
layers.Flatten(),
layers.Dense(1)
])
return model
generator = build_generator()
discriminator = build_discriminator()
# Step 4: Define loss functions and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
return real_loss + fake_loss
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# Step 5: Training step
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_gen, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_disc, discriminator.trainable_variables))
# Step 6: Training loop
def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
print(f'Epoch {epoch + 1} completed.')
generate_and_plot_images(generator, tf.random.normal([16, LATENT_DIM]))
# Step 7: Generate and plot images
def generate_and_plot_images(model, test_input):
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.show()
# Step 8: Run training
EPOCHS = 30
train(dataset, EPOCHS)
N.B.:
- The reason of early images look like fuzzy blobs or abstract shapes is that in the early stages of training a GAN, the generator has no idea what a real image looks like.
- It starts from pure noise. As training progresses, it slowly learns to produce more realistic outputs by trying to “fool” the discriminator.
- I have provided images after 6 epochs, as it took 1 hour and 44 minutes already, so I didn’t allow the code to run up to 30 epochs due to time & computational cost.
- However, you can see that these images are becoming clearer gradually from total blur.
- It was supposed to give some new fashion product images.
Reference
- Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., … & Bengio, Y. (2014). Generative adversarial nets. Advances in neural information processing systems. https://arxiv.org/abs/1406.2661
- IBM. (n.d.). What are Generative Adversarial Networks (GANs)? IBM Think. Retrieved from https://www.ibm.com/think/topics/generative-adversarial-networks