ExamplesBy LevelBy TopicLearning Paths
869 Expert

869-continuation-monad — Continuation Monad

Functional Programming

Tutorial

The Problem

Continuation-passing style (CPS) is a program transformation where every function, instead of returning a value, accepts an extra "continuation" argument — a callback representing the rest of the computation. This transform makes control flow explicit and enables features that are otherwise impossible in a direct-style language: early exit, coroutines, and delimited continuations. Scheme pioneered first-class continuations with call/cc. The continuation monad wraps this pattern as a composable abstraction with the type Cont r a = (a -> r) -> r, fundamental to compiler intermediate representations and effect-system implementations.

🎯 Learning Outcomes

  • • Understand continuation-passing style and how it makes control flow explicit
  • • Implement the Cont monad type in Rust using Box<dyn Fn>
  • • Recognize CPS factorial as the canonical tail-recursive transformation
  • • Understand how callcc enables early exit from nested computations
  • • Compare Rust's heap-allocated closures with OCaml's uniform closure representation
  • Code Example

    #![allow(clippy::all)]
    /// Continuation Monad — Delimited Continuations in Rust
    ///
    /// The continuation monad wraps computations that pass their result to a callback.
    /// type Cont r a = (a -> r) -> r
    
    pub struct Cont<R, A> {
        run: Box<dyn Fn(Box<dyn Fn(A) -> R>) -> R>,
    }
    
    impl<R: 'static, A: 'static> Cont<R, A> {
        pub fn new(f: impl Fn(Box<dyn Fn(A) -> R>) -> R + 'static) -> Self {
            Cont { run: Box::new(f) }
        }
    
        pub fn run_cont(self, k: impl Fn(A) -> R + 'static) -> R {
            (self.run)(Box::new(k))
        }
    }
    
    /// Wrap a pure value in Cont: \k -> k(a)
    pub fn cont_return<R: 'static, A: Clone + 'static>(a: A) -> Cont<R, A> {
        Cont::new(move |k: Box<dyn Fn(A) -> R>| k(a.clone()))
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_cps_add() {
            let result = (|a: i32, b: i32, k: &dyn Fn(i32) -> i32| k(a + b))(3, 4, &|x| x);
            assert_eq!(result, 7);
        }
    
        #[test]
        fn test_factorial_cps() {
            fn fact(n: i32, k: &dyn Fn(i32) -> i32) -> i32 {
                if n <= 1 {
                    k(1)
                } else {
                    fact(n - 1, &|r| k(n * r))
                }
            }
            assert_eq!(fact(5, &|x| x), 120);
        }
    
        #[test]
        fn test_cont_return() {
            let c = cont_return::<i32, i32>(42);
            assert_eq!(c.run_cont(|x| x), 42);
            let c2 = cont_return::<i32, i32>(10);
            assert_eq!(c2.run_cont(|x| x * 2), 20);
        }
    }

    Key Differences

  • Closure representation: Rust requires Box<dyn Fn> for heap-allocated dynamic closures; OCaml closures are uniformly heap-allocated values.
  • callcc: Full callcc is not expressible in safe Rust; OCaml (and Scheme) support it natively via the runtime stack.
  • Lifetime of continuations: Rust must carefully manage 'static bounds on captured values; OCaml's GC handles lifetime automatically.
  • Composition: OCaml's >>= operator chains continuations cleanly; Rust requires explicit method calls or helper functions.
  • OCaml Approach

    OCaml represents Cont as a GADT-free algebraic type type ('a, 'r) cont = Cont of (('a -> 'r) -> 'r). The bind function composes two continuations, and callcc captures the current continuation, enabling early exit from List.fold_left. OCaml's let* syntax makes monadic chains readable. The OCaml version shows find_first_negative using callcc to exit the fold immediately when a negative is found — a pattern impossible with plain fold.

    Full Source

    #![allow(clippy::all)]
    /// Continuation Monad — Delimited Continuations in Rust
    ///
    /// The continuation monad wraps computations that pass their result to a callback.
    /// type Cont r a = (a -> r) -> r
    
    pub struct Cont<R, A> {
        run: Box<dyn Fn(Box<dyn Fn(A) -> R>) -> R>,
    }
    
    impl<R: 'static, A: 'static> Cont<R, A> {
        pub fn new(f: impl Fn(Box<dyn Fn(A) -> R>) -> R + 'static) -> Self {
            Cont { run: Box::new(f) }
        }
    
        pub fn run_cont(self, k: impl Fn(A) -> R + 'static) -> R {
            (self.run)(Box::new(k))
        }
    }
    
    /// Wrap a pure value in Cont: \k -> k(a)
    pub fn cont_return<R: 'static, A: Clone + 'static>(a: A) -> Cont<R, A> {
        Cont::new(move |k: Box<dyn Fn(A) -> R>| k(a.clone()))
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_cps_add() {
            let result = (|a: i32, b: i32, k: &dyn Fn(i32) -> i32| k(a + b))(3, 4, &|x| x);
            assert_eq!(result, 7);
        }
    
        #[test]
        fn test_factorial_cps() {
            fn fact(n: i32, k: &dyn Fn(i32) -> i32) -> i32 {
                if n <= 1 {
                    k(1)
                } else {
                    fact(n - 1, &|r| k(n * r))
                }
            }
            assert_eq!(fact(5, &|x| x), 120);
        }
    
        #[test]
        fn test_cont_return() {
            let c = cont_return::<i32, i32>(42);
            assert_eq!(c.run_cont(|x| x), 42);
            let c2 = cont_return::<i32, i32>(10);
            assert_eq!(c2.run_cont(|x| x * 2), 20);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_cps_add() {
            let result = (|a: i32, b: i32, k: &dyn Fn(i32) -> i32| k(a + b))(3, 4, &|x| x);
            assert_eq!(result, 7);
        }
    
        #[test]
        fn test_factorial_cps() {
            fn fact(n: i32, k: &dyn Fn(i32) -> i32) -> i32 {
                if n <= 1 {
                    k(1)
                } else {
                    fact(n - 1, &|r| k(n * r))
                }
            }
            assert_eq!(fact(5, &|x| x), 120);
        }
    
        #[test]
        fn test_cont_return() {
            let c = cont_return::<i32, i32>(42);
            assert_eq!(c.run_cont(|x| x), 42);
            let c2 = cont_return::<i32, i32>(10);
            assert_eq!(c2.run_cont(|x| x * 2), 20);
        }
    }

    Exercises

  • Implement cont_bind that sequences two Cont<R, A> values (monadic bind / >>=).
  • Implement a CPS Fibonacci that is stack-safe by expressing it as a continuation chain.
  • Use the continuation monad to implement a depth-limited tree search that exits early when the depth limit is reached.
  • Open Source Repos