Forward and backward pass for the softmax + cross-entropy loss used in language model next-token prediction and multi-class classification.
| Symbol | Description |
|---|---|
| \(\mathbf{z} \in \mathbb{R}^V\) | Logits (raw pre-softmax scores); \(V\) is the vocabulary or class count |
| \(\mathbf{p} \in \mathbb{R}^V\) | Softmax probabilities: \(p_i = e^{z_i} / \sum_j e^{z_j}\) |
| \(S = \sum_j e^{z_j}\) | Partition function (normalization denominator) |
| \(k\) | Index of the true class |
| \(\mathbf{y} \in \{0,1\}^V\) | One-hot target: \(y_i = \mathbf{1}[i = k]\) |
| \(L\) | Scalar cross-entropy loss |
| \(d\mathbf{z}\) | Gradient \(\partial L / \partial \mathbf{z}\) |
Convert logits to a probability distribution over all \(V\) classes:
\[ p_i = \frac{e^{z_i}}{S}, \qquad S = \sum_{j=1}^{V} e^{z_j} \]By construction \(p_i > 0\) and \(\sum_i p_i = 1\), so \(\mathbf{p}\) is a valid probability distribution.
Measure the log-loss against the one-hot target:
\[ L = -\sum_{i=1}^{V} y_i \log p_i = -\log p_k \]Because only the true-class term survives the sum, the loss reduces to the negative log-probability assigned to the correct class.
Expanding \(-\log p_k\) using the definition of softmax:
\[ L = -\log p_k = -\log \frac{e^{z_k}}{S} = -\log e^{z_k} + \log S = -z_k + \log \sum_{j=1}^{V} e^{z_j} \]Computing \(e^{z_i}\) directly overflows for large logits. The standard fix subtracts the maximum logit before exponentiating. Let \(m = \max_j z_j\):
\[ p_i = \frac{e^{z_i - m}}{\sum_j e^{z_j - m}} \]This is mathematically identical to the original (the \(e^{-m}\) factors cancel), but the largest exponent is now \(0\), keeping all values in \((0,\, 1]\).
F.cross_entropy and JAX's optax.softmax_cross_entropy both apply this shift internally. You rarely need to implement it by hand, but understanding it is important for writing numerically stable custom kernels.
We want \(\partial L / \partial z_i\) for each logit. The derivation goes through the Jacobian of softmax.
Apply the quotient rule to \(p_k = e^{z_k}/S\), where \(S = \sum_j e^{z_j}\).
Case \(i = k\) — both numerator and \(S\) depend on \(z_k\), so the quotient rule gives \((e^{z_k} \cdot S - e^{z_k} \cdot e^{z_k})/S^2\):
\[ \begin{aligned} \frac{\partial p_k}{\partial z_k} &= \frac{e^{z_k} \cdot S - e^{z_k} \cdot e^{z_k}}{S^2} \\[8pt] &= \frac{e^{z_k}}{S} - \left(\frac{e^{z_k}}{S}\right)^{\!2} = p_k - p_k^2 = p_k(1 - p_k) \end{aligned} \]Case \(i \ne k\) — the numerator \(e^{z_k}\) does not depend on \(z_i\), so its derivative is zero:
\[ \frac{\partial p_k}{\partial z_i} = \frac{0 \cdot S - e^{z_k} \cdot \dfrac{\partial S}{\partial z_i}}{S^2} = \frac{-e^{z_k} \cdot e^{z_i}}{S^2} = -\frac{e^{z_k}}{S} \cdot \frac{e^{z_i}}{S} = -p_k\, p_i \]Since \(L = -\log p_k\), the chain rule gives \(\partial L/\partial z_i = (-1/p_k)\cdot(\partial p_k/\partial z_i)\).
Case \(i = k\) — substitute the result from Step A:
\[ \frac{\partial L}{\partial z_k} = -\frac{1}{p_k} \cdot p_k(1-p_k) = -(1-p_k) = p_k - 1 \]Case \(i \ne k\) — the two negatives cancel:
\[ \frac{\partial L}{\partial z_i} = -\frac{1}{p_k} \cdot (-p_k\, p_i) = p_i \]Combining both cases using the one-hot label \(y_i\):
This remarkably clean formula is one of the most important in deep learning. The gradient at each logit is simply how much probability the model assigned minus what it should have assigned. No intermediate Jacobian matrices or division by \(p_k\) remain in the final expression.
In practice, loss is averaged over a mini-batch of \(N\) examples. Let \(Z \in \mathbb{R}^{N \times V}\) be the logit matrix and \(k_1, \ldots, k_N\) the true class indices.
\[ L = -\frac{1}{N}\sum_{n=1}^{N} \log p_{n,k_n} \]Softmax is applied row-wise to produce \(P \in \mathbb{R}^{N \times V}\). The gradient follows immediately:
\[ \frac{\partial L}{\partial Z} = \frac{1}{N}(P - Y) \]where \(Y \in \mathbb{R}^{N \times V}\) is the batch of one-hot labels. The \(1/N\) factor comes from averaging the loss — it simply scales every gradient row uniformly.
\(p_i = e^{z_i} / \sum_j e^{z_j}\)
\(L = -\log p_k = -z_k + \log S\)
\(d\mathbf{z} = \mathbf{p} - \mathbf{y}\)
\(dZ = \tfrac{1}{N}(P - Y)\)
Subtract \(m = \max_j z_j\) before exp
Log and exp cancel; cross entropy is the natural loss for softmax
| Quantity | Formula | Shape |
|---|---|---|
| Softmax \(\mathbf{p}\) | \(e^{\mathbf{z}} / \sum_j e^{z_j}\) | \(\mathbb{R}^V\) |
| Loss \(L\) | \(-\log p_k\) | scalar |
| Gradient \(d\mathbf{z}\) (true class \(k\)) | \(p_k - 1\) | scalar (entry \(k\)) |
| Gradient \(d\mathbf{z}\) (other classes) | \(p_i\) | scalar (entry \(i \ne k\)) |
| Gradient \(d\mathbf{z}\) (vectorized) | \(\mathbf{p} - \mathbf{y}\) | \(\mathbb{R}^V\) |
| Batch gradient \(dZ\) | \(\tfrac{1}{N}(P - Y)\) | \(\mathbb{R}^{N \times V}\) |