ML Math

SwiGLU

Forward and backward pass for the Swish-Gated Linear Unit used in LLaMA, PaLM, and most modern transformer FFN layers.

Notation

SymbolDescription
\(\mathbf{x} \in \mathbb{R}^d\)Input token embedding (one token; batch/sequence dimensions suppressed)
\(W_g, W_v \in \mathbb{R}^{d \times n}\)Gate and value projection weights; \(n\) is the FFN hidden width
\(W_o \in \mathbb{R}^{n \times d}\)Output (down) projection weights
\(\mathbf{u} = \mathbf{x}W_g \in \mathbb{R}^n\)Gate pre-activation
\(\mathbf{v} = \mathbf{x}W_v \in \mathbb{R}^n\)Value pre-activation
\(\sigma(z)\)Sigmoid: \(1/(1+e^{-z})\)
\(\mathrm{Swish}(z)\)Swish / SiLU: \(z\,\sigma(z)\), applied element-wise
\(\mathbf{h} \in \mathbb{R}^n\)Gated hidden state: \(\mathrm{Swish}(\mathbf{u}) \odot \mathbf{v}\)
\(\mathbf{y} \in \mathbb{R}^d\)FFN output: \(\mathbf{h}\,W_o\)
\(d\mathbf{y},\, d\mathbf{h},\, d\mathbf{u},\, d\mathbf{v},\, d\mathbf{x}\)Upstream and downstream gradients \(\partial L / \partial (\cdot)\)

1. Swish Activation and Its Derivative

Swish (also called SiLU) is a smooth, non-monotone gating function:

\[ \mathrm{Swish}(z) = z\,\sigma(z) = \frac{z}{1+e^{-z}} \]

Its derivative follows from the product rule:

\[ \mathrm{Swish}'(z) = \sigma(z) + z\,\sigma(z)\bigl(1 - \sigma(z)\bigr) = \sigma(z) + \mathrm{Swish}(z)\bigl(1 - \sigma(z)\bigr) \]
Both \(\sigma(z)\) and \(\mathrm{Swish}(z)\) are already computed in the forward pass, so the backward pass reuses them for free — no additional transcendental functions required.

2. Forward Pass

  1. Gate pre-activation   \(\mathbf{u} = \mathbf{x}W_g\)
  2. Value pre-activation   \(\mathbf{v} = \mathbf{x}W_v\)
  3. SwiGLU gate \[ h_i = \mathrm{Swish}(u_i)\cdot v_i = u_i\,\sigma(u_i)\cdot v_i \]
  4. Output projection   \(\mathbf{y} = \mathbf{h}\,W_o\)
In LLaMA, the gate projection is sometimes called gate_proj, the value projection up_proj, and the output projection down_proj. The hidden width is typically set to \(\tfrac{2}{3} \cdot 4d\) (rounded to a multiple of 64) so that parameter count matches a standard \(4d\) FFN after adding the extra gate path.

3. Backward Pass — Activations

Step A — Gradient through the output projection

Since \(\mathbf{y} = \mathbf{h}W_o\), the Jacobian \(\partial \mathbf{y}/\partial \mathbf{h} = W_o^\top\):

\[ d\mathbf{h} = d\mathbf{y}\,W_o^\top \in \mathbb{R}^n \]
Step B — Gradient through the gate (element-wise product)

Each output \(h_i = \mathrm{Swish}(u_i)\cdot v_i\) depends on \(u_i\) and \(v_i\) independently:

\[ \frac{\partial h_i}{\partial u_i} = \mathrm{Swish}'(u_i)\cdot v_i \qquad \frac{\partial h_i}{\partial v_i} = \mathrm{Swish}(u_i) \]

Applying the chain rule element-wise:

\[ du_i = dh_i \cdot v_i \cdot \mathrm{Swish}'(u_i) \qquad dv_i = dh_i \cdot \mathrm{Swish}(u_i) \]
Gate gradients (vectorized)
\[ d\mathbf{u} = d\mathbf{h} \odot \mathbf{v} \odot \mathrm{Swish}'(\mathbf{u}) \qquad\qquad d\mathbf{v} = d\mathbf{h} \odot \mathrm{Swish}(\mathbf{u}) \]

Expanding \(\mathrm{Swish}'(\mathbf{u})\):

\[ \mathrm{Swish}'(\mathbf{u}) = \boldsymbol{\sigma}(\mathbf{u}) + \mathrm{Swish}(\mathbf{u})\odot\bigl(\mathbf{1} - \boldsymbol{\sigma}(\mathbf{u})\bigr) \]
Step C — Gradient through the input projections

Both gate and value paths flow back through \(\mathbf{x}\), so their contributions add:

\[ d\mathbf{x} = d\mathbf{u}\,W_g^\top \;+\; d\mathbf{v}\,W_v^\top \in \mathbb{R}^d \]

4. Weight Gradients

The three projection matrices each receive one gradient contribution. For a single token, the weight gradients are outer products; accumulated over a batch \((N)\) and sequence \((L)\):

\[ dW_g = \mathbf{x}^\top\, d\mathbf{u} \in \mathbb{R}^{d \times n} \qquad\qquad dW_v = \mathbf{x}^\top\, d\mathbf{v} \in \mathbb{R}^{d \times n} \] \[ dW_o = \mathbf{h}^\top\, d\mathbf{y} \in \mathbb{R}^{n \times d} \]

In matrix form across the full batch (stacking tokens into rows):

\[ dW_g = X^\top \, dU, \qquad dW_v = X^\top \, dV, \qquad dW_o = H^\top \, dY \]

5. Summary

Swish'(u)

\(\sigma(\mathbf{u}) + \mathrm{Swish}(\mathbf{u})\odot(1-\sigma(\mathbf{u}))\)

dh

\(d\mathbf{y}\,W_o^\top\)

du

\(d\mathbf{h}\odot\mathbf{v}\odot\mathrm{Swish}'(\mathbf{u})\)

dv

\(d\mathbf{h}\odot\mathrm{Swish}(\mathbf{u})\)

dx

\(d\mathbf{u}\,W_g^\top + d\mathbf{v}\,W_v^\top\)

Weight grads

\(dW_g = X^\top dU\),  \(dW_v = X^\top dV\),  \(dW_o = H^\top dY\)

QuantityFormulaShape
Forward: \(\mathbf{h}\)\(\mathrm{Swish}(\mathbf{x}W_g)\odot(\mathbf{x}W_v)\)\(\mathbb{R}^n\)
Forward: \(\mathbf{y}\)\(\mathbf{h}\,W_o\)\(\mathbb{R}^d\)
\(d\mathbf{h}\)\(d\mathbf{y}\,W_o^\top\)\(\mathbb{R}^n\)
\(d\mathbf{u}\)\(d\mathbf{h}\odot\mathbf{v}\odot\mathrm{Swish}'(\mathbf{u})\)\(\mathbb{R}^n\)
\(d\mathbf{v}\)\(d\mathbf{h}\odot\mathrm{Swish}(\mathbf{u})\)\(\mathbb{R}^n\)
\(d\mathbf{x}\)\(d\mathbf{u}\,W_g^\top + d\mathbf{v}\,W_v^\top\)\(\mathbb{R}^d\)