Chain Rule and Computational Graphs
How gradients flow backward through composition
Prerequisites
Learning objectives
- Apply the chain rule to nested functions
- Draw a computational graph and label local derivatives
- Run a forward pass and a backward pass by hand
- See backpropagation as the chain rule bookkept over a graph
Why the chain rule is the engine of deep learning
Every neural network you will ever train is one enormous composed function: raw pixels flow into a linear layer, then a nonlinearity, then another linear layer, then another nonlinearity, dozens of times, until a single loss number comes out the end. Training means asking one question over and over — if I nudge this weight, how does the loss change? — and adjusting the weight against that answer.
That question is a derivative of a deeply nested function. The tool that answers it is the chain rule, and the data structure that makes the answer cheap to compute is the computational graph. Put the two together and you have backpropagation — the algorithm that trains essentially every deep model in existence. This chapter builds it from the ground up, and by the end you will have written a tiny backprop pass by hand and checked it against the definition of the derivative.
Intuition: rates multiply through a chain
Suppose a gear train drives a wheel. Turning the crank one turn spins the middle gear 3 turns; each turn of the middle gear spins the wheel 2 turns. How fast does the wheel turn per turn of the crank? You multiply: . Rates of change compose by multiplication.
The chain rule says exactly this for functions. If depends on , and depends on , then the sensitivity of to is the sensitivity of to times the sensitivity of to :
The intermediate variable acts like the middle gear. Notice how the notation almost begs you to "cancel" the — that is a useful mnemonic, though not a literal fraction cancellation. The real content is: a small change produces a change in the middle, which produces a change at the end. Substitute one into the other and the local rates multiply.
The lab above traces our running example, , through its intermediate steps. Drag the inputs and watch each local rate light up, then watch them multiply back along the arrows. Keep it open — everything below is a written account of what that picture is doing.
The chain rule, formally
| Symbol | Meaning | Type | Shape | Role |
|---|---|---|---|---|
| Input variable we differentiate with respect to | scalar | 1 | variable | |
| Intermediate value (the inner function) | scalar | 1 | variable | |
| Output (the outer function) | scalar | 1 | variable | |
| Local derivative of the inner step | scalar | 1 | derivative | |
| Local derivative of the outer step | scalar | 1 | derivative | |
| Total derivative (product of the locals) | scalar | 1 | derivative |
The word local is the one to hold onto. Each step in a computation knows only its own tiny derivative — how its output moves when its input moves. The chain rule is the bookkeeping that assembles those purely local facts into the global answer.
A numerical example, two ways
Take the running function with a single variable in view. Fix and and ask for where
We will differentiate it twice, by two different routes, and confirm the answers are identical — that is the whole promise of the chain rule.
The two routes agree, as they must. The lesson is not that they agree — it is why the chain-rule route scales. Expanding worked here because the function was tiny. For a hundred-layer network there is no "expanded form" to write down; the chain rule, applied step by step, is the only feasible route.
Deriving it through the computational graph
To make the chain rule mechanical, we lay the computation out as a graph. Each node is one elementary operation; each edge carries a value forward and, later, a gradient backward. Our running example breaks into three steps:
The graph has inputs feeding a multiply node (), an add node (), and a square node ().
Forward pass — compute values, left to right
Pick concrete inputs , , and walk the graph forward, storing each intermediate result (we will need them in a moment):
That is an ordinary function evaluation. The one discipline that matters: save the intermediates and . Backprop reuses them, and recomputing them is the difference between a cheap algorithm and an exponentially expensive one.
Local derivatives — one per edge
Before flowing anything backward, write down each node's own derivative with respect to its inputs. These are the "gears," and each is trivial in isolation:
The square node contributes ; an addition node passes its gradient through untouched (derivative to each input); a multiply node hands each input the other factor.
Backward pass — multiply local derivatives back to the inputs
Now start at the output with and flow right to left, at each edge multiplying by the local derivative. This accumulated quantity — the derivative of the final output with respect to the current node — is the node's gradient.
Collecting the results in closed form (substitute ):
Sanity check against the earlier example: — exactly the number we got by expanding, and by the one-line chain rule. Three routes, one answer.
This is backpropagation
What we just did by hand is the backpropagation algorithm in full. There is no extra machinery hiding in PyTorch or JAX — a deep-learning framework builds the computational graph as your forward code runs, stores each intermediate, and then sweeps backward multiplying local derivatives, reusing those stored intermediates exactly as we reused and . Training a network is:
- Forward pass — run the input through the graph to a scalar loss, caching intermediates.
- Backward pass — seed and propagate gradients back to every weight.
- Update — nudge each weight against its gradient (next chapter: gradient descent).
Because the total derivative along any path is a product of local derivatives, the number of layers controls how many factors get multiplied — and that has a sharp practical consequence.
NumPy: forward, backward, and a gradient check
The real test of a backward pass is whether it matches the definition of the derivative. We approximate each true derivative with a finite difference, , and assert our analytic gradients agree. This "gradient check" is exactly how you debug a hand-written backprop in practice. Run it:
Notice the shape of the backward function: it is a straight-line sequence of multiplications, one per node, walking the cache in reverse. That is all backpropagation ever is. Scale this pattern up to matrices and thousands of nodes and you have the training loop of a real network.
Summary
- The chain rule composes rates by multiplication: for , — outer derivative times inner derivative. A long chain multiplies one local derivative per link.
- A computational graph turns a formula into nodes (operations) and edges (values forward, gradients backward). Our running example splits as , , .
- The forward pass computes and caches values; the backward pass seeds and multiplies local derivatives right-to-left, reusing the cached intermediates. For this gives , , .
- This bookkeeping is backpropagation — the algorithm that trains every neural network.
- Because a deep gradient is a product of many local derivatives, it can vanish (factors ) or explode (factors ); much of modern architecture design exists to control that product.
- A variable used on multiple paths sums the gradients from each path; forgetting one is the classic backprop bug, caught by a finite-difference gradient check.
Active recall
Answer from memory before checking the lesson:
- State the chain rule for and explain, in words, why the two local derivatives multiply rather than add.
- For with , , , run the forward pass and report , , and .
- From that forward pass, run the backward pass to get , , and . What role does the cached value play?
- Why can a gradient vanish in a deep network, and what does that do to the early layers' learning?
- A variable is used in two places in a graph. How do you combine the gradients coming back along the two paths, and what goes wrong if you don't?
Exercises
Level ARecall & basic calculation
Chain rule at a point
Let . Use the chain rule to compute at .
Forward pass of the graph
For decomposed as , , , run the forward pass with , , and report .
Local derivative of the square node
The square node computes . What is its local derivative evaluated at ?
Gradient with respect to the bias
For , the backward pass gives . Evaluate it at , , .
Gradient through an add node
In the graph , the gradient arriving at is . What gradient does the add node pass back to ?
Gradient through a multiply node
In the graph , the gradient arriving at is and . What gradient does the multiply node pass back to ?
Level BConceptual understanding
Why local derivatives multiply
For with , why is a product of the two local rates rather than a sum?
Backward pass with new inputs
For run forward with , , , then use the backward pass to compute . (Recall .)
Diagnosing a vanishing gradient
A 20-layer network uses an activation whose local derivative is at most everywhere. Roughly what happens to the gradient reaching the first layer, and why?
Why cache the intermediates?
Backpropagation stores intermediate values (like in ) during the forward pass instead of recomputing them during the backward pass. In one or two sentences, explain why this matters for cost.
A variable used on two paths
In a graph, the variable feeds two different downstream nodes. When you run the backward pass, how do you combine the gradients arriving at from the two paths?
Level CDerivation & implementation
Implement backprop for (wx+b)²
Implement forward(w, x, b) returning plus a cache, and backward(cache) returning by multiplying local derivatives back through , , . Verify at that the gradients are , then print ok.
Backprop a three-input product
Consider with graph , , . Derive , , and by the backward pass, then evaluate them at , , .
Finite-difference gradient check
Write a central finite-difference checker for . Use with to approximate , , at , assert they match the analytic gradients with np.allclose, then print ok.
Simulate a vanishing-gradient product
The gradient at the first of layers is a product of local derivatives. Numerically compute the product of factors each equal to , and separately each equal to , print both, and assert the first is below and the second is above . End by printing ok.
Level DResearch-thinking challenge
Why do gradients vanish, and how is it fixed?
Explain, using the chain rule, why deep networks with sigmoid activations suffer vanishing gradients. Then explain mechanistically how two of {ReLU activation, residual (skip) connections, batch/layer normalization} address it — pointing to what each does to the per-layer local derivative.
Reverse-mode vs forward-mode autodiff
Backpropagation is reverse-mode automatic differentiation. For a function (many inputs, scalar loss — the neural-network case), explain why reverse mode computes all input gradients in roughly one backward sweep, whereas forward-mode autodiff would need about passes. What does this asymmetry cost, and when would forward mode actually be preferable?