Step-by-step derivation of the forward and backward passes, normalizing across the batch dimension with per-feature scale and shift.
| Symbol | Description |
|---|---|
| \(x_{i,j}\) | Input for sample \(i \in \{1,\ldots,N\}\), feature \(j \in \{1,\ldots,D\}\) |
| \(\mu_j,\; \sigma^2_j\) | Mean and variance of feature \(j\) computed across the \(N\) batch samples |
| \(\hat{x}_{i,j}\) | Normalized value: \((x_{i,j} - \mu_j)/\sqrt{\sigma^2_j + \varepsilon}\) |
| \(\gamma_j,\; \beta_j\) | Learned per-feature scale and shift; shared across all batch samples |
| \(y_{i,j}\) | Output: \(\gamma_j \hat{x}_{i,j} + \beta_j\) |
| \(dy_{i,j},\; dx_{i,j}\) | Upstream gradient \(\partial L/\partial y_{i,j}\) and downstream gradient \(\partial L/\partial x_{i,j}\) |
For each feature \(j\), compute statistics across the \(N\) samples, then normalize and rescale:
Since \(y_{i,j} = \gamma_j \hat{x}_{i,j} + \beta_j\), the chain rule gives a contribution from every sample \(i\) in the batch:
\[ d\beta_j = \sum_{i=1}^{N} \frac{\partial L}{\partial y_{i,j}} \cdot \underbrace{\frac{\partial y_{i,j}}{\partial \beta_j}}_{=\,1} = \sum_{i=1}^{N} dy_{i,j} \] \[ d\gamma_j = \sum_{i=1}^{N} \frac{\partial L}{\partial y_{i,j}} \cdot \underbrace{\frac{\partial y_{i,j}}{\partial \gamma_j}}_{=\,\hat{x}_{i,j}} = \sum_{i=1}^{N} dy_{i,j}\cdot\hat{x}_{i,j} \]Each input \(x_{i,j}\) reaches the loss through three paths: directly via \(\hat{x}_{i,j}\), and indirectly via the batch mean \(\mu_j\) and batch variance \(\sigma^2_j\).
Since \(y_{i,j} = \gamma_j \hat{x}_{i,j} + \beta_j\):
\[d\hat{x}_{i,j} = dy_{i,j} \cdot \gamma_j\]Every \(\hat{x}_{i,j}\) depends on \(\sigma^2_j\), so contributions from all \(N\) samples accumulate:
\[ d\sigma^2_j = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{\partial \hat{x}_{i,j}}{\partial \sigma^2_j} = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot (x_{i,j} - \mu_j) \cdot \left(-\tfrac{1}{2}(\sigma^2_j+\varepsilon)^{-3/2}\right) \]Substituting \(x_{i,j} - \mu_j = \hat{x}_{i,j}\sqrt{\sigma^2_j+\varepsilon}\) simplifies the expression:
\[ d\sigma^2_j = -\frac{1}{2(\sigma^2_j+\varepsilon)} \sum_{i=1}^{N} d\hat{x}_{i,j}\,\hat{x}_{i,j} \]Path 1 — direct effect on each \(\hat{x}_{i,j}\):
\[ \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{\partial \hat{x}_{i,j}}{\partial \mu_j} = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{-1}{\sqrt{\sigma^2_j+\varepsilon}} \]Path 2 — indirect effect through \(\sigma^2_j\):
\[ d\sigma^2_j \cdot \frac{\partial \sigma^2_j}{\partial \mu_j} = d\sigma^2_j \cdot \frac{-2}{N}\sum_{i=1}^{N}(x_{i,j} - \mu_j) = 0 \]Combined:
\[ d\mu_j = \frac{-1}{\sqrt{\sigma^2_j+\varepsilon}} \sum_{i=1}^{N} d\hat{x}_{i,j} \]Three paths contribute to \(dx_{i,j}\): (1) through \(\hat{x}_{i,j}\), (2) through \(\sigma^2_j\) with \(\partial\sigma^2_j/\partial x_{i,j} = 2(x_{i,j}-\mu_j)/N\), (3) through \(\mu_j\) with \(\partial\mu_j/\partial x_{i,j} = 1/N\):
\[ dx_{i,j} = \frac{d\hat{x}_{i,j}}{\sqrt{\sigma^2_j+\varepsilon}} + d\sigma^2_j \cdot \frac{2(x_{i,j}-\mu_j)}{N} + \frac{d\mu_j}{N} \]Substituting Steps A–C and using \(x_{i,j} - \mu_j = \hat{x}_{i,j}\sqrt{\sigma^2_j+\varepsilon}\):
\[ dx_{i,j} = \frac{1}{\sqrt{\sigma^2_j+\varepsilon}} \left[ d\hat{x}_{i,j} - \frac{1}{N}\sum_{k=1}^{N} d\hat{x}_{k,j} - \frac{\hat{x}_{i,j}}{N}\sum_{k=1}^{N} d\hat{x}_{k,j}\,\hat{x}_{k,j} \right] \]Since \(d\hat{x}_{k,j} = dy_{k,j}\cdot\gamma_j\) and \(\gamma_j\) does not vary with the batch index \(k\), it factors out of both sums:
The three bracketed terms each play a distinct role:
The forward pass above uses batch statistics — \(\mu_j\) and \(\sigma^2_j\) computed from the current mini-batch. This couples all samples together and causes the mean and variance to fluctuate across batches.
At inference time the batch may contain a single example, making batch statistics meaningless. Instead, running statistics accumulated during training are used:
\[ \mu_j^{\text{run}} \;\leftarrow\; (1-\alpha)\,\mu_j^{\text{run}} + \alpha\,\mu_j^{\text{batch}} \qquad \sigma^{2,\text{run}}_j \;\leftarrow\; (1-\alpha)\,\sigma^{2,\text{run}}_j + \alpha\,\sigma^2_j^{\text{batch}} \]where \(\alpha\) is the momentum hyperparameter (e.g. 0.1 in PyTorch). At inference the normalization becomes a fixed affine transform per feature — no batch dependence remains.
model.train() / model.eval() calls.
| Batch Norm | Layer Norm | |
|---|---|---|
| Normalize over | Batch dimension \(N\) (per feature \(j\)) | Feature dimension \(D\) (per token/sample) |
| Statistics depend on | All samples in the mini-batch | Only the current sample |
| Train vs. inference | Different (running stats at inference) | Same |
| Small batch size | Unstable (noisy statistics) | Unaffected |
| \(\gamma\) in \(dx\) formula | Factors out of sums (sum is over batch) | Must stay inside sums (sum is over features) |
| Typical use | CNNs, vision models | Transformers, LLMs |
\(\displaystyle\sum_{i=1}^{N} dy_{i,j}\)
\(\displaystyle\sum_{i=1}^{N} dy_{i,j}\cdot\hat{x}_{i,j}\)
| Quantity | Formula |
|---|---|
| \(\mu_j\) | \(\tfrac{1}{N}\sum_i x_{i,j}\) |
| \(\sigma^2_j\) | \(\tfrac{1}{N}\sum_i (x_{i,j}-\mu_j)^2\) |
| \(\hat{x}_{i,j}\) | \((x_{i,j}-\mu_j)/\sqrt{\sigma^2_j+\varepsilon}\) |
| \(d\beta_j\) | \(\sum_i dy_{i,j}\) |
| \(d\gamma_j\) | \(\sum_i dy_{i,j}\,\hat{x}_{i,j}\) |
| \(d\hat{x}_{i,j}\) | \(dy_{i,j}\,\gamma_j\) |
| \(d\sigma^2_j\) | \(-\tfrac{1}{2(\sigma^2_j+\varepsilon)}\sum_i d\hat{x}_{i,j}\,\hat{x}_{i,j}\) |
| \(d\mu_j\) | \(-\tfrac{1}{\sqrt{\sigma^2_j+\varepsilon}}\sum_i d\hat{x}_{i,j}\) |
| \(dx_{i,j}\) | \(\tfrac{\gamma_j}{\sqrt{\sigma^2_j+\varepsilon}}\!\left[dy_{i,j} - \tfrac{1}{N}\sum_k dy_{k,j} - \tfrac{\hat{x}_{i,j}}{N}\sum_k dy_{k,j}\hat{x}_{k,j}\right]\) |