$ State management in Haskell
I have been building a tiny neural network library in Haskell. One of the requirements there is to be able to update the weights based on some gradients. In other programming languages, this can be done by mutating the weight variables. However, in Haskell, we try to keep things pure until the outermost layers, also sometimes referred to as "functional core, imperative shell".
Motivating example
Let's take a simple example where we need to update a single weight based on a gradient.
{-# OPTIONS_GHC -Wall #-}
type Weight = Double
type Gradient = Double
type LearningRate = Double
learningRate :: LearningRate
learningRate = 0.1
backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation
updateWeight :: Weight -> Gradient -> Weight
updateWeight oldWeight gradient = oldWeight - learningRate * gradient
train :: Weight -> Int -> Weight
train oldWeight numEpochs = foldl step oldWeight [1 .. numEpochs]
where
step w _ =
let gradient = backward w
in updateWeight w gradient
main :: IO ()
main = do
let initialWeight = 0.5
let finalWeight = train initialWeight 10
putStrLn $ "Final weight: " ++ show finalWeight
Notice that functions which need to update the weight take the current weight as an argument and return the updated weight.
The biggest downside of this approach is that we rely on the calling functions
to manage the state correctly, and pass in the right weight. The actual update
management is not encapsulated. The function could be better named
calculateNewWeight, rather than updateWeight.
How do we encapsulate the state management?
The State Monad
Here is an excellent and visual tutorial on monads in Haskell.
Haskell provides the State monad to encapsulate stateful computations. Fire
up ghci and load the Control.Monad.State module.
import Control.Monad.State
Let's look up at a few definitions.
ghci> :info State
type State :: * -> * -> *
type State s = StateT s Data.Functor.Identity.Identity :: * -> *
-- Defined in ‘Control.Monad.Trans.State.Lazy’
If you have not seen the * -> * -> * notation before, it means that State is a
type constructor that takes two type parameters. The first parameter is the type
of the state, and the second parameter is the type of the "result".
The second line in the definition says that, State is an instance of the
"monad transformer" (StateT), with the underlying monad being Identity. 1
Discussion of this is out-of-scope for this post.
Let's take a simple example to start with.
import Control.Monad.State
-- The type of state is Int
-- The type of result is String
doubleState :: State Int String
doubleState = do
n <- get -- Get the current state
put (n * 2) -- Update the state to n * 2
pure ("Doubled the state to " ++ show (n * 2)) -- Return a result
-- Different state monad, where we don't have a result
-- No need for a do block here
-- No need for a pure return either, modify already returns a State Int ()
doubleState' :: State Int ()
doubleState' = modify (* 2) -- Modify the state by applying a function
-- To run this
main :: IO ()
main = do
-- Note that runState returns the result first, and the final state second
-- (reversed from the type parameters)
let (result, finalState) = runState doubleState 10
putStrLn $ "Result: " ++ result
putStrLn $ "Final state: " ++ show finalState
let finalState' = execState doubleState' 100
putStrLn $ "Final state from doubleState': " ++ show finalState'
Taking a look at each of these functions:
getretrieves the current state.putupdates the state to a new value.purewraps the result in theStatemonad.runStateruns the stateful computation with an initial state and returns a tuple of the result and the final state.modifyapplies a function to the current state to produce a new state. This is an alternative to usinggetandputseparately.execStateruns the stateful computation and returns only the final state.
There is also evalState if you only need the result, and not the state.
The do block is syntactic sugar for chaining monadic operations. For
example, you could implement doubleState without a do block as follows:
doubleState :: State Int String
doubleState = get >>= \n -> put (n * 2) >> pure ("Doubled the state to " ++ show (n * 2))
Refactoring the training example
In our case, we don't care about the result of each training step, we just
want to update the weight. So we can use State Weight () as our monadic type.
Here is the equivalent code using the State monad.
{-# OPTIONS_GHC -Wall #-}
import Control.Monad (replicateM_)
import Control.Monad.State
type Weight = Double
type Gradient = Double
type LearningRate = Double
learningRate :: LearningRate
learningRate = 0.1
backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation
updateWeight :: Gradient -> State Weight ()
updateWeight gradient = modify (\w -> w - learningRate * gradient)
train :: Int -> State Weight ()
train numEpochs = replicateM_ numEpochs $ do
weight <- get
let gradient = backward weight
updateWeight gradient
main :: IO ()
main = do
let initialWeight = 0.5
let finalWeight = execState (train 10) initialWeight
putStrLn $ "Final weight: " ++ show finalWeight
We used a new function replicateM_ from Control.Monad to repeat a monadic
action, and discard the results.
Now that we understand how the State monad works, we need to figure out how
to integrate it with the rest of the library - slated for a future post!
Footnotes
-
Corrected on 2024-10-05: Original post said that
Stateitself is a monad transformer. ↩