Batch Normalization

Accelerating deep network training by reducing internal covariate shift

What is Batch Normalization?

Analogy: Think of batch normalization like standardizing test scores across different schools. If one school has very hard tests (low scores) and another has easy tests (high scores), it's hard to compare students fairly. By normalizing scores to a standard scale (mean=0, std=1), we can compare apples to apples. Similarly, batch normalization standardizes the inputs to each layer during training, making the network easier to train and less sensitive to initialization.

Batch Normalization (BatchNorm), introduced by Ioffe and Szegedy in 2015, is a technique that normalizes the inputs of each layer to have zero mean and unit variance within each mini-batch. It then applies learnable scale (γ) and shift (β) parameters to restore the representational power of the network.

The Problem:

The Internal Covariate Shift Problem: During training, the distribution of inputs to each layer constantly changes as the parameters of previous layers update. This makes training slow and requires careful initialization and low learning rates. BatchNorm addresses this by normalizing layer inputs.

How Batch Normalization Works

Algorithm (During Training)

  1. Compute the mean (μ) and variance (σ²) of the current mini-batch
  2. Normalize: subtract mean and divide by standard deviation: x̂ = (x - μ) / √(σ² + ε)
  3. Scale and shift with learnable parameters: y = γx̂ + β
  4. Track running statistics (moving average of μ and σ²) for inference
Note: ε (epsilon) is a small constant (e.g., 1e-5) added for numerical stability to prevent division by zero.

During Inference (Testing)

At test time, we don't have a batch, so we use the running statistics (population mean and variance) computed during training:

x̂ = (x - μ_running) / √(σ²_running + ε), then y = γx̂ + β

Learnable Parameters

  • γ (gamma): Scale parameter - learned via backpropagation
  • β (beta): Shift parameter - learned via backpropagation
These parameters allow the network to undo normalization if needed. For example, if the optimal distribution for a layer is not zero-mean unit-variance, the network can learn appropriate γ and β.

Benefits of Batch Normalization

Faster Training

Allows higher learning rates (often 10-100x higher) because normalization reduces sensitivity to parameter scale

Reduces Internal Covariate Shift

Stabilizes the distribution of layer inputs, making training more stable

Acts as Regularization

Introduces noise through batch statistics, which has a regularizing effect similar to dropout

Less Sensitive to Initialization

Networks train successfully even with poor weight initialization

Enables Deeper Networks

Makes it possible to train very deep networks (e.g., ResNets with 100+ layers)

May Reduce Need for Dropout

The regularization effect can sometimes replace dropout

Where to Place BatchNorm

After Activation (Original Paper)

Linear → Activation → BatchNorm

Recommended in the original 2015 paper

Before Activation (Modern Practice)

Linear → BatchNorm → Activation

More commonly used today, especially with ReLU

In Convolutional Networks

Conv2D → BatchNorm2D → ReLU → MaxPool

BatchNorm2D normalizes across spatial dimensions

Normalization Variants

NormalizationNormalizes OverUse Case
Batch Normalization (BatchNorm)Batch dimensionMost common for large batches (CNNs, MLPs)
Layer Normalization (LayerNorm)Feature dimensionTransformers, RNNs, small batch sizes
Instance Normalization (InstanceNorm)Spatial dimensions per channelStyle transfer, GANs
Group Normalization (GroupNorm)Groups of channelsSmall batch sizes, object detection

Important Considerations

Training Mode vs Eval Mode

CRITICAL: Always set model.train() during training and model.eval() during testing. BatchNorm behaves completely differently in these modes!

Training Mode:

Uses batch statistics (mean/var of current batch)

Eval Mode:

Uses running statistics (accumulated during training)

Batch Size Dependency

Issue: BatchNorm performance degrades with very small batches (< 16). The batch statistics become unreliable noise.

Solution: Use LayerNorm or GroupNorm for small batches

Distributed Training

Issue: In multi-GPU training, each GPU computes statistics on its local batch, which may not represent the full distribution.

Solution: Use SyncBatchNorm to synchronize statistics across GPUs

Advantages

  • Significantly speeds up training (2-10x faster convergence)
  • Allows much higher learning rates
  • Makes network less sensitive to weight initialization
  • Has regularizing effect, may reduce need for dropout
  • Enables training of very deep networks
  • Improves gradient flow through the network
  • Standard in modern architectures (ResNet, Inception, etc.)

Limitations

  • Introduces dependency on batch size (problematic for small batches)
  • Different behavior during training vs inference (can cause bugs)
  • Adds computational overhead
  • Not suitable for online learning (single sample at a time)
  • Can interact poorly with certain architectures or dropout
  • Requires careful handling in distributed training
  • May not work well with RNNs (LayerNorm preferred)

Code Examples

1. Batch Normalization from Scratch

Understanding the math behind BatchNorm

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
class BatchNorm1D:
"""Batch Normalization implementation from scratch"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# Learnable parameters
self.gamma = np.ones(num_features) # Scale
self.beta = np.zeros(num_features) # Shift
# Running statistics (for inference)
self.running_mean = np.zeros(num_features)
self.running_var = np.ones(num_features)
def forward(self, x, training=True):
"""
Args:
x: Input of shape (batch_size, num_features)
training: If True, use batch statistics; if False, use running statistics
Returns:
out: Normalized output
"""
if training:
# Step 1: Compute batch statistics
batch_mean = x.mean(axis=0)
batch_var = x.var(axis=0)
# Step 2: Normalize
x_normalized = (x - batch_mean) / np.sqrt(batch_var + self.eps)
# Step 3: Scale and shift
out = self.gamma * x_normalized + self.beta
# Step 4: Update running statistics (exponential moving average)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
else:
# Use running statistics for inference
x_normalized = (x - self.running_mean) / np.sqrt(self.running_var + self.eps)
out = self.gamma * x_normalized + self.beta
return out
# Example usage
bn = BatchNorm1D(num_features=3)
# Training data (batch_size=4, num_features=3)
x_train = np.random.randn(4, 3) * 10 + 5 # Mean ~5, high variance
print(f"Input mean: {x_train.mean(axis=0)}")
print(f"Input var: {x_train.var(axis=0)}")
# Forward pass in training mode
out_train = bn.forward(x_train, training=True)
print(f"\nOutput mean (training): {out_train.mean(axis=0)}") # ~0 (normalized)
print(f"Output var (training): {out_train.var(axis=0)}") # ~1 (normalized)
# Forward pass in inference mode
out_test = bn.forward(x_train, training=False)
print(f"\nOutput (inference): {out_test.mean(axis=0)}")

2. Using BatchNorm in PyTorch CNN

Standard usage in convolutional networks

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
class CNNWithBatchNorm(nn.Module):
"""Standard CNN with Batch Normalization"""
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
# Block 1: Conv BatchNorm ReLU Pool
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64), # BatchNorm2D for convolutional layers
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Block 2
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Block 3
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 4 * 4, 512),
nn.BatchNorm1d(512), # BatchNorm1D for fully connected layers
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# Create model
model = CNNWithBatchNorm(num_classes=10)
print(model)
# Example input (batch of 8 RGB images, 32x32)
x = torch.randn(8, 3, 32, 32)
# CRITICAL: Set training mode
model.train()
output_train = model(x)
print(f"\nTraining mode output shape: {output_train.shape}")
# CRITICAL: Set eval mode for testing
model.eval()
output_eval = model(x)
print(f"Eval mode output shape: {output_eval.shape}")
# Inspect BatchNorm statistics
bn_layer = model.features[1] # First BatchNorm2d layer
print(f"\nBatchNorm running mean: {bn_layer.running_mean[:5]}")
print(f"BatchNorm running var: {bn_layer.running_var[:5]}")
print(f"Learnable gamma (scale): {bn_layer.weight[:5]}")
print(f"Learnable beta (shift): {bn_layer.bias[:5]}")

3. Comparing Different Normalization Techniques

BatchNorm vs LayerNorm vs GroupNorm vs InstanceNorm

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
# Sample input: (batch=2, channels=4, height=3, width=3)
x = torch.randn(2, 4, 3, 3)
print("Input shape:", x.shape)
print("="*60)
# 1. Batch Normalization
# Normalizes across batch dimension for each channel
bn = nn.BatchNorm2d(4)
out_bn = bn(x)
print("\n1. Batch Normalization (BatchNorm2d)")
print(f"Normalizes over: Batch dimension (N)")
print(f"Output shape: {out_bn.shape}")
print(f"For each channel, mean across batch 0:")
print(f" Channel 0 mean: {out_bn[:, 0].mean():.6f}")
# 2. Layer Normalization
# Normalizes across channel dimension for each sample
ln = nn.LayerNorm([4, 3, 3]) # Normalize over C, H, W
out_ln = ln(x)
print("\n2. Layer Normalization (LayerNorm)")
print(f"Normalizes over: Channel, Height, Width dimensions")
print(f"Output shape: {out_ln.shape}")
print(f"For each sample in batch, mean 0:")
print(f" Sample 0 mean: {out_ln[0].mean():.6f}")
# 3. Instance Normalization
# Normalizes across spatial dimensions for each channel and sample
in_norm = nn.InstanceNorm2d(4)
out_in = in_norm(x)
print("\n3. Instance Normalization (InstanceNorm2d)")
print(f"Normalizes over: Spatial dimensions (H, W) for each channel")
print(f"Output shape: {out_in.shape}")
print(f"For each sample and channel, mean across spatial dims 0:")
print(f" Sample 0, Channel 0 mean: {out_in[0, 0].mean():.6f}")
# 4. Group Normalization
# Normalizes across groups of channels
gn = nn.GroupNorm(num_groups=2, num_channels=4) # 4 channels, 2 groups
out_gn = gn(x)
print("\n4. Group Normalization (GroupNorm)")
print(f"Normalizes over: Groups of channels (here: 2 groups)")
print(f"Output shape: {out_gn.shape}")
# Comparison table
print("\n" + "="*60)
print("WHEN TO USE EACH:")
print("="*60)
print("BatchNorm: Large batches, CNNs, standard training")
print("LayerNorm: Transformers, RNNs, small batches, NLP")
print("InstanceNorm: Style transfer, GANs")
print("GroupNorm: Small batches, object detection, batch-independent")

4. Training vs Eval Mode Behavior

Demonstrating the critical difference between modes

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
# Create a simple model with BatchNorm
model = nn.Sequential(
nn.Linear(10, 20),
nn.BatchNorm1d(20),
nn.ReLU()
)
# Sample input
x = torch.randn(32, 10) # Batch of 32
print("="*60)
print("TRAINING MODE (model.train())")
print("="*60)
model.train()
output1 = model(x)
output2 = model(x) # Same input
print(f"Output 1 mean: {output1.mean():.6f}")
print(f"Output 2 mean: {output2.mean():.6f}")
print(f"Outputs are identical: {torch.allclose(output1, output2)}")
print("❌ Different outputs! (uses different batch statistics each time)")
print("\n" + "="*60)
print("EVAL MODE (model.eval())")
print("="*60)
model.eval()
output3 = model(x)
output4 = model(x) # Same input
print(f"Output 3 mean: {output3.mean():.6f}")
print(f"Output 4 mean: {output4.mean():.6f}")
print(f"Outputs are identical: {torch.allclose(output3, output4)}")
print("✓ Identical outputs! (uses fixed running statistics)")
print("\n" + "="*60)
print("WHY THIS MATTERS")
print("="*60)
print("""
Training Mode (model.train()):
- Uses batch statistics (mean/var of current batch)
- Different batches different statistics different outputs
- Updates running statistics
Eval Mode (model.eval()):
- Uses running statistics (accumulated during training)
- Deterministic behavior
- Essential for reproducible inference
COMMON BUG: Forgetting model.eval() during testing
Results in inconsistent predictions!
""")
# Demonstrate the critical importance
bn_layer = model[1]
print(f"Running mean: {bn_layer.running_mean[:5]}")
print(f"Running var: {bn_layer.running_var[:5]}")

5. Impact of BatchNorm on Training Speed

Comparing training with and without BatchNorm

python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.optim as optim
import time
# Synthetic dataset
torch.manual_seed(42)
X_train = torch.randn(1000, 100)
y_train = (X_train.sum(dim=1) > 0).long()
# Model WITHOUT BatchNorm
class ModelWithoutBN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 2)
)
def forward(self, x):
return self.network(x)
# Model WITH BatchNorm
class ModelWithBN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Linear(100, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, 2)
)
def forward(self, x):
return self.network(x)
def train_model(model, lr, epochs=50):
optimizer = optim.SGD(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
losses = []
start_time = time.time()
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
losses.append(loss.item())
train_time = time.time() - start_time
return losses, train_time
# Train without BatchNorm (requires low learning rate)
print("Training WITHOUT BatchNorm (lr=0.01)...")
model_no_bn = ModelWithoutBN()
losses_no_bn, time_no_bn = train_model(model_no_bn, lr=0.01)
# Train with BatchNorm (can use much higher learning rate)
print("Training WITH BatchNorm (lr=0.1 - 10x higher!)...")
model_with_bn = ModelWithBN()
losses_with_bn, time_with_bn = train_model(model_with_bn, lr=0.1)
# Results
print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(f"Without BatchNorm:")
print(f" Final loss: {losses_no_bn[-1]:.4f}")
print(f" Training time: {time_no_bn:.2f}s")
print(f"\nWith BatchNorm:")
print(f" Final loss: {losses_with_bn[-1]:.4f}")
print(f" Training time: {time_with_bn:.2f}s")
print(f"\nSpeedup: {time_no_bn / time_with_bn:.2f}x faster convergence")
print("\n" + "="*60)
print("KEY INSIGHTS")
print("="*60)
print("""
1. BatchNorm allows 10x higher learning rate (0.1 vs 0.01)
2. Converges faster (fewer epochs to reach same loss)
3. More stable training
4. Less sensitive to weight initialization
Try training without BatchNorm at lr=0.1 training will diverge!
""")

Key Concepts

  • Normalization: Transform to zero mean, unit variance within batch
  • γ (gamma) and β (beta): Learnable scale and shift parameters
  • ε (epsilon): Small constant (1e-5) for numerical stability
  • Running statistics: Moving average of mean/variance for inference
  • Momentum: Typically 0.1 for updating running statistics
  • Internal covariate shift: Change in distribution of layer inputs during training
  • Training mode: Uses batch statistics (μ_batch, σ²_batch)
  • Eval mode: Uses running statistics (μ_running, σ²_running)
  • Placement: Usually Conv/Linear → BatchNorm → Activation

Interview Tips

  • 💡Explain the internal covariate shift problem BatchNorm solves
  • 💡Know the normalization formula: x̂ = (x - μ) / √(σ² + ε), then y = γx̂ + β
  • 💡Understand γ and β are learnable parameters (why they exist)
  • 💡Explain training vs eval mode difference (batch stats vs running stats)
  • 💡Know that BatchNorm allows much higher learning rates
  • 💡Discuss the regularization effect from batch statistics noise
  • 💡Understand placement: typically before activation (Linear → BN → ReLU)
  • 💡Know variants: LayerNorm for Transformers/RNNs, GroupNorm for small batches
  • 💡Explain why small batch sizes are problematic (unreliable statistics)
  • 💡Mention it's standard in modern CNNs (ResNet, VGG, Inception)
  • 💡Discuss the training speedup (2-10x faster convergence)
  • 💡Know the original 2015 Ioffe & Szegedy paper
  • 💡Explain why model.train() and model.eval() are critical in PyTorch
  • 💡Understand SyncBatchNorm for multi-GPU training
  • 💡Compare with LayerNorm: BatchNorm for CNNs, LayerNorm for Transformers