ExamplesBy LevelBy TopicLearning Paths
594 Advanced

Continuation-Passing Style (CPS)

Functional Programming

Tutorial Video

Text description (accessibility)

This video demonstrates the "Continuation-Passing Style (CPS)" functional Rust example. Difficulty level: Advanced. Key concepts covered: Functional Programming. Continuation-Passing Style (CPS) is a program transformation where instead of returning a value, a function passes its result to a callback (continuation). Key difference from OCaml: 1. **TCO**: OCaml optimizes tail calls — CPS factorial in OCaml is O(1) stack; Rust does not guarantee TCO, so heap

Tutorial

The Problem

Continuation-Passing Style (CPS) is a program transformation where instead of returning a value, a function passes its result to a callback (continuation). Every function takes an extra argument k: impl FnOnce(T) -> R and calls k(result) instead of returning. CPS has deep roots in compiler theory (CPS IR is used in LLVM, GHC, and OCaml's backend), enables explicit control flow manipulation, and is the foundation for implementing coroutines, async/await, generators, and exception handling. It also eliminates stack overflow in recursive functions.

🎯 Learning Outcomes

  • • How direct-style functions are transformed to CPS using continuation arguments
  • • How fact_k<R>(n, k: Box<dyn FnOnce(u64) -> R>) -> R passes results to continuations
  • • How CPS makes the call stack explicit as a chain of closures
  • • How CPS is related to trampolining for stack-safe recursion
  • • Where CPS appears: compilers (CPS IR), async/await desugaring, effect systems
  • Code Example

    fn fact(n: u64) -> u64 {
        if n <= 1 { 1 } else { n * fact(n - 1) }
    }

    Key Differences

  • TCO: OCaml optimizes tail calls — CPS factorial in OCaml is O(1) stack; Rust does not guarantee TCO, so heap-boxed continuations are needed for large n.
  • Continuation type: OCaml continuations are plain function values; Rust requires Box<dyn FnOnce(T) -> R> to store them dynamically.
  • Practical use: OCaml's compiler backend genuinely uses CPS IR; Rust's async/await desugars to state machines rather than CPS.
  • Allocation: Rust's CPS with Box allocates one heap object per continuation frame; OCaml's GC manages continuation closures but also allocates.
  • OCaml Approach

    OCaml's CPS is natural and the OCaml compiler uses CPS as its intermediate representation:

    let rec fact_k n k = if n <= 1 then k 1 else fact_k (n-1) (fun r -> k (n * r))
    (* Usage: *)
    let result = fact_k 10 (fun x -> x)
    

    OCaml's tail-call optimization ensures CPS functions do not stack overflow — a key advantage over Rust where CPS chains of Box<dyn FnOnce> still allocate on the heap.

    Full Source

    #![allow(clippy::all)]
    //! # Continuation-Passing Style (CPS)
    //!
    //! Transform direct-style functions to pass results to continuations.
    
    /// Factorial in direct style (for comparison).
    pub fn fact(n: u64) -> u64 {
        if n <= 1 {
            1
        } else {
            n * fact(n - 1)
        }
    }
    
    /// Factorial in CPS - result passed to continuation k.
    pub fn fact_k<R: 'static>(n: u64, k: Box<dyn FnOnce(u64) -> R>) -> R {
        if n <= 1 {
            k(1)
        } else {
            fact_k(n - 1, Box::new(move |r| k(n * r)))
        }
    }
    
    /// Fibonacci in direct style.
    pub fn fib(n: u64) -> u64 {
        if n <= 1 {
            n
        } else {
            fib(n - 1) + fib(n - 2)
        }
    }
    
    /// Fibonacci in CPS (requires boxed closure for recursion).
    pub fn fib_k<R: 'static>(n: u64, k: Box<dyn FnOnce(u64) -> R>) -> R {
        if n <= 1 {
            k(n)
        } else {
            fib_k(
                n - 1,
                Box::new(move |r1| fib_k(n - 2, Box::new(move |r2| k(r1 + r2)))),
            )
        }
    }
    
    /// Map in CPS style.
    pub fn map_k<T, U, R>(items: Vec<T>, f: impl Fn(T) -> U + Clone, k: impl FnOnce(Vec<U>) -> R) -> R {
        fn go<T, U, R>(
            mut items: Vec<T>,
            f: impl Fn(T) -> U + Clone,
            mut acc: Vec<U>,
            k: impl FnOnce(Vec<U>) -> R,
        ) -> R {
            if items.is_empty() {
                k(acc)
            } else {
                let head = items.remove(0);
                let u = f(head);
                acc.push(u);
                go(items, f, acc, k)
            }
        }
        go(items, f, Vec::new(), k)
    }
    
    /// Fold in CPS style.
    pub fn fold_k<T, A, R>(
        items: Vec<T>,
        init: A,
        f: impl Fn(A, T) -> A + Clone,
        k: impl FnOnce(A) -> R,
    ) -> R {
        fn go<T, A, R>(
            mut items: Vec<T>,
            acc: A,
            f: impl Fn(A, T) -> A + Clone,
            k: impl FnOnce(A) -> R,
        ) -> R {
            if items.is_empty() {
                k(acc)
            } else {
                let head = items.remove(0);
                let new_acc = f(acc, head);
                go(items, new_acc, f, k)
            }
        }
        go(items, init, f, k)
    }
    
    /// Safe division with success and error continuations.
    pub fn safe_div_k<R>(a: f64, b: f64, ok: impl FnOnce(f64) -> R, err: impl FnOnce(&str) -> R) -> R {
        if b == 0.0 {
            err("division by zero")
        } else {
            ok(a / b)
        }
    }
    
    /// Parse integer with continuation for success/failure.
    pub fn parse_int_k<R>(s: &str, ok: impl FnOnce(i64) -> R, err: impl FnOnce(&str) -> R) -> R {
        match s.parse::<i64>() {
            Ok(n) => ok(n),
            Err(_) => err(s),
        }
    }
    
    /// Chained CPS operations.
    pub fn chain_example<R>(s: &str, k: impl FnOnce(f64) -> R, err: impl FnOnce(&str) -> R) -> R {
        match s.parse::<i64>() {
            Ok(n) => safe_div_k(100.0, n as f64, k, err),
            Err(_) => err(s),
        }
    }
    
    /// Identity continuation - extract value from CPS.
    pub fn run<T>(f: impl FnOnce(Box<dyn FnOnce(T) -> T>) -> T) -> T {
        f(Box::new(|x| x))
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_fact_direct() {
            assert_eq!(fact(5), 120);
            assert_eq!(fact(10), 3_628_800);
        }
    
        #[test]
        fn test_fact_cps() {
            fact_k(5, Box::new(|n| assert_eq!(n, 120)));
            fact_k(10, Box::new(|n| assert_eq!(n, 3_628_800)));
        }
    
        #[test]
        fn test_fact_equivalence() {
            for n in 0..=12 {
                let direct = fact(n);
                fact_k(n, Box::new(move |cps| assert_eq!(cps, direct)));
            }
        }
    
        #[test]
        fn test_fib_direct() {
            assert_eq!(fib(10), 55);
        }
    
        #[test]
        fn test_fib_cps() {
            fib_k(10, Box::new(|n| assert_eq!(n, 55)));
        }
    
        #[test]
        fn test_map_k() {
            map_k(
                vec![1, 2, 3],
                |x| x * 2,
                |result| {
                    assert_eq!(result, vec![2, 4, 6]);
                },
            );
        }
    
        #[test]
        fn test_map_k_empty() {
            map_k(
                Vec::<i32>::new(),
                |x| x * 2,
                |result| {
                    assert!(result.is_empty());
                },
            );
        }
    
        #[test]
        fn test_fold_k() {
            fold_k(
                vec![1, 2, 3, 4],
                0,
                |acc, x| acc + x,
                |sum| {
                    assert_eq!(sum, 10);
                },
            );
        }
    
        #[test]
        fn test_safe_div_ok() {
            safe_div_k(
                10.0,
                2.0,
                |r| assert_eq!(r, 5.0),
                |_| panic!("unexpected error"),
            );
        }
    
        #[test]
        fn test_safe_div_err() {
            let mut got_error = false;
            safe_div_k(
                10.0,
                0.0,
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    
        #[test]
        fn test_parse_int_ok() {
            parse_int_k("42", |n| assert_eq!(n, 42), |_| panic!("unexpected error"));
        }
    
        #[test]
        fn test_parse_int_err() {
            let mut got_error = false;
            parse_int_k(
                "abc",
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    
        #[test]
        fn test_chain() {
            // "10" -> parse to 10 -> 100/10 = 10.0
            chain_example(
                "10",
                |r| assert_eq!(r, 10.0),
                |_| panic!("unexpected error"),
            );
        }
    
        #[test]
        fn test_chain_div_zero() {
            let mut got_error = false;
            chain_example("0", |_| panic!("unexpected success"), |_| got_error = true);
            assert!(got_error);
        }
    
        #[test]
        fn test_chain_parse_error() {
            let mut got_error = false;
            chain_example(
                "abc",
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_fact_direct() {
            assert_eq!(fact(5), 120);
            assert_eq!(fact(10), 3_628_800);
        }
    
        #[test]
        fn test_fact_cps() {
            fact_k(5, Box::new(|n| assert_eq!(n, 120)));
            fact_k(10, Box::new(|n| assert_eq!(n, 3_628_800)));
        }
    
        #[test]
        fn test_fact_equivalence() {
            for n in 0..=12 {
                let direct = fact(n);
                fact_k(n, Box::new(move |cps| assert_eq!(cps, direct)));
            }
        }
    
        #[test]
        fn test_fib_direct() {
            assert_eq!(fib(10), 55);
        }
    
        #[test]
        fn test_fib_cps() {
            fib_k(10, Box::new(|n| assert_eq!(n, 55)));
        }
    
        #[test]
        fn test_map_k() {
            map_k(
                vec![1, 2, 3],
                |x| x * 2,
                |result| {
                    assert_eq!(result, vec![2, 4, 6]);
                },
            );
        }
    
        #[test]
        fn test_map_k_empty() {
            map_k(
                Vec::<i32>::new(),
                |x| x * 2,
                |result| {
                    assert!(result.is_empty());
                },
            );
        }
    
        #[test]
        fn test_fold_k() {
            fold_k(
                vec![1, 2, 3, 4],
                0,
                |acc, x| acc + x,
                |sum| {
                    assert_eq!(sum, 10);
                },
            );
        }
    
        #[test]
        fn test_safe_div_ok() {
            safe_div_k(
                10.0,
                2.0,
                |r| assert_eq!(r, 5.0),
                |_| panic!("unexpected error"),
            );
        }
    
        #[test]
        fn test_safe_div_err() {
            let mut got_error = false;
            safe_div_k(
                10.0,
                0.0,
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    
        #[test]
        fn test_parse_int_ok() {
            parse_int_k("42", |n| assert_eq!(n, 42), |_| panic!("unexpected error"));
        }
    
        #[test]
        fn test_parse_int_err() {
            let mut got_error = false;
            parse_int_k(
                "abc",
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    
        #[test]
        fn test_chain() {
            // "10" -> parse to 10 -> 100/10 = 10.0
            chain_example(
                "10",
                |r| assert_eq!(r, 10.0),
                |_| panic!("unexpected error"),
            );
        }
    
        #[test]
        fn test_chain_div_zero() {
            let mut got_error = false;
            chain_example("0", |_| panic!("unexpected success"), |_| got_error = true);
            assert!(got_error);
        }
    
        #[test]
        fn test_chain_parse_error() {
            let mut got_error = false;
            chain_example(
                "abc",
                |_| panic!("unexpected success"),
                |_| got_error = true,
            );
            assert!(got_error);
        }
    }

    Deep Comparison

    OCaml vs Rust: Continuation-Passing Style

    Direct vs CPS

    OCaml Direct

    let rec fact n = if n <= 1 then 1 else n * fact (n-1)
    

    OCaml CPS

    let rec fact_k n k =
      if n <= 1 then k 1
      else fact_k (n-1) (fun result -> k (n * result))
    

    Rust Direct

    fn fact(n: u64) -> u64 {
        if n <= 1 { 1 } else { n * fact(n - 1) }
    }
    

    Rust CPS

    fn fact_k<R>(n: u64, k: impl FnOnce(u64) -> R) -> R {
        if n <= 1 {
            k(1)
        } else {
            fact_k(n - 1, move |r| k(n * r))
        }
    }
    

    Error Handling with CPS

    OCaml

    let safe_div a b ok err =
      if b = 0.0 then err "division by zero"
      else ok (a /. b)
    

    Rust

    fn safe_div_k<R>(
        a: f64, b: f64,
        ok: impl FnOnce(f64) -> R,
        err: impl FnOnce(&str) -> R
    ) -> R {
        if b == 0.0 { err("division by zero") }
        else { ok(a / b) }
    }
    

    Key Differences

    AspectOCamlRust
    Closure syntaxfun x -> ...\|x\| ... or move \|x\| ...
    Recursive CPSDirectMay need Box<dyn FnOnce>
    Type inferenceFullGeneric bounds needed
    Move semanticsImplicitExplicit with move

    Why CPS?

  • Explicit control flow - No hidden returns
  • Tail call optimization - All calls are tail calls
  • Error handling - Multiple continuations for success/error
  • Backtracking - Save continuation for later use
  • Exercises

  • CPS identity: Write a direct-style identity(x: i32) -> i32 and its CPS version identity_k<R>(x: i32, k: impl FnOnce(i32) -> R) -> R — verify they produce the same results.
  • CPS addition: Transform fn add(a: i32, b: i32) -> i32 to CPS and use it to build fn sum_list_k(items: &[i32], k: impl FnOnce(i32)) -> () that passes the total to k.
  • Defunctionalize: Transform the CPS fact_k to use an explicit stack of enum Cont { Identity, MultAndCall(u64, Box<Cont>) } instead of closures — this is defunctionalization.
  • Open Source Repos