Generative Adversarial Networks (GAN)

Understanding Generative Adversarial Networks for generating realistic synthetic data

What are GANs?

Generative Adversarial Networks (GANs) are a class of deep learning models introduced by Ian Goodfellow in 2014. GANs consist of two neural networks, a Generator and a Discriminator, that compete against each other in a zero-sum game framework. The generator creates fake data, while the discriminator tries to distinguish between real and fake data.

💡 Simple Analogy:

Think of GANs like an art forger (generator) trying to create fake paintings, while an art expert (discriminator) tries to identify which paintings are real and which are fake. As they compete, both become better at their tasks - the forger creates more realistic fakes, and the expert becomes better at detecting them.

GAN Architecture

A GAN consists of two neural networks trained simultaneously:

Generator (G)

Creates fake data from random noise

Takes a random vector (latent space) as input and transforms it into synthetic data that resembles the training data. The goal is to fool the discriminator.

Input: Random noise vector z ~ p(z)

Output: Synthetic data G(z)

Discriminator (D)

Distinguishes between real and fake data

A binary classifier that takes data as input and outputs a probability that the data is real (from the training set) rather than fake (from the generator).

Input: Real data x or fake data G(z)

Output: Probability D(x) ∈ [0, 1]

Minimax Objective

min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]

The discriminator maximizes the objective (correctly classifying real vs fake), while the generator minimizes it (fooling the discriminator).

Training Process

GANs are trained through an adversarial process:

1. Train Discriminator

Feed real data (labeled as 1) and generated fake data (labeled as 0). Update discriminator weights to maximize classification accuracy.

2. Train Generator

Generate fake data and feed it to the discriminator. Update generator weights to maximize the discriminator's error (make it classify fake as real).

3. Repeat

Alternate between training discriminator and generator until reaching Nash equilibrium where the generator produces realistic data and the discriminator outputs 0.5 (can't tell the difference).

Convergence: Ideally, the GAN reaches equilibrium where D(G(z)) = 0.5 for all z, meaning the discriminator cannot distinguish real from fake data.

Code Example: Simple GAN

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Simple GAN Implementation in PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
# Generator Network
class Generator(nn.Module):
def __init__(self, latent_dim=100, img_dim=784): # 28x28 = 784
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, img_dim),
nn.Tanh() # Output in range [-1, 1]
)
def forward(self, z):
return self.model(z)
# Discriminator Network
class Discriminator(nn.Module):
def __init__(self, img_dim=784):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid() # Output probability [0, 1]
)
def forward(self, img):
return self.model(img)
# Training Loop
latent_dim = 100
lr = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
opt_gen = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()
def train_step(real_images):
batch_size = real_images.size(0)
# Labels
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ========== Train Discriminator ==========
# On real images
real_output = discriminator(real_images)
loss_real = criterion(real_output, real_labels)
# On fake images
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
fake_output = discriminator(fake_images.detach())
loss_fake = criterion(fake_output, fake_labels)
# Total discriminator loss
loss_disc = (loss_real + loss_fake) / 2
opt_disc.zero_grad()
loss_disc.backward()
opt_disc.step()
# ========== Train Generator ==========
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
output = discriminator(fake_images)
# Generator tries to maximize discriminator error
loss_gen = criterion(output, real_labels)
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
return loss_disc.item(), loss_gen.item()
# Usage:
# for epoch in range(num_epochs):
# for real_images, _ in dataloader:
# real_images = real_images.view(-1, 784).to(device)
# d_loss, g_loss = train_step(real_images)

GAN Variants

Many GAN architectures have been developed to address specific challenges:

DCGAN (Deep Convolutional GAN)

Uses convolutional and transposed convolutional layers

Features:

  • Removes fully connected layers
  • Uses batch normalization
  • ReLU in generator, LeakyReLU in discriminator

Use Cases:

Image generation, more stable training

CGAN (Conditional GAN)

Conditions generation on additional information (labels, text, images)

Features:

  • Both G and D receive conditioning input
  • Controls the mode of data generated

Use Cases:

Text-to-image, image-to-image translation, class-specific generation

StyleGAN / StyleGAN2

Generates high-quality images with controllable style features

Features:

  • Progressive growing
  • Style mixing
  • High resolution (1024×1024+)

Use Cases:

Photorealistic face generation, art creation, deepfakes

CycleGAN

Image-to-image translation without paired examples

Features:

  • Cycle consistency loss
  • No need for paired training data

Use Cases:

Style transfer (photo to painting), season transfer, horse to zebra

Pix2Pix

Paired image-to-image translation

Features:

  • Conditional GAN with U-Net generator
  • PatchGAN discriminator

Use Cases:

Sketch to photo, black&white to color, semantic labels to photo

WGAN (Wasserstein GAN)

Uses Wasserstein distance instead of Jensen-Shannon divergence

Features:

  • More stable training
  • Meaningful loss metric
  • Lipschitz constraint

Use Cases:

Addresses mode collapse, vanishing gradients

Code Example: DCGAN

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# DCGAN (Deep Convolutional GAN) for Image Generation
import torch
import torch.nn as nn
class DCGANGenerator(nn.Module):
def __init__(self, latent_dim=100, channels=3):
super(DCGANGenerator, self).__init__()
self.main = nn.Sequential(
# Input: latent_dim x 1 x 1
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# State: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# State: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# State: 128 x 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# State: 64 x 32 x 32
nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
nn.Tanh()
# Output: channels x 64 x 64
)
def forward(self, z):
# Reshape latent vector to (batch, latent_dim, 1, 1)
z = z.view(z.size(0), z.size(1), 1, 1)
return self.main(z)
class DCGANDiscriminator(nn.Module):
def __init__(self, channels=3):
super(DCGANDiscriminator, self).__init__()
self.main = nn.Sequential(
# Input: channels x 64 x 64
nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# State: 64 x 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# State: 128 x 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# State: 256 x 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# State: 512 x 4 x 4
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# Output: 1 x 1 x 1
)
def forward(self, img):
return self.main(img).view(-1, 1)
# Initialize weights
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Usage
generator = DCGANGenerator(latent_dim=100, channels=3)
discriminator = DCGANDiscriminator(channels=3)
generator.apply(weights_init)
discriminator.apply(weights_init)

Code Example: Conditional GAN

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Conditional GAN - Generate images based on class labels
import torch
import torch.nn as nn
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=100, num_classes=10, img_dim=784):
super(ConditionalGenerator, self).__init__()
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z, labels):
# Concatenate latent vector with label embedding
label_embed = self.label_embedding(labels)
gen_input = torch.cat([z, label_embed], dim=1)
return self.model(gen_input)
class ConditionalDiscriminator(nn.Module):
def __init__(self, num_classes=10, img_dim=784):
super(ConditionalDiscriminator, self).__init__()
self.label_embedding = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(img_dim + num_classes, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# Concatenate image with label embedding
label_embed = self.label_embedding(labels)
disc_input = torch.cat([img, label_embed], dim=1)
return self.model(disc_input)
# Training with conditional GAN
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 100
num_classes = 10
gen = ConditionalGenerator(latent_dim, num_classes).to(device)
disc = ConditionalDiscriminator(num_classes).to(device)
# Generate specific digit (e.g., digit 7)
z = torch.randn(16, latent_dim).to(device)
labels = torch.full((16,), 7, dtype=torch.long).to(device) # Generate 16 images of digit 7
fake_images = gen(z, labels)
# Now you can control what class of images to generate!

Training Challenges & Solutions

GANs are notoriously difficult to train. Common challenges include:

Mode Collapse

Problem:

Generator produces limited variety of outputs (collapses to single mode)

Symptoms:

All generated images look very similar

Solutions:

  • Minibatch discrimination
  • Unrolled GANs
  • Use different loss functions (WGAN)

Training Instability

Problem:

Generator and discriminator don't converge, loss oscillates

Symptoms:

Losses fluctuate wildly, no convergence

Solutions:

  • Two-timescale update rule (TTUR)
  • Spectral normalization
  • Progressive growing

Vanishing Gradients

Problem:

When discriminator is too good, generator gradients vanish

Symptoms:

Generator stops learning

Solutions:

  • Feature matching
  • Modified loss functions
  • Wasserstein loss

Hyperparameter Sensitivity

Problem:

Small changes in hyperparameters cause training failure

Symptoms:

Hard to reproduce results

Solutions:

  • Extensive hyperparameter search
  • Use proven architectures (DCGAN)
  • Self-attention mechanisms

Code Example: WGAN (Wasserstein GAN)

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# WGAN (Wasserstein GAN) - More stable training
import torch
import torch.nn as nn
import torch.optim as optim
class WGANCritic(nn.Module):
"""Critic (not discriminator) - outputs unbounded score"""
def __init__(self, img_dim=784):
super(WGANCritic, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
# No sigmoid! Output is unbounded
)
def forward(self, img):
return self.model(img)
class WGANGenerator(nn.Module):
def __init__(self, latent_dim=100, img_dim=784):
super(WGANGenerator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# Training loop for WGAN
def train_wgan(critic, generator, dataloader, n_critic=5, clip_value=0.01):
"""
n_critic: Train critic n_critic times per generator update
clip_value: Weight clipping for Lipschitz constraint
"""
opt_critic = optim.RMSprop(critic.parameters(), lr=0.00005)
opt_gen = optim.RMSprop(generator.parameters(), lr=0.00005)
for real_images, _ in dataloader:
batch_size = real_images.size(0)
real_images = real_images.view(batch_size, -1).to(device)
# ========== Train Critic ==========
for _ in range(n_critic):
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
# Wasserstein loss: maximize D(real) - D(fake)
critic_real = critic(real_images).mean()
critic_fake = critic(fake_images.detach()).mean()
loss_critic = -(critic_real - critic_fake) # Negative because we're minimizing
opt_critic.zero_grad()
loss_critic.backward()
opt_critic.step()
# Weight clipping for Lipschitz constraint
for p in critic.parameters():
p.data.clamp_(-clip_value, clip_value)
# ========== Train Generator ==========
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = generator(z)
# Generator tries to maximize D(fake)
loss_gen = -critic(fake_images).mean()
opt_gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Advantages of WGAN:
# 1. More stable training
# 2. Loss correlates with sample quality
# 3. Reduced mode collapse
# 4. No need to balance G and D carefully

Real-World Applications

GANs have revolutionized several domains:

Image Synthesis

  • Photorealistic face generation (ThisPersonDoesNotExist.com)
  • Art creation and style transfer
  • Super-resolution (enhancing image quality)
  • Deepfakes and face swapping

Data Augmentation

  • Generating synthetic training data
  • Balancing imbalanced datasets
  • Creating medical images for rare conditions
  • Synthetic data for privacy preservation

Image-to-Image Translation

  • Converting day to night scenes
  • Colorizing black and white photos
  • Converting sketches to photos
  • Changing seasons in landscapes

Text and Audio

  • Text-to-image generation (DALL-E style)
  • Voice synthesis and conversion
  • Music generation
  • Video game asset creation

Drug Discovery

  • Generating novel molecular structures
  • Predicting protein structures
  • Designing new drug candidates

Anomaly Detection

  • Fraud detection in finance
  • Medical image analysis for abnormalities
  • Manufacturing defect detection

Key Concepts

Latent Space

The random noise vector input to the generator, representing a compressed representation of the data distribution.

Adversarial Training

Training paradigm where two models compete, improving through competition rather than direct optimization.

Nash Equilibrium

The ideal convergence point where neither generator nor discriminator can improve without the other adapting.

Mode Collapse

When the generator learns to produce only a limited variety of samples, ignoring parts of the data distribution.

Discriminator Saturation

When the discriminator becomes too confident (outputs near 0 or 1), leading to vanishing gradients for the generator.

Inception Score (IS)

Metric to evaluate GAN quality: how realistic and diverse the generated images are.

Fréchet Inception Distance (FID)

Better metric than IS, measures distance between feature distributions of real and generated images.

Interview Tips

  • 💡Explain the adversarial training process clearly: two networks competing in a zero-sum game
  • 💡Understand the minimax objective function and what each term represents
  • 💡Know common GAN variants: DCGAN, CGAN, StyleGAN, CycleGAN, Pix2Pix, WGAN
  • 💡Be ready to discuss mode collapse and training instability with solutions
  • 💡Explain the difference between generator and discriminator loss functions
  • 💡Understand evaluation metrics: Inception Score (IS) and Fréchet Inception Distance (FID)
  • 💡Know real-world applications: image synthesis, data augmentation, image translation
  • 💡Explain why GANs are hard to train compared to supervised learning
  • 💡Discuss the vanishing gradient problem when discriminator is too strong
  • 💡Understand conditional GANs and how they enable controlled generation
  • 💡Know the difference between WGAN and traditional GAN (Wasserstein distance vs JS divergence)
  • 💡Be able to write basic GAN training loop in PyTorch or TensorFlow
  • 💡Understand the role of batch normalization in stabilizing GAN training
  • 💡Explain how GANs can be used for semi-supervised learning
  • 💡Discuss ethical concerns: deepfakes, misinformation, privacy violations