ML Math

RMS Normalization

Step-by-step derivation of the forward and backward passes for RMSNorm — the simpler normalization used in LLaMA, Gemma, and most modern LLMs.

Motivation

RMSNorm (Zhang & Sennrich, 2019) strips Layer Normalization down to just the re-scaling step, dropping mean subtraction entirely. The hypothesis: the re-centering in LayerNorm is redundant — re-scaling alone is sufficient, and removing mean subtraction cuts 7–64% of computation.

It is the default normalization in LLaMA, LLaMA 2/3, Gemma, Mistral, and Qwen.

Notation

SymbolDescription
\(\mathbf{x}\)Input feature vector of length \(D\) (one token's embedding)
\(r\)Root mean square of \(\mathbf{x}\): \(\sqrt{\tfrac{1}{D}\sum_i x_i^2 + \varepsilon}\)
\(\hat{x}_i\)Normalized input: \(x_i / r\)
\(\gamma\)Learned scale parameter (no shift \(\beta\) in RMSNorm)
\(dy,\; dx\)Upstream gradient \(\partial L/\partial y\) and downstream gradient \(\partial L/\partial x\)
\(N, L, D\)Batch size, sequence length, feature dimension

1. Forward Pass

  1. Root mean square \[r = \sqrt{\frac{1}{D}\sum_{i=1}^{D} x_i^2 + \varepsilon}\]
  2. Normalize \[\hat{x}_i = \frac{x_i}{r}\]
  3. Scale \[y_i = \gamma_i \hat{x}_i\]
There is no shift parameter \(\beta\) and no mean subtraction. That is the entire point of RMSNorm.

2. Backward Pass — Gradient for \(\gamma\)

Since \(y_i = \gamma_i \hat{x}_i\):

\[ \frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \gamma_i}}_{=\,\hat{x}_i} = dy_i \cdot \hat{x}_i \]

Accumulated across batch and sequence:

\[ d\gamma_i = \sum_{b=1}^{N}\sum_{s=1}^{L} dy_{b,s,i} \cdot \hat{x}_{b,s,i} \]

3. Backward Pass — Gradient for the Input (\(dx\))

Each \(x_i\) affects the output through two paths: directly as the numerator of \(\hat{x}_i\), and indirectly through the denominator \(r\) which depends on all \(x_k\).

Step A — Gradient w.r.t. normalized input \(\hat{x}_i\)
\[d\hat{x}_i = dy_i \cdot \gamma_i\]
Step B — Gradient w.r.t. RMS \(r\)

Every \(\hat{x}_j = x_j/r\) depends on \(r\), so all \(j\) contribute:

\[ \frac{\partial \hat{x}_j}{\partial r} = -\frac{x_j}{r^2} \] \[ dr = \sum_{j=1}^{D} d\hat{x}_j \cdot \left(-\frac{x_j}{r^2}\right) = -\frac{1}{r^2} \sum_{j=1}^{D} d\hat{x}_j\, x_j \]
Step C — Chain rule for \(x_i\) through both paths

Using \(\hat{x}_j = x_j / r\), the partial derivative of \(\hat{x}_j\) w.r.t. \(x_i\) is:

\[ \frac{\partial \hat{x}_j}{\partial x_i} = \frac{\delta_{ij}}{r} - \frac{x_j}{r^2} \cdot \frac{\partial r}{\partial x_i} = \frac{\delta_{ij}}{r} - \frac{x_j\, x_i}{D\, r^3} = \frac{1}{r}\!\left(\delta_{ij} - \frac{\hat{x}_j\,\hat{x}_i}{D}\right) \]

where we used \(\partial r/\partial x_i = x_i/(Dr)\). Summing over all output dimensions \(j\):

\[ dx_i = \sum_{j=1}^{D} \frac{\partial L}{\partial y_j}\cdot \gamma_j \cdot \frac{\partial \hat{x}_j}{\partial x_i} = \frac{1}{r}\sum_{j=1}^{D} d\hat{x}_j\!\left(\delta_{ij} - \frac{\hat{x}_j\,\hat{x}_i}{D}\right) \] \[ = \frac{1}{r}\!\left[d\hat{x}_i - \frac{\hat{x}_i}{D}\sum_{j=1}^{D} d\hat{x}_j\,\hat{x}_j\right] \]
Result
\[ dx_i = \frac{1}{r}\!\left[\gamma_i\,dy_i \;-\; \frac{\hat{x}_i}{D}\sum_{j=1}^{D}\gamma_j\,dy_j\,\hat{x}_j\right] \]

The two bracketed terms:

4. Comparison with Layer Normalization

RMSNorm's simpler forward pass produces a simpler backward pass — one fewer correction term in \(dx\).

Side by side — gradient for input \(dx_i\)
LayerNorm
\[ \frac{\gamma_i}{\sqrt{\sigma^2+\varepsilon}} \Bigl[ dy_i \;-\; \underbrace{\tfrac{1}{D}\textstyle\sum_j dy_j}_{\text{mean correction}} \;-\; \tfrac{\hat{x}_i}{D}\textstyle\sum_j dy_j\hat{x}_j \Bigr] \]
RMSNorm
\[ \frac{1}{r} \Bigl[ \gamma_i\,dy_i \;-\; \tfrac{\hat{x}_i}{D}\textstyle\sum_j \gamma_j dy_j\hat{x}_j \Bigr] \]
PropertyLayerNormRMSNorm
Mean subtraction in forwardYesNo
Learnable shift \(\beta\)YesNo
Correction terms in \(dx\)2 (mean + variance)1 (RMS only)
Normalizing statistic\(\sqrt{\sigma^2 + \varepsilon}\)\(\sqrt{\tfrac{1}{D}\sum x_i^2 + \varepsilon}\)
Used inBERT, GPT-2, T5LLaMA, Gemma, Mistral, Qwen

5. Summary

r (forward)

\(\sqrt{\tfrac{1}{D}\sum_i x_i^2 + \varepsilon}\)

\(\sum_{b,s} dy_{b,s,i}\cdot\hat{x}_{b,s,i}\)

dxi

\(\tfrac{1}{r}\bigl[\gamma_i dy_i - \tfrac{\hat{x}_i}{D}\overline{\gamma \cdot dy \cdot \hat{x}}\bigr]\)

where \(\overline{(\cdot)}\) denotes the sum across the \(D\) feature dimension (not divided by \(D\) — division already appears in the prefactor).