Step-by-step derivation of the forward and backward passes, including the full gradient for the input \(\mathbf{x}\).
| Symbol | Description |
|---|---|
| \(\mathbf{x}\) | Input feature vector of length \(D\) (one token's embedding) |
| \(\mu,\; \sigma^2\) | Mean and variance computed across the \(D\) features |
| \(\gamma,\; \beta\) | Learned scale and shift parameters |
| \(\hat{x}_i\) | Normalized input: \((x_i - \mu)/\sqrt{\sigma^2+\varepsilon}\) |
| \(dy,\; dx\) | Upstream gradient \(\partial L/\partial y\) and downstream gradient \(\partial L/\partial x\) |
| \(N, L, D\) | Batch size, sequence length, feature dimension |
Since \(y_i = \gamma_i \hat{x}_i + \beta_i\), both gradients follow immediately from the chain rule:
\[ \frac{\partial L}{\partial \beta_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \beta_i}}_{=\,1} = dy_i \] \[ \frac{\partial L}{\partial \gamma_i} = \frac{\partial L}{\partial y_i} \cdot \underbrace{\frac{\partial y_i}{\partial \gamma_i}}_{=\,\hat{x}_i} = dy_i \cdot \hat{x}_i \]This is the trickier part. The input \(x_i\) affects the output through three paths: directly via \(\hat{x}_i\), and indirectly via \(\mu\) and \(\sigma^2\).
Each \(\hat{x}_j\) depends on \(\sigma^2\), so we sum contributions from all \(j\):
\[ d\sigma^2 = \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{\partial \hat{x}_j}{\partial \sigma^2} = \sum_{j=1}^{D} d\hat{x}_j \cdot (x_j - \mu) \cdot \left(-\tfrac{1}{2}(\sigma^2+\varepsilon)^{-3/2}\right) \]Path 1 — direct effect on each \(\hat{x}_j\):
\[ \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{\partial \hat{x}_j}{\partial \mu} = \sum_{j=1}^{D} d\hat{x}_j \cdot \frac{-1}{\sqrt{\sigma^2+\varepsilon}} \]Path 2 — indirect effect through \(\sigma^2\):
\[ d\sigma^2 \cdot \frac{\partial \sigma^2}{\partial \mu} = d\sigma^2 \cdot \frac{-2}{D}\sum_{j=1}^{D}(x_j - \mu) = 0 \]Combined:
\[ d\mu = \frac{-1}{\sqrt{\sigma^2+\varepsilon}} \sum_{j=1}^{D} d\hat{x}_j \]Applying the chain rule through \(\hat{x}_i\), \(\sigma^2\), and \(\mu\), and substituting Steps A–C:
Substituting \(d\hat{x}_j = dy_j\,\gamma_j\) (Step A) gives the fully expanded form:
\[ dx_i = \frac{1}{\sqrt{\sigma^2+\varepsilon}} \left[ dy_i\,\gamma_i - \frac{1}{D}\sum_{j=1}^{D} dy_j\,\gamma_j - \frac{\hat{x}_i}{D}\sum_{j=1}^{D} dy_j\,\gamma_j\,\hat{x}_j \right] \]The three bracketed terms each play a distinct role:
\(\sum_{b,s} dy_{b,s,i}\)
\(\sum_{b,s} dy_{b,s,i}\cdot\hat{x}_{b,s,i}\)
\(\tfrac{1}{\sqrt{\sigma^2+\varepsilon}}\!\left[d\hat{x}_i - \tfrac{1}{D}\!\sum_j d\hat{x}_j - \tfrac{\hat{x}_i}{D}\!\sum_j d\hat{x}_j\hat{x}_j\right]\)
where \(d\hat{x}_j = dy_j\,\gamma_j\), and all sums are over the \(D\) feature dimension.