Normalization Layers¶
Normalization layers re-scale intermediate activations so that each sub-layer receives inputs with controlled statistics. Without normalization, deep transformer networks are notoriously difficult to train: activations drift, gradients explode or vanish, and learning rates must be tuned to unreasonably small values. This page covers every normalization technique relevant to modern LLMs, with particular emphasis on RMSNorm (used in LLaMA / Mistral) and LayerNorm (used in GPT / BERT).
1. Internal Covariate Shift¶
Internal Covariate Shift
The change in the distribution of a layer's inputs caused by updates to the parameters of preceding layers during training.1
When every gradient step shifts the input distribution of downstream layers, those layers must continuously re-adapt. This slows convergence and forces conservative learning rates. Normalization combats the problem by ensuring that, regardless of what the preceding layers produce, the inputs to each sub-layer have zero mean and unit variance (or at least controlled scale).
Even at inference time (where weights are frozen), normalization is still applied because the model was trained to expect normalized inputs. Removing it would change the function the network computes.
2. Layer Normalization¶
Layer Normalization (Ba et al. 2016)
For an input vector \( x \in \mathbb{R}^d \):
where
and \(\gamma, \beta \in \mathbb{R}^d\) are learnable scale and shift parameters, and \(\epsilon\) is a small constant for numerical stability (typically \(10^{-6}\)).2
2.1 Key Properties¶
- Per-sample normalization: Statistics are computed over the feature dimension for each sample independently. This means LayerNorm behaves identically at training and inference time, unlike BatchNorm.
- Two learnable parameters per feature: \(\gamma\) (scale) allows the layer to recover the original scale if needed; \(\beta\) (shift) allows recovery of an arbitrary mean.
- Invariances: The output is invariant to shifting and scaling of the input along the feature dimension.
2.2 Gradient Through LayerNorm¶
The Jacobian of LayerNorm is not a simple diagonal matrix. For input \(x\), the derivative of the normalized output \(\hat{x}_i = (x_i - \mu)/\sqrt{\sigma^2 + \epsilon}\) with respect to \(x_j\) is:
where \(\delta_{ij}\) is the Kronecker delta. This couples all features through the mean and variance terms, which can be beneficial for learning but adds computational cost in training.
3. RMS Normalization¶
RMSNorm (Zhang & Sennrich 2019)
where
[ \operatorname{RMS}(x) = \sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2} ]3
3.1 Differences from LayerNorm¶
| Aspect | LayerNorm | RMSNorm |
|---|---|---|
| Mean subtraction | Yes | No |
| Shift parameter \(\beta\) | Yes | No |
| Learnable parameters | \(2d\) | \(d\) |
| Invariance | Shift + scale | Scale only |
| Relative speed | Baseline | ~15% faster |
RMSNorm's simplification is motivated by the empirical observation that the re-centering (mean subtraction) in LayerNorm contributes little to model quality in transformer architectures, while the re-scaling (division by RMS) is the component that matters most for stabilizing activations.3
3.2 Computational Advantage¶
RMSNorm saves computation in two ways:
- One fewer reduction: No need to compute the mean \(\mu\).
- Fewer parameters: No \(\beta\) vector to store or apply.
For a model with \(d = 4096\) and 80 layers (LLaMA-65B scale), eliminating the mean computation and shift parameter across all normalization sites saves measurable wall-clock time and memory bandwidth.
4. Batch Normalization (Brief Treatment)¶
Batch Normalization (Ioffe & Szegedy 2015)
Normalizes across the batch dimension rather than the feature dimension:
where statistics are computed over all samples in the mini-batch for each feature independently.1
Why BatchNorm is rarely used in transformers:
- Variable sequence lengths make batch statistics inconsistent.
- Small batch sizes (common in LLM training) yield noisy statistics.
- Different behavior at training vs. inference (running averages) introduces discrepancies.
- LayerNorm provides the same stabilization benefits without these drawbacks.
5. Group Normalization (Brief Treatment)¶
Group Normalization (Wu & He 2018)
Divides the feature dimension into \(G\) groups and normalizes within each group independently:
where \(g(i)\) maps feature index \(i\) to its group.4
GroupNorm interpolates between LayerNorm (\(G=1\)) and InstanceNorm (\(G=d\)). It is primarily used in vision models (e.g., with convolutional networks) and is uncommon in transformer-based LLMs.
6. Pre-Norm vs Post-Norm Placement¶
The placement of normalization layers relative to attention and feed-forward sub-layers has a significant impact on training stability and gradient flow.
6.1 Post-Norm (Original Transformer)¶
The original "Attention Is All You Need" architecture applies normalization after the residual addition:
6.2 Pre-Norm (Modern Standard)¶
Most modern architectures (LLaMA, GPT-3, Mistral) apply normalization before the sub-layer:
6.3 Gradient Flow Analysis¶
Gradient Flow Comparison
Post-Norm gradient path through \(L\) layers:
The normalization Jacobian at each layer can attenuate or amplify gradients unpredictably.
Pre-Norm gradient path:
The identity matrix \(I\) in each factor ensures that even if \(\partial f_\ell / \partial x_\ell\) is small, the gradient is at least propagated through the skip connection.
6.4 Comparison Diagram¶
flowchart LR
subgraph "Post-Norm (Original)"
direction TB
A1["x"] --> B1["Attention"]
A1 --> C1["+ (Residual)"]
B1 --> C1
C1 --> D1["LayerNorm"]
D1 --> E1["FFN"]
D1 --> F1["+ (Residual)"]
E1 --> F1
F1 --> G1["LayerNorm"]
end
subgraph "Pre-Norm (Modern)"
direction TB
A2["x"] --> B2["RMSNorm"]
B2 --> C2["Attention"]
A2 --> D2["+ (Residual)"]
C2 --> D2
D2 --> E2["RMSNorm"]
E2 --> F2["FFN"]
D2 --> G2["+ (Residual)"]
F2 --> G2
end Practical Guidance
Pre-Norm is strongly preferred for deep transformers (\(L > 24\)). Post-Norm can work for shallow models but often requires learning rate warmup and careful initialization to avoid divergence.
7. Implementation in ZigLlama¶
7.1 Normalization Type Enum¶
pub const NormalizationType = enum {
LayerNorm, // Standard Layer Normalization
RMSNorm, // Root Mean Square Normalization
BatchNorm, // Batch Normalization (less common in transformers)
GroupNorm, // Group Normalization
};
7.2 LayerNorm¶
pub fn layerNorm(
comptime T: type,
input: Tensor(T),
scale: ?Tensor(T), // gamma
shift: ?Tensor(T), // beta
allocator: Allocator,
) TensorError!Tensor(T) {
const last_dim = input.shape[input.shape.len - 1];
const batch_size = input.size / last_dim;
var result = try Tensor(T).init(allocator, input.shape);
for (0..batch_size) |batch_idx| {
const offset = batch_idx * last_dim;
// Pass 1: mean
var sum: T = 0.0;
for (0..last_dim) |i| sum += input.data[offset + i];
const mean = sum / @as(T, @floatFromInt(last_dim));
// Pass 2: variance
var var_sum: T = 0.0;
for (0..last_dim) |i| {
const diff = input.data[offset + i] - mean;
var_sum += diff * diff;
}
const std_dev = @sqrt(var_sum / @as(T, @floatFromInt(last_dim)) + NORM_EPSILON);
// Normalize + affine transform
for (0..last_dim) |i| {
var val = (input.data[offset + i] - mean) / std_dev;
if (scale) |s| val *= s.data[i];
if (shift) |b| val += b.data[i];
result.data[offset + i] = val;
}
}
return result;
}
7.3 RMSNorm¶
pub fn rmsNorm(
comptime T: type,
input: Tensor(T),
scale: ?Tensor(T), // gamma (no beta)
allocator: Allocator,
) TensorError!Tensor(T) {
const last_dim = input.shape[input.shape.len - 1];
const batch_size = input.size / last_dim;
var result = try Tensor(T).init(allocator, input.shape);
for (0..batch_size) |batch_idx| {
const offset = batch_idx * last_dim;
// Single pass: sum of squares
var sum_sq: T = 0.0;
for (0..last_dim) |i| {
const x = input.data[offset + i];
sum_sq += x * x;
}
const rms = @sqrt(sum_sq / @as(T, @floatFromInt(last_dim)) + NORM_EPSILON);
// Normalize + scale
for (0..last_dim) |i| {
var val = input.data[offset + i] / rms;
if (scale) |s| val *= s.data[i];
result.data[offset + i] = val;
}
}
return result;
}
7.4 Dispatcher¶
pub fn applyNormalization(
comptime T: type,
norm_type: NormalizationType,
input: Tensor(T),
scale: ?Tensor(T),
shift: ?Tensor(T),
allocator: Allocator,
) TensorError!Tensor(T) {
return switch (norm_type) {
.LayerNorm => layerNorm(T, input, scale, shift, allocator),
.RMSNorm => rmsNorm(T, input, scale, allocator),
.BatchNorm => batchNorm(T, input, scale, shift, allocator),
.GroupNorm => groupNorm(T, input, 8, scale, shift, allocator),
};
}
Source File
Full implementation: src/neural_primitives/normalization.zig (approximately 570 lines including tests and the numerically stable computeStableStats helper).
8. Numerical Stability¶
8.1 The Two-Pass Algorithm¶
ZigLlama uses a two-pass algorithm for computing variance (LayerNorm):
- First pass: compute \(\mu\).
- Second pass: compute \(\sigma^2\) using the known \(\mu\).
The naive one-pass formula \(\sigma^2 = \mathbb{E}[X^2] - (\mathbb{E}[X])^2\) suffers from catastrophic cancellation when the variance is small relative to the mean. The two-pass approach avoids this.
8.2 Epsilon Selection¶
The constant \(\epsilon = 10^{-6}\) is added inside the square root to prevent division by zero when all elements are identical. For f16 inference (not yet supported in ZigLlama), a larger \(\epsilon = 10^{-3}\) is typical.
9. Which Models Use What?¶
| Model | Normalization | Placement | Notes |
|---|---|---|---|
| Original Transformer | LayerNorm | Post-Norm | Vaswani et al. 2017 |
| BERT | LayerNorm | Post-Norm | Devlin et al. 2019 |
| GPT-2 / GPT-3 | LayerNorm | Pre-Norm | Radford et al. 2019 |
| LLaMA / LLaMA 2 | RMSNorm | Pre-Norm | Touvron et al. 2023 |
| Mistral | RMSNorm | Pre-Norm | Jiang et al. 2023 |
| PaLM | RMSNorm | Pre-Norm | Chowdhery et al. 2022 |
| Falcon | LayerNorm | Pre-Norm | Penedo et al. 2023 |
| T5 | RMSNorm | Pre-Norm | Raffel et al. 2020 |
10. Exercises¶
- Derive the Jacobian \(\partial \hat{x}/\partial x\) for RMSNorm and compare its rank with the LayerNorm Jacobian.
- Show that LayerNorm is invariant to adding a constant vector \(c \cdot \mathbf{1}\) to the input, while RMSNorm is not.
- Measure the wall-clock difference between LayerNorm and RMSNorm on a \(2048 \times 4096\) tensor in ZigLlama. Does the ~15% claim hold on your hardware?
References¶
-
Ioffe, S. & Szegedy, C. "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift." ICML, 2015. ↩↩
-
Ba, J. L., Kiros, J. R. & Hinton, G. E. "Layer Normalization." arXiv:1607.06450, 2016. ↩
-
Zhang, B. & Sennrich, R. "Root Mean Square Layer Normalization." NeurIPS, 2019. ↩↩
-
Wu, Y. & He, K. "Group Normalization." ECCV, 2018. ↩
-
Xiong, R. et al. "On Layer Normalization in the Transformer Architecture." ICML, 2020. ↩