$ State management in Haskell
I have been building a tiny neural network library in Haskell. One of the requirements there is to be able to update the weights based on some gradients. In other programming languages, this can be done by mutating the weight variables. However, in Haskell, we try to keep things pure until the outermost layers, also sometimes referred to as "functional core, imperative shell".
Motivating example
Let's take a simple example where we need to update a single weight based on a gradient.
{-# OPTIONS_GHC -Wall #-}
type Weight = Double
type Gradient = Double
type LearningRate = Double
learningRate :: LearningRate
learningRate = 0.1
backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation
updateWeight :: Weight -> Gradient -> Weight
updateWeight oldWeight gradient = oldWeight - learningRate * gradient
train :: Weight -> Int -> Weight
train oldWeight numEpochs = foldl step oldWeight [1 .. numEpochs]
where
step w _ =
let gradient = backward w
in updateWeight w gradient
main :: IO ()
main = do
let initialWeight = 0.5
let finalWeight = train initialWeight 10
putStrLn $ "Final weight: " ++ show finalWeight
Notice that functions which need to update the weight take the current weight as an argument and return the updated weight.
The biggest downside of this approach is that we rely on the calling functions
to manage the state correctly, and pass in the right weight. The actual update
management is not encapsulated. The function could be better named
calculateNewWeight
, rather than updateWeight
.
How do we encapsulate the state management?
The State
Monad
Here is an excellent and visual tutorial on monads in Haskell.
Haskell provides the State
monad to encapsulate stateful computations. Fire
up ghci
and load the Control.Monad.State
module.
import Control.Monad.State
Let's look up at a few definitions.
ghci> :info State
type State :: * -> * -> *
type State s = StateT s Data.Functor.Identity.Identity :: * -> *
-- Defined in ‘Control.Monad.Trans.State.Lazy’
If you have not seen the * -> * -> *
notation before, it means that State
is a
type constructor that takes two type parameters. The first parameter is the type
of the state, and the second parameter is the type of the "result".
Technically, State
is a "monad transformer" (StateT
) - the underlying
monad is Identity
. Discussion of this is out-of-scope for this post.
Let's take a simple example to start with.
import Control.Monad.State
-- The type of state is Int
-- The type of result is String
doubleState :: State Int String
doubleState = do
n <- get -- Get the current state
put (n * 2) -- Update the state to n * 2
pure ("Doubled the state to " ++ show (n * 2)) -- Return a result
-- Different state monad, where we don't have a result
-- No need for a do block here
-- No need for a pure return either, modify already returns a State Int ()
doubleState' :: State Int ()
doubleState' = modify (* 2) -- Modify the state by applying a function
-- To run this
main :: IO ()
main = do
-- Note that runState returns the result first, and the final state second
-- (reversed from the type parameters)
let (result, finalState) = runState doubleState 10
putStrLn $ "Result: " ++ result
putStrLn $ "Final state: " ++ show finalState
let finalState' = execState doubleState' 100
putStrLn $ "Final state from doubleState': " ++ show finalState'
Taking a look at each of these functions:
get
retrieves the current state.put
updates the state to a new value.pure
wraps the result in theState
monad.runState
runs the stateful computation with an initial state and returns a tuple of the result and the final state.modify
applies a function to the current state to produce a new state. This is an alternative to usingget
andput
separately.execState
runs the stateful computation and returns only the final state.
There is also evalState
if you only need the result, and not the state.
The do
block is syntactic sugar for chaining monadic operations. For example, you could implement doubleState
without a do
block as follows:
doubleState :: State Int String
doubleState = get >>= \n -> put (n * 2) >> pure ("Doubled the state to " ++ show (n * 2))
Refactoring the training example
In our case, we don't care about the result of each training step, we just
want to update the weight. So we can use State Weight ()
as our monadic type.
Here is the equivalent code using the State
monad.
{-# OPTIONS_GHC -Wall #-}
import Control.Monad (replicateM_)
import Control.Monad.State
type Weight = Double
type Gradient = Double
type LearningRate = Double
learningRate :: LearningRate
learningRate = 0.1
backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation
updateWeight :: Gradient -> State Weight ()
updateWeight gradient = modify (\w -> w - learningRate * gradient)
train :: Int -> State Weight ()
train numEpochs = replicateM_ numEpochs $ do
weight <- get
let gradient = backward weight
updateWeight gradient
main :: IO ()
main = do
let initialWeight = 0.5
let finalWeight = execState (train 10) initialWeight
putStrLn $ "Final weight: " ++ show finalWeight
We used a new function replicateM_
from Control.Monad
to repeat a monadic
action, and discard the results.
Now that we understand how the State
monad works, we need to figure out how
to integrate it with the rest of the library - slated for a future post!