$ Micrograd in Haskell: Evaluation and Backward Pass
Welcome to the second post in a series on building a minimal neural network library in Haskell, which does automatic differentiation. If you haven't already, check out the first post.
You can find the complete code for this post at this playground link.
Evaluation
Here is the Op
type that we have been working with.
data Op
= Constant Double
| Var Double
| ReLU Op
| Add Op Op
| Multiply Op Op
| Negate Op
deriving (Eq, Show)
Haskell makes evaluating an expression tree straightforward using "pattern
matching". We destructure the Op
in a case expression, and recursively
evaluate the sub-expressions. Here is what the evaluate
function looks like.
evaluate :: Op -> Double
evaluate op = case op of
Constant x -> x
Var x -> x
Negate x -> negate (evaluate x)
ReLU x -> max 0 (evaluate x)
Add x y -> evaluate x + evaluate y
Multiply x y -> evaluate x * evaluate y
So elegant! Note the reliance on recursion in pattern matched branches.
Let's try it out in GHCi.
ghci> a :: Op = -4
ghci> b :: Op = 2
ghci> c = a * b + 3
ghci> evaluate c
-5.0
ghci> d = ReLU c
ghci> evaluate d
0.0
Great! It works, now let's implement the backward pass.
Backward Pass
Given a gradient value of 1.0
at the output node d
, we want to
know the gradients at a
and b
. Once we have the gradients, we can use
gradient descent like techniques to update a
and b
to minimize a loss
function.
It is not necessary for this post, but if curious, here is an accessible tutorial on gradient descent with lots of visuals.
How should we represent the gradients? Let's see how micrograd does this.
Here
is the implementation for the backward pass for __add__
operation
class Value:
...
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
def _backward():
self.grad += out.grad
other.grad += out.grad
out._backward = _backward
return out
...
Some notes about this implementation:
- We don't have an ADT equivalent in python. When constructing the
out
node, we pass in the parent nodesself
andother
, to keep track of children of the operation. - We keep a
grad
field in theValue
class to store the gradient value. - You normally need to invoke
model.zero_grad()
to reset the gradients to zero before running the backward pass. This is not shown in the snippet above, but you can see an example in micrograd's notebook.
How should we approach this in Haskell? We could try to mimic the python
implementation, but it is not very idiomatic. Instead, we use a "pure"
functional approach. We can define a backward
function that takes in an Op
and a gradient value, and returns a list of tuples containing the Var
nodes
and their corresponding gradients.
backward :: Op -> Double -> [(Op, Double)]
backward op grad = case op of
Constant _ -> []
Var _ -> [(op, grad)]
Negate x -> backward x (negate grad)
ReLU x ->
let xVal = evaluate x
reluGrad = if xVal > 0 then grad else 0
in backward x reluGrad
Add x y ->
backward x grad ++ backward y grad
Multiply x y ->
let xVal = evaluate x
yVal = evaluate y
in backward x (grad * yVal) ++ backward y (grad * xVal)
Once again, this is quite elegant!
Astute readers will observe that this implementation, while simple and
elegant, is not very efficient. The evaluate
function is called multiple
times for the same sub-expression. We'll ignore this for this blog post and
leave it as an exercise for the reader!
Let's take an example, and compare gradients we get against pytorch for
verification. Let's skip the ReLU
operation, since it zeroes the gradients
out for the given values.
Here is the pytorch code.
import torch
a = torch.tensor([-4.0])
a.requires_grad=True
b = torch.tensor([2.0])
b.requires_grad=True
c = a * b + 3
c.backward()
print(a.grad, b.grad)
# tensor([2.]) tensor([-4.])
And here is the GHCi session.
ghci> a :: Op = -4
ghci> b :: Op = 2
ghci> c = a * b + Constant 3
ghci> backward d 1.0
[(Var 4.0,-2.0),(Var 2.0,-4.0)]
We get the same gradients as pytorch! Hooray!
Some notes:
- We needed to explicitly use
Constant
for the3
. Without this our typeclass would have converted into aVar
, and we would compute gradients for it as well. - The first element in the tuple is
Var 4.0
instead ofVar -4.0
. If weshow a
we see that it isNegate (Var 4.0)
. The negation is captured in the tree structure, rather than in the value of the Var! Although the gradients are correct, this may not be what we want. Fixing this is left as yet another exercise for the reader.
That's it for this post, thank you for reading! You can play around with all code needed in this blog post here.