Step-by-step derivation of the forward and backward passes for RMSNorm — the simpler normalization used in LLaMA, Gemma, and most modern LLMs.
RMSNorm (Zhang & Sennrich, 2019) strips Layer Normalization down to just the re-scaling step, dropping mean subtraction entirely. The hypothesis: the re-centering in LayerNorm is redundant — re-scaling alone is sufficient, and removing mean subtraction cuts 7–64% of computation.
It is the default normalization in LLaMA, LLaMA 2/3, Gemma, Mistral, and Qwen.
| Symbol | Description |
|---|---|
| \(\mathbf{x}\) | Input feature vector of length \(D\) (one token's embedding) |
| \(r\) | Root mean square of \(\mathbf{x}\): \(\sqrt{\tfrac{1}{D}\sum_i x_i^2 + \varepsilon}\) |
| \(\hat{x}_i\) | Normalized input: \(x_i / r\) |
| \(\gamma\) | Learned scale parameter (no shift \(\beta\) in RMSNorm) |
| \(dy,\; dx\) | Upstream gradient \(\partial L/\partial y\) and downstream gradient \(\partial L/\partial x\) |
| \(N, L, D\) | Batch size, sequence length, feature dimension |
Since \(y_i = \gamma_i \hat{x}_i\):
\[ \frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \gamma_i}}_{=\,\hat{x}_i} = dy_i \cdot \hat{x}_i \]Accumulated across batch and sequence:
\[ d\gamma_i = \sum_{b=1}^{N}\sum_{s=1}^{L} dy_{b,s,i} \cdot \hat{x}_{b,s,i} \]Each \(x_i\) affects the output through two paths: directly as the numerator of \(\hat{x}_i\), and indirectly through the denominator \(r\) which depends on all \(x_k\).
Every \(\hat{x}_j = x_j/r\) depends on \(r\), so all \(j\) contribute:
\[ \frac{\partial \hat{x}_j}{\partial r} = -\frac{x_j}{r^2} \] \[ dr = \sum_{j=1}^{D} d\hat{x}_j \cdot \left(-\frac{x_j}{r^2}\right) = -\frac{1}{r^2} \sum_{j=1}^{D} d\hat{x}_j\, x_j \]Using \(\hat{x}_j = x_j / r\), the partial derivative of \(\hat{x}_j\) w.r.t. \(x_i\) is:
\[ \frac{\partial \hat{x}_j}{\partial x_i} = \frac{\delta_{ij}}{r} - \frac{x_j}{r^2} \cdot \frac{\partial r}{\partial x_i} = \frac{\delta_{ij}}{r} - \frac{x_j\, x_i}{D\, r^3} = \frac{1}{r}\!\left(\delta_{ij} - \frac{\hat{x}_j\,\hat{x}_i}{D}\right) \]where we used \(\partial r/\partial x_i = x_i/(Dr)\). Summing over all output dimensions \(j\):
\[ dx_i = \sum_{j=1}^{D} \frac{\partial L}{\partial y_j}\cdot \gamma_j \cdot \frac{\partial \hat{x}_j}{\partial x_i} = \frac{1}{r}\sum_{j=1}^{D} d\hat{x}_j\!\left(\delta_{ij} - \frac{\hat{x}_j\,\hat{x}_i}{D}\right) \] \[ = \frac{1}{r}\!\left[d\hat{x}_i - \frac{\hat{x}_i}{D}\sum_{j=1}^{D} d\hat{x}_j\,\hat{x}_j\right] \]The two bracketed terms:
RMSNorm's simpler forward pass produces a simpler backward pass — one fewer correction term in \(dx\).
| Property | LayerNorm | RMSNorm |
|---|---|---|
| Mean subtraction in forward | Yes | No |
| Learnable shift \(\beta\) | Yes | No |
| Correction terms in \(dx\) | 2 (mean + variance) | 1 (RMS only) |
| Normalizing statistic | \(\sqrt{\sigma^2 + \varepsilon}\) | \(\sqrt{\tfrac{1}{D}\sum x_i^2 + \varepsilon}\) |
| Used in | BERT, GPT-2, T5 | LLaMA, Gemma, Mistral, Qwen |
\(\sqrt{\tfrac{1}{D}\sum_i x_i^2 + \varepsilon}\)
\(\sum_{b,s} dy_{b,s,i}\cdot\hat{x}_{b,s,i}\)
\(\tfrac{1}{r}\bigl[\gamma_i dy_i - \tfrac{\hat{x}_i}{D}\overline{\gamma \cdot dy \cdot \hat{x}}\bigr]\)
where \(\overline{(\cdot)}\) denotes the sum across the \(D\) feature dimension (not divided by \(D\) — division already appears in the prefactor).