Article
Machine Learning · ⏱ ~14 min read

Backpropagation — the Chain Rule Unwrapped

Backpropagation is the algorithm that makes deep learning possible. It is nothing more than the chain rule of calculus applied systematically to a computational graph — yet when Rumelhart, Hinton and Williams popularised it in 1986 it transformed the field. We trace the math from scalar chain rules through Jacobians and matrix calculus, and build a minimal autograd engine that differentiates any expression automatically.

1. The Chain Rule

Suppose f(x) = g(h(x)). The chain rule says the derivative of f with respect to x is:

df/dx = dg/dh · dh/dx Example: f = sin(x²) h = x² → dh/dx = 2x g = sin(h) → dg/dh = cos(h) df/dx = cos(x²) · 2x

For a chain of n composed functions this generalises to a product of n partial derivatives. Backprop is exactly this product, evaluated layer-by-layer starting from the loss output.

2. Computational Graphs

A computational graph is a directed acyclic graph (DAG) where nodes are operations (+, ·, exp, …) and edges carry values. Every mathematical expression can be expressed as such a graph:

L = (a·w + b − y)² Graph: m = a·w (multiply) s = m + b (add) e = s − y (subtract) L = e² (square)

During the forward pass we evaluate values left-to-right. During the backward pass we multiply local partial derivatives right-to-left to accumulate each node's contribution to the total gradient.

3. Forward and Backward Passes

For the graph above, the backward pass (setting dL/dL = 1 as seed) propagates:

dL/de = 2e dL/ds = dL/de · de/ds = 2e · 1 = 2e dL/dm = dL/ds · ds/dm = 2e · 1 = 2e dL/db = dL/ds · ds/db = 2e · 1 = 2e dL/dw = dL/dm · dm/dw = 2e · a dL/da = dL/dm · dm/da = 2e · w

Each node accumulates gradient from all nodes that depend on it (multivariate chain rule). Where a value is used in multiple downstream nodes its gradient contributions are summed.

4. Jacobians and Matrix Gradients

When inputs and outputs are vectors, the "derivative" is the Jacobian matrix J, where Jᵢⱼ = ∂yᵢ/∂xⱼ. For a linear layer y = Wx + b:

∂L/∂W = δ · xᵀ (outer product) ∂L/∂x = Wᵀ · δ (passed to previous layer) ∂L/∂b = δ (sum over batch dimension) where δ = ∂L/∂y (upstream gradient vector)

Element-wise ops

For y = σ(z) element-wise, the Jacobian is diagonal: ∂L/∂z = ∂L/∂y ⊙ σ'(z). Just a pointwise product.

Softmax Jacobian

For y = softmax(z), Jᵢⱼ = yᵢ(δᵢⱼ − yⱼ) — a full dense matrix. Combined with cross-entropy loss it simplifies to ŷ − y.

Batch dimension

With mini-batch of m samples, weight gradients are averaged over the batch: dW = (δ · Aᵀ) / m to keep gradients scale-independent of batch size.

Gradient check

Verify analytical gradients numerically: compare dL/dw to [L(w+ε)−L(w−ε)]/(2ε). If they agree to ~6 decimal places, your backprop is correct.

5. Vanishing and Exploding Gradients

In a 20-layer network, the gradient of the first layer involves a product of 20 weight matrices and 20 activation derivatives. If each factor has magnitude < 1 (sigmoid saturates to σ' ≈ 0), the product shrinks exponentially → vanishing gradient. If each factor > 1, it grows exponentially → exploding gradient.

Vanishing: σ'(z) ≤ 0.25 for sigmoid → (0.25)²⁰ ≈ 10⁻¹² Exploding: ||W|| > 1 per layer → ||W||²⁰ → ∞

Solutions:

6. Beyond SGD — Adam and Friends

SGD updates each weight identically regardless of its gradient history. Adaptive optimisers track per-parameter gradient statistics:

Adam (Kingma & Ba 2014): m_t = β₁ m_{t-1} + (1−β₁) g_t (1st moment — mean) v_t = β₂ v_{t-1} + (1−β₂) g_t² (2nd moment — variance) m̂_t = m_t / (1−β₁ᵗ) (bias correction) v̂_t = v_t / (1−β₂ᵗ) θ_t = θ_{t-1} − lr · m̂_t / (√v̂_t + ε) Typical: β₁=0.9, β₂=0.999, ε=1e-8, lr=0.001

Adam converges faster than SGD in most settings and is robust to learning rate choice. For fine-tuning language models AdamW adds proper L2 weight decay decoupled from the gradient step.

7. Minimal Autograd Engine

// Scalar autograd — reverse-mode automatic differentiation
class Value {
  constructor(data, _children = [], _op = '') {
    this.data = data;
    this.grad = 0;
    this._backward = () => {};
    this._prev = new Set(_children);
    this._op = _op;
  }
  add(other) {
    other = other instanceof Value ? other : new Value(other);
    const out = new Value(this.data + other.data, [this, other], '+');
    out._backward = () => { this.grad += out.grad; other.grad += out.grad; };
    return out;
  }
  mul(other) {
    other = other instanceof Value ? other : new Value(other);
    const out = new Value(this.data * other.data, [this, other], '*');
    out._backward = () => {
      this.grad  += other.data * out.grad;
      other.grad += this.data  * out.grad;
    };
    return out;
  }
  pow(n) {
    const out = new Value(this.data ** n, [this], `**${n}`);
    out._backward = () => { this.grad += n * (this.data ** (n - 1)) * out.grad; };
    return out;
  }
  relu() {
    const out = new Value(this.data > 0 ? this.data : 0, [this], 'relu');
    out._backward = () => { this.grad += (out.data > 0 ? 1 : 0) * out.grad; };
    return out;
  }
  backward() {
    // Topological sort, then call _backward in reverse order
    const topo = []; const visited = new Set();
    const build = v => {
      if (!visited.has(v)) {
        visited.add(v);
        for (const child of v._prev) build(child);
        topo.push(v);
      }
    };
    build(this);
    this.grad = 1;
    for (const v of topo.reverse()) v._backward();
  }
}

// Example: L = (x·w + b - y)²
const x = new Value(2.0);
const w = new Value(-3.0);
const b = new Value(6.88);
const y = new Value(1.0);
const L = x.mul(w).add(b).add(y.mul(-1)).pow(2);
L.backward();
console.log(w.grad); // ∂L/∂w  ≈ 2*(x*w+b-y)*x

8. Automatic Differentiation Modes

Reverse mode (backprop)

One backward pass computes ∂L/∂θ for ALL parameters simultaneously. Ideal when outputs ≪ inputs — exactly the neural net case.

Forward mode

Computes the Jacobian-vector product Jv — one forward pass per input dimension. Efficient when inputs ≪ outputs (rare in ML).

Symbolic diff

Derives closed-form expressions (Mathematica, SymPy). Exact but can produce exponentially large expressions ("expression swell").

Numerical diff

Finite differences [f(x+h)−f(x)]/h. Simple but slow (one pass per parameter) and subject to floating-point cancellation errors.

PyTorch and JAX use reverse-mode AD with a dynamic computation graph (define-by-run). JAX additionally supports forward-mode AD and function composition operators (jit, vmap, grad) that compose cleanly because of its functional design.

🧠 Open Neural Network →