ML Math

Cross Entropy

Forward and backward pass for the softmax + cross-entropy loss used in language model next-token prediction and multi-class classification.

Notation

SymbolDescription
\(\mathbf{z} \in \mathbb{R}^V\)Logits (raw pre-softmax scores); \(V\) is the vocabulary or class count
\(\mathbf{p} \in \mathbb{R}^V\)Softmax probabilities: \(p_i = e^{z_i} / \sum_j e^{z_j}\)
\(S = \sum_j e^{z_j}\)Partition function (normalization denominator)
\(k\)Index of the true class
\(\mathbf{y} \in \{0,1\}^V\)One-hot target: \(y_i = \mathbf{1}[i = k]\)
\(L\)Scalar cross-entropy loss
\(d\mathbf{z}\)Gradient \(\partial L / \partial \mathbf{z}\)

1. Forward Pass

Step 1 — Softmax

Convert logits to a probability distribution over all \(V\) classes:

\[ p_i = \frac{e^{z_i}}{S}, \qquad S = \sum_{j=1}^{V} e^{z_j} \]

By construction \(p_i > 0\) and \(\sum_i p_i = 1\), so \(\mathbf{p}\) is a valid probability distribution.

Step 2 — Cross-entropy loss

Measure the log-loss against the one-hot target:

\[ L = -\sum_{i=1}^{V} y_i \log p_i = -\log p_k \]

Because only the true-class term survives the sum, the loss reduces to the negative log-probability assigned to the correct class.

Expanding \(-\log p_k\) using the definition of softmax:

\[ L = -\log p_k = -\log \frac{e^{z_k}}{S} = -\log e^{z_k} + \log S = -z_k + \log \sum_{j=1}^{V} e^{z_j} \]
Forward pass
\[ L = -z_k + \log \sum_{j=1}^{V} e^{z_j} \]
This form avoids any division and gives a clear interpretation: the loss equals how much the log-partition function \(\log S\) exceeds the true-class logit \(z_k\). A perfect model pushes \(z_k \to \infty\) relative to all others, driving \(L \to 0\).

2. Numerical Stability

Computing \(e^{z_i}\) directly overflows for large logits. The standard fix subtracts the maximum logit before exponentiating. Let \(m = \max_j z_j\):

\[ p_i = \frac{e^{z_i - m}}{\sum_j e^{z_j - m}} \]

This is mathematically identical to the original (the \(e^{-m}\) factors cancel), but the largest exponent is now \(0\), keeping all values in \((0,\, 1]\).

PyTorch's F.cross_entropy and JAX's optax.softmax_cross_entropy both apply this shift internally. You rarely need to implement it by hand, but understanding it is important for writing numerically stable custom kernels.

3. Backward Pass

We want \(\partial L / \partial z_i\) for each logit. The derivation goes through the Jacobian of softmax.

Step A — Softmax Jacobian

Apply the quotient rule to \(p_k = e^{z_k}/S\), where \(S = \sum_j e^{z_j}\).

Case \(i = k\) — both numerator and \(S\) depend on \(z_k\), so the quotient rule gives \((e^{z_k} \cdot S - e^{z_k} \cdot e^{z_k})/S^2\):

\[ \begin{aligned} \frac{\partial p_k}{\partial z_k} &= \frac{e^{z_k} \cdot S - e^{z_k} \cdot e^{z_k}}{S^2} \\[8pt] &= \frac{e^{z_k}}{S} - \left(\frac{e^{z_k}}{S}\right)^{\!2} = p_k - p_k^2 = p_k(1 - p_k) \end{aligned} \]

Case \(i \ne k\) — the numerator \(e^{z_k}\) does not depend on \(z_i\), so its derivative is zero:

\[ \frac{\partial p_k}{\partial z_i} = \frac{0 \cdot S - e^{z_k} \cdot \dfrac{\partial S}{\partial z_i}}{S^2} = \frac{-e^{z_k} \cdot e^{z_i}}{S^2} = -\frac{e^{z_k}}{S} \cdot \frac{e^{z_i}}{S} = -p_k\, p_i \]
Step B — Chain rule through the log

Since \(L = -\log p_k\), the chain rule gives \(\partial L/\partial z_i = (-1/p_k)\cdot(\partial p_k/\partial z_i)\).

Case \(i = k\) — substitute the result from Step A:

\[ \frac{\partial L}{\partial z_k} = -\frac{1}{p_k} \cdot p_k(1-p_k) = -(1-p_k) = p_k - 1 \]

Case \(i \ne k\) — the two negatives cancel:

\[ \frac{\partial L}{\partial z_i} = -\frac{1}{p_k} \cdot (-p_k\, p_i) = p_i \]

Combining both cases using the one-hot label \(y_i\):

Key result — gradient of the combined loss
\[ \frac{\partial L}{\partial z_i} = p_i - y_i \] \[ d\mathbf{z} = \mathbf{p} - \mathbf{y} \]

This remarkably clean formula is one of the most important in deep learning. The gradient at each logit is simply how much probability the model assigned minus what it should have assigned. No intermediate Jacobian matrices or division by \(p_k\) remain in the final expression.

This simplicity is not a coincidence. Cross entropy is the natural loss for softmax because they are conjugate in the exponential family — the log in the loss exactly cancels the exp in the softmax, leaving a linear residual.

4. Batch Formulation

In practice, loss is averaged over a mini-batch of \(N\) examples. Let \(Z \in \mathbb{R}^{N \times V}\) be the logit matrix and \(k_1, \ldots, k_N\) the true class indices.

\[ L = -\frac{1}{N}\sum_{n=1}^{N} \log p_{n,k_n} \]

Softmax is applied row-wise to produce \(P \in \mathbb{R}^{N \times V}\). The gradient follows immediately:

\[ \frac{\partial L}{\partial Z} = \frac{1}{N}(P - Y) \]

where \(Y \in \mathbb{R}^{N \times V}\) is the batch of one-hot labels. The \(1/N\) factor comes from averaging the loss — it simply scales every gradient row uniformly.

In language modelling the sequence dimension is typically folded into the batch, so the effective \(N\) is \(\text{batch} \times \text{sequence length}\). The same formula applies without modification.

5. Summary

Softmax

\(p_i = e^{z_i} / \sum_j e^{z_j}\)

Loss

\(L = -\log p_k = -z_k + \log S\)

dz (single)

\(d\mathbf{z} = \mathbf{p} - \mathbf{y}\)

dZ (batch)

\(dZ = \tfrac{1}{N}(P - Y)\)

Stable forward

Subtract \(m = \max_j z_j\) before exp

Why it's clean

Log and exp cancel; cross entropy is the natural loss for softmax

QuantityFormulaShape
Softmax \(\mathbf{p}\)\(e^{\mathbf{z}} / \sum_j e^{z_j}\)\(\mathbb{R}^V\)
Loss \(L\)\(-\log p_k\)scalar
Gradient \(d\mathbf{z}\) (true class \(k\))\(p_k - 1\)scalar (entry \(k\))
Gradient \(d\mathbf{z}\) (other classes)\(p_i\)scalar (entry \(i \ne k\))
Gradient \(d\mathbf{z}\) (vectorized)\(\mathbf{p} - \mathbf{y}\)\(\mathbb{R}^V\)
Batch gradient \(dZ\)\(\tfrac{1}{N}(P - Y)\)\(\mathbb{R}^{N \times V}\)