Part 4 · CalculusChapter 1585 min

Chain Rule and Computational Graphs

How gradients flow backward through composition

Learning objectives

  • Apply the chain rule to nested functions
  • Draw a computational graph and label local derivatives
  • Run a forward pass and a backward pass by hand
  • See backpropagation as the chain rule bookkept over a graph

Why the chain rule is the engine of deep learning

Every neural network you will ever train is one enormous composed function: raw pixels flow into a linear layer, then a nonlinearity, then another linear layer, then another nonlinearity, dozens of times, until a single loss number comes out the end. Training means asking one question over and over — if I nudge this weight, how does the loss change? — and adjusting the weight against that answer.

That question is a derivative of a deeply nested function. The tool that answers it is the chain rule, and the data structure that makes the answer cheap to compute is the computational graph. Put the two together and you have backpropagation — the algorithm that trains essentially every deep model in existence. This chapter builds it from the ground up, and by the end you will have written a tiny backprop pass by hand and checked it against the definition of the derivative.

Intuition: rates multiply through a chain

Suppose a gear train drives a wheel. Turning the crank one turn spins the middle gear 3 turns; each turn of the middle gear spins the wheel 2 turns. How fast does the wheel turn per turn of the crank? You multiply: 3×2=63 \times 2 = 6. Rates of change compose by multiplication.

The chain rule says exactly this for functions. If yy depends on uu, and uu depends on xx, then the sensitivity of yy to xx is the sensitivity of yy to uu times the sensitivity of uu to xx:

dydx=dydududx.\frac{dy}{dx} = \frac{dy}{du}\cdot\frac{du}{dx}.

The intermediate variable uu acts like the middle gear. Notice how the notation almost begs you to "cancel" the dudu — that is a useful mnemonic, though not a literal fraction cancellation. The real content is: a small change dxdx produces a change du=dudxdxdu = \frac{du}{dx}\,dx in the middle, which produces a change dy=dydududy = \frac{dy}{du}\,du at the end. Substitute one into the other and the local rates multiply.

Interactive LabChain-Rule Visualizer
Loading interactive lab…

The lab above traces our running example, y=(wx+b)2y = (wx+b)^2, through its intermediate steps. Drag the inputs and watch each local rate light up, then watch them multiply back along the arrows. Keep it open — everything below is a written account of what that picture is doing.

The chain rule, formally

The word local is the one to hold onto. Each step in a computation knows only its own tiny derivative — how its output moves when its input moves. The chain rule is the bookkeeping that assembles those purely local facts into the global answer.

A numerical example, two ways

Take the running function with a single variable in view. Fix ww and bb and ask for dydx\dfrac{dy}{dx} where

We will differentiate it twice, by two different routes, and confirm the answers are identical — that is the whole promise of the chain rule.

The two routes agree, as they must. The lesson is not that they agree — it is why the chain-rule route scales. Expanding worked here because the function was tiny. For a hundred-layer network there is no "expanded form" to write down; the chain rule, applied step by step, is the only feasible route.

Deriving it through the computational graph

To make the chain rule mechanical, we lay the computation out as a graph. Each node is one elementary operation; each edge carries a value forward and, later, a gradient backward. Our running example y=(wx+b)2y = (wx+b)^2 breaks into three steps:

The graph has inputs w,x,bw, x, b feeding a multiply node (uu), an add node (zz), and a square node (yy).

Forward pass — compute values, left to right

Pick concrete inputs w=2w = 2, x=3x = 3, b=1b = 1 and walk the graph forward, storing each intermediate result (we will need them in a moment):

That is an ordinary function evaluation. The one discipline that matters: save the intermediates uu and zz. Backprop reuses them, and recomputing them is the difference between a cheap algorithm and an exponentially expensive one.

Local derivatives — one per edge

Before flowing anything backward, write down each node's own derivative with respect to its inputs. These are the "gears," and each is trivial in isolation:

The square node contributes 2z2z; an addition node passes its gradient through untouched (derivative 11 to each input); a multiply node hands each input the other factor.

Backward pass — multiply local derivatives back to the inputs

Now start at the output with dydy=1\dfrac{dy}{dy} = 1 and flow right to left, at each edge multiplying by the local derivative. This accumulated quantity — the derivative of the final output with respect to the current node — is the node's gradient.

Collecting the results in closed form (substitute z=wx+bz = wx+b):

Sanity check against the earlier example: dydx=2zw=2(7)(2)=28\dfrac{dy}{dx} = 2z\,w = 2(7)(2) = 28 — exactly the number we got by expanding, and by the one-line chain rule. Three routes, one answer.

This is backpropagation

What we just did by hand is the backpropagation algorithm in full. There is no extra machinery hiding in PyTorch or JAX — a deep-learning framework builds the computational graph as your forward code runs, stores each intermediate, and then sweeps backward multiplying local derivatives, reusing those stored intermediates exactly as we reused zz and uu. Training a network is:

  1. Forward pass — run the input through the graph to a scalar loss, caching intermediates.
  2. Backward pass — seed dlossdloss=1\frac{d\,\text{loss}}{d\,\text{loss}} = 1 and propagate gradients back to every weight.
  3. Update — nudge each weight against its gradient (next chapter: gradient descent).

Because the total derivative along any path is a product of local derivatives, the number of layers controls how many factors get multiplied — and that has a sharp practical consequence.

NumPy: forward, backward, and a gradient check

The real test of a backward pass is whether it matches the definition of the derivative. We approximate each true derivative with a finite difference, dydθy(θ+ε)y(θε)2ε\frac{dy}{d\theta}\approx\frac{y(\theta+\varepsilon)-y(\theta-\varepsilon)}{2\varepsilon}, and assert our analytic gradients agree. This "gradient check" is exactly how you debug a hand-written backprop in practice. Run it:

chain_rule_backprop.py

Notice the shape of the backward function: it is a straight-line sequence of multiplications, one per node, walking the cache in reverse. That is all backpropagation ever is. Scale this pattern up to matrices and thousands of nodes and you have the training loop of a real network.

Summary

  • The chain rule composes rates by multiplication: for y=f(g(x))y = f(g(x)), dydx=dydududx\dfrac{dy}{dx} = \dfrac{dy}{du}\dfrac{du}{dx} — outer derivative times inner derivative. A long chain multiplies one local derivative per link.
  • A computational graph turns a formula into nodes (operations) and edges (values forward, gradients backward). Our running example splits as u=wxu=wx, z=u+bz=u+b, y=z2y=z^2.
  • The forward pass computes and caches values; the backward pass seeds dydy=1\frac{dy}{dy}=1 and multiplies local derivatives right-to-left, reusing the cached intermediates. For (wx+b)2(wx+b)^2 this gives dydw=2zx\frac{dy}{dw}=2zx, dydb=2z\frac{dy}{db}=2z, dydx=2zw\frac{dy}{dx}=2zw.
  • This bookkeeping is backpropagation — the algorithm that trains every neural network.
  • Because a deep gradient is a product of many local derivatives, it can vanish (factors <1<1) or explode (factors >1>1); much of modern architecture design exists to control that product.
  • A variable used on multiple paths sums the gradients from each path; forgetting one is the classic backprop bug, caught by a finite-difference gradient check.

Active recall

Answer from memory before checking the lesson:

  1. State the chain rule for y=f(g(x))y = f(g(x)) and explain, in words, why the two local derivatives multiply rather than add.
  2. For y=(wx+b)2y = (wx+b)^2 with w=2w=2, x=3x=3, b=1b=1, run the forward pass and report uu, zz, and yy.
  3. From that forward pass, run the backward pass to get dydw\frac{dy}{dw}, dydb\frac{dy}{db}, and dydx\frac{dy}{dx}. What role does the cached value zz play?
  4. Why can a gradient vanish in a deep network, and what does that do to the early layers' learning?
  5. A variable is used in two places in a graph. How do you combine the gradients coming back along the two paths, and what goes wrong if you don't?

Exercises

Level ARecall & basic calculation

Level BConceptual understanding

Level CDerivation & implementation

Level DResearch-thinking challenge