Pytorch's Autograd allows the automatic differention of complex computations and is central to the backpropagation of neural networks. This post focuses on the foundations of computing gradients using this functionality, points out the things to watch out for, and provides an example walk-thru to consolidate the points being made. No reference is made to an actual neural network model, but the ideas should be applicable to them as well.
Nomenclature
$$ \frac{\partial{f}}{\partial{x}} $$ is called:
-
The (first) partial derivative of \(f\) with respect to \(x\).
-
The (first) partial derivative of \(f\) w.r.t \(x\).
-
The gradient of \(f\).
-
The \(x\)-component of the gradient of \(f\).
-
The gradient of \(f\) on \(x\).
-
The gradient on \(x\).
-
The gradient on \(x\) w.r.t \(f\).
In deep learning, 'gradient' is more commonly used than 'derivative', so 'gradient' will be used in general in this post, but hopefully the explicit mathematical symbols shown alongside will help avoid confusion.
How to compute gradient: basic howto
To compute the gradient in Pytorch, it's generally sufficient to simply:
-
Set
requires_grad
toTrue
. -
Call the
backward
method.
x = torch.tensor(3., requires_grad=True)
x.backward()
x, x.grad
(tensor(3., requires_grad=True), tensor(1.))
The gradient that is computed by x.backward()
here is:
$$
\frac{dx}{dx}\bigr\rvert_{x=3} = 1
$$
and it is stored in the grad
attribute of x
, x.grad
.
Computation graph
In a piece of code, statements which involve, either directly or indirectly, one or more tensors with requires_grad=True
, form a computation graph 1. The computation graph keeps track of the computation described by the statements and provides the information needed for Autograd to compute the gradients when backward
is called.
In the following statements, because none of the tensors have requires_grad=True
, there isn't information to allow the computation of gradients.
x = torch.tensor(9.)
y = x**2 + x + torch.exp(x)
z = y + 2
x, y, z
(tensor(9.), tensor(8193.0840), tensor(8195.0840))
Whereas in the following computation, there is a computation graph, which you can tell because when printed, requires_grad=True
or a grad_fn
attribute is displayed alongside their value. Notice here that x
cannot have gradients on it computed, nor its backward
called, because it has neither a requires_grad=True
nor grad_fn
; it's not part of a computation graph.
x = torch.tensor(9.)
y = x**2 + x + torch.exp(x)
y.requires_grad_()
z = y + 2
x, y, z
(tensor(9.),
tensor(8193.0840, requires_grad=True),
tensor(8195.0840, grad_fn=<AddBackward0>))
In general, statements leading up to a backward
call constitute what's called a forward propagation.
Can only get gradients on leaf variables
Let's examine the effect of z.backward()
on the grad
attributes of the variables.
print(x.grad, y.grad, z.grad)
z.backward()
print(x.grad, y.grad, z.grad)
None None None
None tensor(1.) None
It can be seen that z
has no grad
, even though it is expected that the gradient of z
on z
, \(\frac{dz}{dz}\), should be \(1\). This is because z
is not what is considered a leaf variable by Pytorch. Leaf variables are the earliest variables in a computation graph; often they are the inputs of a computation. Pytorch only makes the gradients on leaf variables available after a backward
call, in order to save memory. It appears that one way to distinguish a leaf variable from one that isn't, an intermediary variable, is to print them. A leaf variable has requires_grad=True
printed after its value, while an intermediary variable has a grad_fn
of some sort printed after its value.
If you do want to get the gradients on intermediary variables, you need to use hooks, but this is outside the scope of this post.
Resetting the gradient
It's important to note that x.grad
accumulates additively with each successful call of the backward
method, so care needs to be taken to reset the gradient whenever appropriate.
For example, it is expected that \(\frac{dx}{dx}\) should always be \(1\).
x = torch.tensor(3., requires_grad=True)
x.backward()
x.grad
tensor(1.)
x.backward()
x.grad
tensor(2.)
There are various ways to reset the gradient: grad.zeros_()
, grad.data.zero_
, and grad = None
, etc.
x.grad.zero_()
x.backward()
x.grad
tensor(1.)
x.grad.data.zero_()
x.backward()
x.grad
tensor(1.)
x.grad = None
x.backward()
x.grad
tensor(1.)
Computation graph is flushed after backward()
In the previous section, x.backward()
was called multiple times without any problems (except for the need to reset grad
in order for it to make sense).
However, as soon as additional computations are done on x
, the backward()
method of the result of these computations can no longer be called consecutively. For example, take the following computation that takes the square of \(x\):
$$
m = x^2
$$
x = torch.tensor(3., requires_grad=True)
m = x**2
m
tensor(9., grad_fn=<PowBackward0>)
Differentiating with respect to \(x\): $$ \frac{dm}{dx} = \frac{d}{dx} x^2 = 2x $$
m.backward()
x.grad
tensor(6.)
However, repeating the above cell again raises a RuntimeError
. As explained in this thread, this is because all the 'intermediary results' of the computations are deleted after the first backward
call, so another backward call cannot be done, as it needs these intermediary results. The reason for the deletion of these intermediary results by default is to save memory.
x.grad.zero_()
m.backward()
RuntimeError
...
...
RuntimeError: Trying to backward through the graph a second time, but the
buffers have already been freed. Specify retain_graph=True when calling
backward the first time.
For a computation that is very small, like the current one, this can be resolved by simply repeating the computation, hence making the intermediary results available again.
m = x**2
x.grad.zero_()
m.backward()
x.grad
tensor(6.)
However, for large computations which take a long time, it's better to use backward(retain_graph=True)
to tell Pytorch not to delete those intermediary results, so that another backward
call can be made.
m = x**2
x.grad.zero_()
m.backward(retain_graph=True)
x.grad
tensor(6.)
x.grad.zero_()
m.backward()
x.grad
tensor(6.)
Practice: staged computation
Let's now consider a larger computation and do the backpass on it in stages , to practice the ideas described so far. This is in the same vein as in 1, except Pytorch is used, so there is no need to first find the analytical expressions of the gradients.
Consider the following computation: $$ x = 3\\ m = x^2\\ y = m + x\\ p = 4y\\ q = -2y $$ Note that each equation is represented as a gate in a computation graph, with the variable on the LHS being the gate output, and the variable(s) on the RHS being the gate input(s).
Forward pass through the computation graph. The values of all variables in the computation are printed out, as these might be needed during the backward pass. The values are: $$ x = 3 \\ m = 9 \\ y = 12 \\ p = 48 \\ q = -24 $$
x = torch.tensor(3., requires_grad=True)
m = x**2
y = m + x
p = 4*y
q = -2*y
print(f'x={x} m={m} y={y} p={p} q={q}')
x=3.0 m=9.0 y=12.0 p=48.0 q=-24.0
Here, the computation graph has two final branches, one that ends with \(p\), and one that ends with \(q\). Let's first backpass through the gate along the former: $$ y=12\\ p=4y $$ Notice that, because Pytorch only makes the leaf variables' gradients available, it's necessary to use a computation graph that starts with \(y\), using its value obtained during the forward pass (even though here its value doesn't matter in the backpass): $$ \frac{dp}{dy}\bigr\rvert_{y=12} = 4 $$
y = torch.tensor(12., requires_grad=True)
p = 4*y
p.backward()
y.grad
tensor(4.)
Similarly for \(q = -2y\): $$ \frac{dq}{dy}\bigr\rvert_{y=12} = -2 $$
y = torch.tensor(12., requires_grad=True)
q = -2*y
q.backward()
y.grad
tensor(-2.)
The remaining part of the computation graph to backpass through can be expressed as: $$ x = 3\\ m = x^2\\ y = m + x $$ Instead of back passing first through the y gate and then the m gate, \(\frac{dy}{dx}\) can be more directly obtained by using the expression $$ \frac{dy}{dx} = \frac{dm}{dx} + \frac{dx}{dx} $$ and the facts that:
-
x.grad
accumulates additively, so calculating the RHS of the above equation amounts to callingm.backward
followed byx.backward()
. -
x.backward()
doesn't require any intermediary results, sincex
is already a leaf variable.
The gradient should be: $$ \frac{dy}{dx}\bigr\rvert_{x=3} = \frac{dm}{dx}\bigr\rvert_{x=3} + \frac{dx}{dx}\bigr\rvert_{x=3} = 2\cdot 3 + 1 = 7 $$
x = torch.tensor(3., requires_grad=True)
m = x**2
m.backward()
print(f'dm/dx = {x.grad}')
x.backward()
print(f'dy/dx = dm/dx + dx/dx = {x.grad}')
dm/dx = 6.0
dy/dx = dm/dx + dx/dx = 7.0
Having backpropagated through all the gates and noted the gradients at each, the chain rule for derivatives can be used to construct the gradient for the whole computation graph. In this case, there will be two overall gradients, \(\frac{dp}{dx}\) and \(\frac{dq}{dx}\). $$ \frac{dp}{dx} = \frac{dp}{dy}\frac{dy}{dx} = 4 \cdot 7 = 28\\ \frac{dq}{dx} = \frac{dq}{dy}\frac{dy}{dx} = -2 \cdot 7 = -14 $$
Checking the result of the staged computation
In the section above, the computation of the gradients \(\frac{dp}{dx}\) and \(\frac{dq}{dx}\) is broken up into small steps. In this section, backward
is called for the whole graph once, so that the backpass is done in one big step instead. This is the normal way to use backward
, which is to describe whatever computation is desired through normal statements, and at the end of the statements, call the backward
method of the computation's output to back propagate all the way to get the gradients on the leaf variables.
x = torch.tensor(3., requires_grad=True)
m = x**2
y = m + x
p = 4*y
q = -2*y
print(f'x={x} m={m} y={y} p={p} q={q}')
x=3.0 m=9.0 y=12.0 p=48.0 q=-24.0
It can be seen, either from the chain rule or the computation graph, that \(\frac{dp}{dx}\) and \(\frac{dq}{dx}\) have \(\frac{dy}{dx}\) in common. Therefore, it's necessary to have retain_graph=True
when calling backward
to obtain one before calling backward
again to obtain the other, otherwise an error will result, because \(\frac{dy}{dx}\), being a intermediary result here, will have been cleared from the buffer.
x.grad = None
p.backward(retain_graph=True)
x.grad
tensor(28.)
x.grad = None
q.backward()
x.grad
tensor(-14.)
The resulting \(\frac{dp}{dx}\) and \(\frac{dq}{dx}\) are consistent with those obtained by staged computation in the previous section.