ML Math

Layer Normalization

Step-by-step derivation of the forward and backward passes, including the full gradient for the input \(\mathbf{x}\).

Notation

SymbolDescription
\(\mathbf{x}\)Input feature vector of length \(D\) (one token's embedding)
\(\mu,\; \sigma^2\)Mean and variance computed across the \(D\) features
\(\gamma,\; \beta\)Learned scale and shift parameters
\(\hat{x}_i\)Normalized input: \((x_i - \mu)/\sqrt{\sigma^2+\varepsilon}\)
\(dy,\; dx\)Upstream gradient \(\partial L/\partial y\) and downstream gradient \(\partial L/\partial x\)
\(N, L, D\)Batch size, sequence length, feature dimension

1. Forward Pass

  1. Mean \[\mu = \frac{1}{D}\sum_{i=1}^{D} x_i\]
  2. Variance \[\sigma^2 = \frac{1}{D}\sum_{i=1}^{D}(x_i - \mu)^2\]
  3. Normalize \[\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \varepsilon}}\]
  4. Scale and shift \[y_i = \gamma_i \hat{x}_i + \beta_i\]

2. Backward Pass — Gradients for \(\gamma\) and \(\beta\)

Single token

Since \(y_i = \gamma_i \hat{x}_i + \beta_i\), both gradients follow immediately from the chain rule:

\[ \frac{\partial L}{\partial \beta_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \beta_i}}_{=\,1} = dy_i \] \[ \frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \gamma_i}}_{=\,\hat{x}_i} = dy_i \cdot \hat{x}_i \]

Accumulated across batch and sequence

\[ d\beta_i = \sum_{b=1}^{N}\sum_{s=1}^{L} dy_{b,s,i} \qquad\qquad d\gamma_i = \sum_{b=1}^{N}\sum_{s=1}^{L} dy_{b,s,i}\cdot\hat{x}_{b,s,i} \]

3. Backward Pass — Gradient for the Input (\(dx\))

This is the trickier part. The input \(x_i\) affects the output through three paths: directly via \(\hat{x}_i\), and indirectly via \(\mu\) and \(\sigma^2\).

Step A — Gradient w.r.t. normalized input \(\hat{x}_i\)
\[d\hat{x}_i = dy_i \cdot \gamma_i\]
Step B — Gradient w.r.t. variance \(\sigma^2\)

Each \(\hat{x}_j\) depends on \(\sigma^2\), so we sum contributions from all \(j\):

\[ d\sigma^2 = \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{\partial \hat{x}_j}{\partial \sigma^2} = \sum_{j=1}^{D} d\hat{x}_j \cdot (x_j - \mu) \cdot \left(-\tfrac{1}{2}(\sigma^2+\varepsilon)^{-3/2}\right) \]
Step C — Gradient w.r.t. mean \(\mu\)

Path 1 — direct effect on each \(\hat{x}_j\):

\[ \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{\partial \hat{x}_j}{\partial \mu} = \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{-1}{\sqrt{\sigma^2+\varepsilon}} \]

Path 2 — indirect effect through \(\sigma^2\):

\[ d\sigma^2 \cdot \frac{\partial \sigma^2}{\partial \mu} = d\sigma^2 \cdot \frac{-2}{D}\sum_{j=1}^{D}(x_j - \mu) = 0 \]
Path 2 is zero because \(\sum_j (x_j - \mu) = 0\) by definition of the mean.

Combined:

\[ d\mu = \frac{-1}{\sqrt{\sigma^2+\varepsilon}} \sum_{j=1}^{D} d\hat{x}_j \]
Step D — Final gradient for input \(dx_i\)

Applying the chain rule through \(\hat{x}_i\), \(\sigma^2\), and \(\mu\), and substituting Steps A–C:

Result (in terms of \(d\hat{x}\))
\[ dx_i = \frac{1}{\sqrt{\sigma^2+\varepsilon}} \left[ d\hat{x}_i - \frac{1}{D}\sum_{j=1}^{D} d\hat{x}_j - \frac{\hat{x}_i}{D}\sum_{j=1}^{D} d\hat{x}_j\,\hat{x}_j \right] \]

Substituting \(d\hat{x}_j = dy_j\,\gamma_j\) (Step A) gives the fully expanded form:

\[ dx_i = \frac{1}{\sqrt{\sigma^2+\varepsilon}} \left[ dy_i\,\gamma_i - \frac{1}{D}\sum_{j=1}^{D} dy_j\,\gamma_j - \frac{\hat{x}_i}{D}\sum_{j=1}^{D} dy_j\,\gamma_j\,\hat{x}_j \right] \]
A common mistake is to factor \(\gamma_i\) out of the summations, writing \(\gamma_i \sum_j dy_j\) instead of \(\sum_j dy_j\gamma_j\). Because \(\gamma\) is a vector (\(\gamma_j\) varies with \(j\)), it must stay inside the sum. The \(d\hat{x}\) form avoids this pitfall entirely.

The three bracketed terms each play a distinct role:

4. Summary

\(\sum_{b,s} dy_{b,s,i}\)

\(\sum_{b,s} dy_{b,s,i}\cdot\hat{x}_{b,s,i}\)

dxi

\(\tfrac{1}{\sqrt{\sigma^2+\varepsilon}}\!\left[d\hat{x}_i - \tfrac{1}{D}\!\sum_j d\hat{x}_j - \tfrac{\hat{x}_i}{D}\!\sum_j d\hat{x}_j\hat{x}_j\right]\)

where \(d\hat{x}_j = dy_j\,\gamma_j\), and all sums are over the \(D\) feature dimension.