ML Math

Batch Normalization

Step-by-step derivation of the forward and backward passes, normalizing across the batch dimension with per-feature scale and shift.

Notation

SymbolDescription
\(x_{i,j}\)Input for sample \(i \in \{1,\ldots,N\}\), feature \(j \in \{1,\ldots,D\}\)
\(\mu_j,\; \sigma^2_j\)Mean and variance of feature \(j\) computed across the \(N\) batch samples
\(\hat{x}_{i,j}\)Normalized value: \((x_{i,j} - \mu_j)/\sqrt{\sigma^2_j + \varepsilon}\)
\(\gamma_j,\; \beta_j\)Learned per-feature scale and shift; shared across all batch samples
\(y_{i,j}\)Output: \(\gamma_j \hat{x}_{i,j} + \beta_j\)
\(dy_{i,j},\; dx_{i,j}\)Upstream gradient \(\partial L/\partial y_{i,j}\) and downstream gradient \(\partial L/\partial x_{i,j}\)

1. Forward Pass

For each feature \(j\), compute statistics across the \(N\) samples, then normalize and rescale:

  1. Batch mean \[\mu_j = \frac{1}{N}\sum_{i=1}^{N} x_{i,j}\]
  2. Batch variance \[\sigma^2_j = \frac{1}{N}\sum_{i=1}^{N}(x_{i,j} - \mu_j)^2\]
  3. Normalize \[\hat{x}_{i,j} = \frac{x_{i,j} - \mu_j}{\sqrt{\sigma^2_j + \varepsilon}}\]
  4. Scale and shift \[y_{i,j} = \gamma_j \hat{x}_{i,j} + \beta_j\]
The normalization axis is the batch dimension \(N\): each feature \(j\) gets its own \(\mu_j\) and \(\sigma^2_j\) computed from all samples in the mini-batch. This is the opposite axis from Layer Normalization.

2. Backward Pass — Gradients for \(\gamma\) and \(\beta\)

Since \(y_{i,j} = \gamma_j \hat{x}_{i,j} + \beta_j\), the chain rule gives a contribution from every sample \(i\) in the batch:

\[ d\beta_j = \sum_{i=1}^{N} \frac{\partial L}{\partial y_{i,j}} \cdot \underbrace{\frac{\partial y_{i,j}}{\partial \beta_j}}_{=\,1} = \sum_{i=1}^{N} dy_{i,j} \] \[ d\gamma_j = \sum_{i=1}^{N} \frac{\partial L}{\partial y_{i,j}} \cdot \underbrace{\frac{\partial y_{i,j}}{\partial \gamma_j}}_{=\,\hat{x}_{i,j}} = \sum_{i=1}^{N} dy_{i,j}\cdot\hat{x}_{i,j} \]

3. Backward Pass — Gradient for the Input (\(dx\))

Each input \(x_{i,j}\) reaches the loss through three paths: directly via \(\hat{x}_{i,j}\), and indirectly via the batch mean \(\mu_j\) and batch variance \(\sigma^2_j\).

Step A — Gradient w.r.t. normalized input \(\hat{x}_{i,j}\)

Since \(y_{i,j} = \gamma_j \hat{x}_{i,j} + \beta_j\):

\[d\hat{x}_{i,j} = dy_{i,j} \cdot \gamma_j\]
Step B — Gradient w.r.t. batch variance \(\sigma^2_j\)

Every \(\hat{x}_{i,j}\) depends on \(\sigma^2_j\), so contributions from all \(N\) samples accumulate:

\[ d\sigma^2_j = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{\partial \hat{x}_{i,j}}{\partial \sigma^2_j} = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot (x_{i,j} - \mu_j) \cdot \left(-\tfrac{1}{2}(\sigma^2_j+\varepsilon)^{-3/2}\right) \]

Substituting \(x_{i,j} - \mu_j = \hat{x}_{i,j}\sqrt{\sigma^2_j+\varepsilon}\) simplifies the expression:

\[ d\sigma^2_j = -\frac{1}{2(\sigma^2_j+\varepsilon)} \sum_{i=1}^{N} d\hat{x}_{i,j}\,\hat{x}_{i,j} \]
Step C — Gradient w.r.t. batch mean \(\mu_j\)

Path 1 — direct effect on each \(\hat{x}_{i,j}\):

\[ \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{\partial \hat{x}_{i,j}}{\partial \mu_j} = \sum_{i=1}^{N} d\hat{x}_{i,j} \cdot \frac{-1}{\sqrt{\sigma^2_j+\varepsilon}} \]

Path 2 — indirect effect through \(\sigma^2_j\):

\[ d\sigma^2_j \cdot \frac{\partial \sigma^2_j}{\partial \mu_j} = d\sigma^2_j \cdot \frac{-2}{N}\sum_{i=1}^{N}(x_{i,j} - \mu_j) = 0 \]
Path 2 vanishes because \(\sum_i (x_{i,j} - \mu_j) = 0\) by definition of the batch mean.

Combined:

\[ d\mu_j = \frac{-1}{\sqrt{\sigma^2_j+\varepsilon}} \sum_{i=1}^{N} d\hat{x}_{i,j} \]
Step D — Final gradient for input \(dx_{i,j}\)

Three paths contribute to \(dx_{i,j}\): (1) through \(\hat{x}_{i,j}\), (2) through \(\sigma^2_j\) with \(\partial\sigma^2_j/\partial x_{i,j} = 2(x_{i,j}-\mu_j)/N\), (3) through \(\mu_j\) with \(\partial\mu_j/\partial x_{i,j} = 1/N\):

\[ dx_{i,j} = \frac{d\hat{x}_{i,j}}{\sqrt{\sigma^2_j+\varepsilon}} + d\sigma^2_j \cdot \frac{2(x_{i,j}-\mu_j)}{N} + \frac{d\mu_j}{N} \]

Substituting Steps A–C and using \(x_{i,j} - \mu_j = \hat{x}_{i,j}\sqrt{\sigma^2_j+\varepsilon}\):

\[ dx_{i,j} = \frac{1}{\sqrt{\sigma^2_j+\varepsilon}} \left[ d\hat{x}_{i,j} - \frac{1}{N}\sum_{k=1}^{N} d\hat{x}_{k,j} - \frac{\hat{x}_{i,j}}{N}\sum_{k=1}^{N} d\hat{x}_{k,j}\,\hat{x}_{k,j} \right] \]

Since \(d\hat{x}_{k,j} = dy_{k,j}\cdot\gamma_j\) and \(\gamma_j\) does not vary with the batch index \(k\), it factors out of both sums:

Result
\[ dx_{i,j} = \frac{\gamma_j}{\sqrt{\sigma^2_j+\varepsilon}} \left[ dy_{i,j} - \frac{1}{N}\sum_{k=1}^{N} dy_{k,j} - \frac{\hat{x}_{i,j}}{N}\sum_{k=1}^{N} dy_{k,j}\,\hat{x}_{k,j} \right] \]

The three bracketed terms each play a distinct role:

Unlike Layer Normalization, \(\gamma_j\) here factors cleanly out of the sums because those sums run over the batch index \(k\), and \(\gamma_j\) is constant across samples. In LayerNorm the equivalent sums run over the feature index \(j\), where \(\gamma_j\) varies — so it cannot be factored out.

4. Training vs. Inference

The forward pass above uses batch statistics — \(\mu_j\) and \(\sigma^2_j\) computed from the current mini-batch. This couples all samples together and causes the mean and variance to fluctuate across batches.

At inference time the batch may contain a single example, making batch statistics meaningless. Instead, running statistics accumulated during training are used:

\[ \mu_j^{\text{run}} \;\leftarrow\; (1-\alpha)\,\mu_j^{\text{run}} + \alpha\,\mu_j^{\text{batch}} \qquad \sigma^{2,\text{run}}_j \;\leftarrow\; (1-\alpha)\,\sigma^{2,\text{run}}_j + \alpha\,\sigma^2_j^{\text{batch}} \]

where \(\alpha\) is the momentum hyperparameter (e.g. 0.1 in PyTorch). At inference the normalization becomes a fixed affine transform per feature — no batch dependence remains.

This train/inference mismatch is Batch Norm's main practical difficulty. It means the model behaves differently at training and eval time, which can mask bugs and requires careful handling of model.train() / model.eval() calls.

5. Batch Norm vs. Layer Norm

Batch NormLayer Norm
Normalize overBatch dimension \(N\) (per feature \(j\))Feature dimension \(D\) (per token/sample)
Statistics depend onAll samples in the mini-batchOnly the current sample
Train vs. inferenceDifferent (running stats at inference)Same
Small batch sizeUnstable (noisy statistics)Unaffected
\(\gamma\) in \(dx\) formulaFactors out of sums (sum is over batch)Must stay inside sums (sum is over features)
Typical useCNNs, vision modelsTransformers, LLMs

6. Summary

j

\(\displaystyle\sum_{i=1}^{N} dy_{i,j}\)

j

\(\displaystyle\sum_{i=1}^{N} dy_{i,j}\cdot\hat{x}_{i,j}\)

dxi,j
\[ \begin{aligned} dx_{i,j} = \frac{\gamma_j}{\sqrt{\sigma^2_j+\varepsilon}}\Bigl[ &\,dy_{i,j} - \tfrac{1}{N}\sum_k dy_{k,j} \\ &- \tfrac{\hat{x}_{i,j}}{N}\sum_k dy_{k,j}\hat{x}_{k,j} \Bigr] \end{aligned} \]
QuantityFormula
\(\mu_j\)\(\tfrac{1}{N}\sum_i x_{i,j}\)
\(\sigma^2_j\)\(\tfrac{1}{N}\sum_i (x_{i,j}-\mu_j)^2\)
\(\hat{x}_{i,j}\)\((x_{i,j}-\mu_j)/\sqrt{\sigma^2_j+\varepsilon}\)
\(d\beta_j\)\(\sum_i dy_{i,j}\)
\(d\gamma_j\)\(\sum_i dy_{i,j}\,\hat{x}_{i,j}\)
\(d\hat{x}_{i,j}\)\(dy_{i,j}\,\gamma_j\)
\(d\sigma^2_j\)\(-\tfrac{1}{2(\sigma^2_j+\varepsilon)}\sum_i d\hat{x}_{i,j}\,\hat{x}_{i,j}\)
\(d\mu_j\)\(-\tfrac{1}{\sqrt{\sigma^2_j+\varepsilon}}\sum_i d\hat{x}_{i,j}\)
\(dx_{i,j}\)\(\tfrac{\gamma_j}{\sqrt{\sigma^2_j+\varepsilon}}\!\left[dy_{i,j} - \tfrac{1}{N}\sum_k dy_{k,j} - \tfrac{\hat{x}_{i,j}}{N}\sum_k dy_{k,j}\hat{x}_{k,j}\right]\)