$ State management in Haskell

[October 4, 2025]

I have been building a tiny neural network library in Haskell. One of the requirements there is to be able to update the weights based on some gradients. In other programming languages, this can be done by mutating the weight variables. However, in Haskell, we try to keep things pure until the outermost layers, also sometimes referred to as "functional core, imperative shell".

Motivating example

Let's take a simple example where we need to update a single weight based on a gradient.

{-# OPTIONS_GHC -Wall #-}

type Weight = Double

type Gradient = Double

type LearningRate = Double

learningRate :: LearningRate
learningRate = 0.1

backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation

updateWeight :: Weight -> Gradient -> Weight
updateWeight oldWeight gradient = oldWeight - learningRate * gradient

train :: Weight -> Int -> Weight
train oldWeight numEpochs = foldl step oldWeight [1 .. numEpochs]
  where
    step w _ =
      let gradient = backward w
       in updateWeight w gradient

main :: IO ()
main = do
  let initialWeight = 0.5
  let finalWeight = train initialWeight 10
  putStrLn $ "Final weight: " ++ show finalWeight

Notice that functions which need to update the weight take the current weight as an argument and return the updated weight.

The biggest downside of this approach is that we rely on the calling functions to manage the state correctly, and pass in the right weight. The actual update management is not encapsulated. The function could be better named calculateNewWeight, rather than updateWeight.

How do we encapsulate the state management?

The State Monad

Info

Here is an excellent and visual tutorial on monads in Haskell.

Haskell provides the State monad to encapsulate stateful computations. Fire up ghci and load the Control.Monad.State module.

import Control.Monad.State

Let's look up at a few definitions.

ghci> :info State
type State :: * -> * -> *
type State s = StateT s Data.Functor.Identity.Identity :: * -> *
        -- Defined in ‘Control.Monad.Trans.State.Lazy’

If you have not seen the * -> * -> * notation before, it means that State is a type constructor that takes two type parameters. The first parameter is the type of the state, and the second parameter is the type of the "result".

Info

Technically, State is a "monad transformer" (StateT) - the underlying monad is Identity. Discussion of this is out-of-scope for this post.

Let's take a simple example to start with.

import Control.Monad.State

-- The type of state is Int
-- The type of result is String
doubleState :: State Int String
doubleState = do
  n <- get -- Get the current state
  put (n * 2) -- Update the state to n * 2
  pure ("Doubled the state to " ++ show (n * 2)) -- Return a result


-- Different state monad, where we don't have a result
-- No need for a do block here
-- No need for a pure return either, modify already returns a State Int ()
doubleState' :: State Int ()
doubleState' = modify (* 2) -- Modify the state by applying a function

-- To run this
main :: IO ()
main = do
  -- Note that runState returns the result first, and the final state second
  -- (reversed from the type parameters)
  let (result, finalState) = runState doubleState 10
  putStrLn $ "Result: " ++ result
  putStrLn $ "Final state: " ++ show finalState

  let finalState' = execState doubleState' 100
  putStrLn $ "Final state from doubleState': " ++ show finalState'

Taking a look at each of these functions:

  1. get retrieves the current state.
  2. put updates the state to a new value.
  3. pure wraps the result in the State monad.
  4. runState runs the stateful computation with an initial state and returns a tuple of the result and the final state.
  5. modify applies a function to the current state to produce a new state. This is an alternative to using get and put separately.
  6. execState runs the stateful computation and returns only the final state.

There is also evalState if you only need the result, and not the state.

Info

The do block is syntactic sugar for chaining monadic operations. For example, you could implement doubleState without a do block as follows:

doubleState :: State Int String
doubleState = get >>= \n -> put (n * 2) >> pure ("Doubled the state to " ++ show (n * 2))

Refactoring the training example

In our case, we don't care about the result of each training step, we just want to update the weight. So we can use State Weight () as our monadic type. Here is the equivalent code using the State monad.

{-# OPTIONS_GHC -Wall #-}

import Control.Monad (replicateM_)
import Control.Monad.State

type Weight = Double

type Gradient = Double

type LearningRate = Double

learningRate :: LearningRate
learningRate = 0.1

backward :: Weight -> Gradient
backward _ = 0.1 -- Dummy gradient computation

updateWeight :: Gradient -> State Weight ()
updateWeight gradient = modify (\w -> w - learningRate * gradient)

train :: Int -> State Weight ()
train numEpochs = replicateM_ numEpochs $ do
  weight <- get
  let gradient = backward weight
  updateWeight gradient

main :: IO ()
main = do
  let initialWeight = 0.5
  let finalWeight = execState (train 10) initialWeight
  putStrLn $ "Final weight: " ++ show finalWeight

We used a new function replicateM_ from Control.Monad to repeat a monadic action, and discard the results.

Now that we understand how the State monad works, we need to figure out how to integrate it with the rest of the library - slated for a future post!