ExamplesBy LevelBy TopicLearning Paths
859 Expert

State Monad

Functional Programming

Tutorial

The Problem

Threading state through a sequence of functions without the State monad requires passing the state explicitly as an argument and returning it alongside the result: fn step(input: T, state: S) -> (R, S). This is error-prone and noisy. The State monad encapsulates this threading: State<S, A> represents a computation S -> (A, S) that reads and modifies state. Computations are composed without explicit state passing — the monad handles threading. This pattern appears in: compiler passes (threading symbol tables), game state machines, configuration accumulation, and embedded DSLs. It makes stateful computation composable and testable while remaining purely functional.

🎯 Learning Outcomes

  • • Understand State<S, A> as a wrapper around FnOnce(S) -> (A, S)
  • • Implement get() returning current state, put(s) replacing state, modify(f) transforming state
  • • Implement monadic bind: state.then(|a| next_state) threading state through both computations
  • • Use run_state(initial) to execute the computation and get (result, final_state)
  • • Recognize the tension with Rust's ownership: FnOnce vs Fn based on state mutation needs
  • Code Example

    struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }

    Key Differences

    AspectRustOCaml
    TypeBox<dyn FnOnce(S) -> (A, S)>State of ('s -> 'a * 's)
    FnOnce vs FnMust choose based on usefun s -> ... (always Fn)
    Bind implementationComplex with boxingClean algebraic unwrap
    getRequires S: CloneSame (returns clone in pure)
    Thread safetySend + Sync bounds neededNot applicable (single-threaded)
    'static boundRequired for boxed closuresNot required

    OCaml Approach

    OCaml represents State as type ('s, 'a) state = State of ('s -> 'a * 's). The run_state (State f) s = f s. Monadic bind: let bind (State f) k = State (fun s -> let (a, s') = f s in let State g = k a in g s'). get = State (fun s -> (s, s)), put s = State (fun _ -> ((), s)). OCaml's algebraic types make the State monad clean and readable. The ppx_let extension provides let%bind syntax for threading state.

    Full Source

    #![allow(clippy::all)]
    // Example 060: State Monad
    // Thread state through computations without explicit passing
    
    // State monad: S -> (A, S)
    struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }
    
    impl<S: 'static, A: 'static> State<S, A> {
        fn new(f: impl FnOnce(S) -> (A, S) + 'static) -> Self {
            State { run: Box::new(f) }
        }
    
        fn run(self, s: S) -> (A, S) {
            (self.run)(s)
        }
    
        fn pure(a: A) -> Self {
            State::new(move |s| (a, s))
        }
    
        fn and_then<B: 'static>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
            State::new(move |s| {
                let (a, s2) = self.run(s);
                f(a).run(s2)
            })
        }
    
        fn map<B: 'static>(self, f: impl FnOnce(A) -> B + 'static) -> State<S, B> {
            State::new(move |s| {
                let (a, s2) = self.run(s);
                (f(a), s2)
            })
        }
    }
    
    fn get<S: Clone + 'static>() -> State<S, S> {
        State::new(|s: S| (s.clone(), s))
    }
    
    fn put<S: 'static>(new_s: S) -> State<S, ()> {
        State::new(move |_| ((), new_s))
    }
    
    fn modify<S: 'static>(f: impl FnOnce(S) -> S + 'static) -> State<S, ()> {
        State::new(move |s| ((), f(s)))
    }
    
    // Approach 1: Counter
    fn tick() -> State<i32, i32> {
        get::<i32>().and_then(|n| put(n + 1).map(move |()| n))
    }
    
    fn count3() -> State<i32, (i32, i32, i32)> {
        tick().and_then(|a| tick().and_then(move |b| tick().map(move |c| (a, b, c))))
    }
    
    // Approach 2: Explicit state threading (no State monad — idiomatic Rust)
    fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
        let a = state;
        let state = state + 1;
        let b = state;
        let state = state + 1;
        let c = state;
        let state = state + 1;
        ((a, b, c), state)
    }
    
    // Approach 3: Stack operations
    fn push(x: i32) -> State<Vec<i32>, ()> {
        modify(move |mut stack: Vec<i32>| {
            stack.push(x);
            stack
        })
    }
    
    fn pop() -> State<Vec<i32>, Option<i32>> {
        State::new(|mut stack: Vec<i32>| {
            let val = stack.pop();
            (val, stack)
        })
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_counter() {
            let (result, state) = count3().run(0);
            assert_eq!(result, (0, 1, 2));
            assert_eq!(state, 3);
        }
    
        #[test]
        fn test_counter_nonzero_start() {
            let (result, state) = count3().run(10);
            assert_eq!(result, (10, 11, 12));
            assert_eq!(state, 13);
        }
    
        #[test]
        fn test_explicit_same_as_monadic() {
            let (r1, s1) = count3().run(0);
            let (r2, s2) = count3_explicit(0);
            assert_eq!(r1, r2);
            assert_eq!(s1, s2);
        }
    
        #[test]
        fn test_stack_push_pop() {
            let ops = push(10).and_then(|()| push(20)).and_then(|()| pop());
            let (val, stack) = ops.run(vec![]);
            assert_eq!(val, Some(20));
            assert_eq!(stack, vec![10]);
        }
    
        #[test]
        fn test_stack_pop_empty() {
            let (val, stack) = pop().run(vec![]);
            assert_eq!(val, None);
            assert_eq!(stack, Vec::<i32>::new());
        }
    
        #[test]
        fn test_pure() {
            let (val, state) = State::<i32, _>::pure(42).run(0);
            assert_eq!(val, 42);
            assert_eq!(state, 0);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_counter() {
            let (result, state) = count3().run(0);
            assert_eq!(result, (0, 1, 2));
            assert_eq!(state, 3);
        }
    
        #[test]
        fn test_counter_nonzero_start() {
            let (result, state) = count3().run(10);
            assert_eq!(result, (10, 11, 12));
            assert_eq!(state, 13);
        }
    
        #[test]
        fn test_explicit_same_as_monadic() {
            let (r1, s1) = count3().run(0);
            let (r2, s2) = count3_explicit(0);
            assert_eq!(r1, r2);
            assert_eq!(s1, s2);
        }
    
        #[test]
        fn test_stack_push_pop() {
            let ops = push(10).and_then(|()| push(20)).and_then(|()| pop());
            let (val, stack) = ops.run(vec![]);
            assert_eq!(val, Some(20));
            assert_eq!(stack, vec![10]);
        }
    
        #[test]
        fn test_stack_pop_empty() {
            let (val, stack) = pop().run(vec![]);
            assert_eq!(val, None);
            assert_eq!(stack, Vec::<i32>::new());
        }
    
        #[test]
        fn test_pure() {
            let (val, state) = State::<i32, _>::pure(42).run(0);
            assert_eq!(val, 42);
            assert_eq!(state, 0);
        }
    }

    Deep Comparison

    Comparison: State Monad

    State Type

    OCaml:

    type ('s, 'a) state = State of ('s -> 'a * 's)
    let run_state (State f) s = f s
    

    Rust:

    struct State<S, A> {
        run: Box<dyn FnOnce(S) -> (A, S)>,
    }
    

    Bind / and_then

    OCaml:

    let bind m f = State (fun s ->
      let (a, s') = run_state m s in
      run_state (f a) s')
    

    Rust:

    fn and_then<B>(self, f: impl FnOnce(A) -> State<S, B> + 'static) -> State<S, B> {
        State::new(move |s| {
            let (a, s2) = self.run(s);
            f(a).run(s2)
        })
    }
    

    Counter Example

    OCaml:

    let tick = get >>= fun n -> put (n + 1) >>= fun () -> return_ n
    let (result, _) = run_state (tick >>= fun a -> tick >>= fun b -> return_ (a, b)) 0
    (* result = (0, 1) *)
    

    Rust (idiomatic — no monad needed):

    fn count3_explicit(state: i32) -> ((i32, i32, i32), i32) {
        let a = state; let state = state + 1;
        let b = state; let state = state + 1;
        let c = state; let state = state + 1;
        ((a, b, c), state)
    }
    

    Exercises

  • Implement monadic bind for State<S, A> and use it to compose get, a transform, and put into a single computation.
  • Implement a stack using the State monad: push and pop operations as State<Vec<T>, Option<T>> computations.
  • Use the State monad to implement a simple counter that increments and returns the new count at each step.
  • Compare the State monad approach with explicit state threading: implement the same computation both ways.
  • Implement modify(f: S -> S) -> State<S, ()> using get and put and verify it equals State::new(|s| ((), f(s))).
  • Open Source Repos