$ Micrograd in Haskell
Hello world! This post is intended to be one of a series to help myself (and hopefully others) learn Haskell.
Let's build a minimal neural network library, inspired by Karpathy's excellent micrograd library.
Here is a quick example we can work with. We have two variables a
and b
,
which we multiply, add 3 to the result, and then apply the ReLU activation
function. In torch
, this can be expressed as:
a = torch.Tensor([-4.0])
b = torch.Tensor([2.0])
c = a * b + 3
d = c.relu()
Unexcitingly, this will yield d = 0
due to the ReLU activation, but I'd like
to keep a non-linearity in the example.
Automatic Differentiation
Autograd (a.k.a autodiff or backprop) is at the heart of many deep learning libraries. It is this basic high-school calculus component that is responsible for all the fancy AI stuff we see today. 1
We want to compute the gradients of d
with respect to
a
and b
. Why? Because we want to know how to change a
and b
to
move d
in a desired direction (often to minimize a loss function).
This is a repeated application of the chain rule from calculus. We can
determine the gradients of d
with c
, and then c
with a
and b
. We can
compute this if keep track of the operations we are performing.
Capturing structure through ADTs
In Haskell, we can use Algebraic Data Types (ADTs) to represent the operations
we perform on our variables. We can define a Op
type that can represent
numbers and operations on them.
Here is a guide to setup a Haskell development environment.
Let's fire up GHCi, the interactive Haskell environment, and define our Op
type. You can setup multi-line input section by using :{
and :}
.
data Op
= Constant Double
| Var Double
| ReLU Op
| Add Op Op
| Multiply Op Op
| Negate Op
deriving (Eq, Show)
To construct our expression d = ReLU(a * b + 3)
, we can do:
ghci> a = Var (-4)
ghci> b = Var 2
ghci> c = Add (Multiply a b) (Constant 3)
ghci> d = ReLU c
ghci> d
ReLU (Add (Multiply (Var (-4.0)) (Var 2.0)) (Constant 3.0))
Great! We have captured the structure of our computation. Let's see how we can make it more ergonomic.
Typeclasses
When defining c
, it is a bit tedious to write Add
and Multiply
. It is so
much easier in the python version. We can make our Op
type an instance of the
Num
typeclass.
It is not required for this tutorial, but if curious, you can read more about Haskell's typeclasses here
instance Num Op where
(+) = Add
(*) = Multiply
negate = Negate
fromInteger n = Var (fromInteger n)
It gives a warning because we have not implemented all the methods in the
Num
, but we can ignore that for now. Let's try our brand new Num
instance!
ghci> x = 4 + 3
ghci> x == Add (Var 4.0) (Var 3.0)
True
ghci> x == 7
True
Huh? How is x
both Add (Var 4.0) (Var 3.0)
and 7
? Let's check some types:
ghci> :type 3
3 :: Num a => a
ghci> :type x
x :: Num a => a
ghci> :type (==)
(==) :: Eq a => a -> a -> Bool
This starts to make sense. The literals 3
and 4
don't have a concrete type.
See the a
in Num a => a
. It means that 3
can be any type that is an
instance of the Num
typeclass. The same applies to x
.
When we do x == Add (Var 4.0) (Var 3.0)
, Haskell can infer that x
is of
type Op
from the right hand side, so it works.
Checkout all the typeclasses by doing the following:
ghci> :info Num
type Num :: * -> Constraint
class Num a where
(+) :: a -> a -> a
(-) :: a -> a -> a
(*) :: a -> a -> a
negate :: a -> a
abs :: a -> a
signum :: a -> a
fromInteger :: Integer -> a
{-# MINIMAL (+), (*), abs, signum, fromInteger, (negate | (-)) #-}
-- Defined in ‘GHC.Num’
instance [safe] Num Op -- Defined at <interactive>:75:10
instance Num Double -- Defined in ‘GHC.Float’
instance Num Float -- Defined in ‘GHC.Float’
instance Num Int -- Defined in ‘GHC.Num’
instance Num Integer -- Defined in ‘GHC.Num’
instance Num Word -- Defined in ‘GHC.Num’
You can see the instance Num Op
line there. But there are also other
instances.
We can coerce types in our case and build out the expression.
ghci> a :: Op = -4
ghci> b :: Op = 2
ghci> c = a * b + 3
ghci> d = ReLU c
ghci> d
ReLU (Add (Multiply (Negate (Var 4.0)) (Var 2.0)) (Var 3.0))
This is much nicer! One issue is that 3
is now a Var
instead of a
Constant
. We can fix that by writing it up like so c = a * b + Constant 3
,
but this is good enough for now.
This post turned out longer than I expected, so I'll continue in the next post where we will implement evaluation and backpropagation.
Footnotes
-
I am exaggerating a bit - here is an great explanation of some nuances in backpropagation. ↩