Continuation-Passing Style (CPS)
Tutorial Video
Text description (accessibility)
This video demonstrates the "Continuation-Passing Style (CPS)" functional Rust example. Difficulty level: Advanced. Key concepts covered: Functional Programming. Continuation-Passing Style (CPS) is a program transformation where instead of returning a value, a function passes its result to a callback (continuation). Key difference from OCaml: 1. **TCO**: OCaml optimizes tail calls — CPS factorial in OCaml is O(1) stack; Rust does not guarantee TCO, so heap
Tutorial
The Problem
Continuation-Passing Style (CPS) is a program transformation where instead of returning a value, a function passes its result to a callback (continuation). Every function takes an extra argument k: impl FnOnce(T) -> R and calls k(result) instead of returning. CPS has deep roots in compiler theory (CPS IR is used in LLVM, GHC, and OCaml's backend), enables explicit control flow manipulation, and is the foundation for implementing coroutines, async/await, generators, and exception handling. It also eliminates stack overflow in recursive functions.
🎯 Learning Outcomes
fact_k<R>(n, k: Box<dyn FnOnce(u64) -> R>) -> R passes results to continuationsCode Example
fn fact(n: u64) -> u64 {
if n <= 1 { 1 } else { n * fact(n - 1) }
}Key Differences
n.Box<dyn FnOnce(T) -> R> to store them dynamically.Box allocates one heap object per continuation frame; OCaml's GC manages continuation closures but also allocates.OCaml Approach
OCaml's CPS is natural and the OCaml compiler uses CPS as its intermediate representation:
let rec fact_k n k = if n <= 1 then k 1 else fact_k (n-1) (fun r -> k (n * r))
(* Usage: *)
let result = fact_k 10 (fun x -> x)
OCaml's tail-call optimization ensures CPS functions do not stack overflow — a key advantage over Rust where CPS chains of Box<dyn FnOnce> still allocate on the heap.
Full Source
#![allow(clippy::all)]
//! # Continuation-Passing Style (CPS)
//!
//! Transform direct-style functions to pass results to continuations.
/// Factorial in direct style (for comparison).
pub fn fact(n: u64) -> u64 {
if n <= 1 {
1
} else {
n * fact(n - 1)
}
}
/// Factorial in CPS - result passed to continuation k.
pub fn fact_k<R: 'static>(n: u64, k: Box<dyn FnOnce(u64) -> R>) -> R {
if n <= 1 {
k(1)
} else {
fact_k(n - 1, Box::new(move |r| k(n * r)))
}
}
/// Fibonacci in direct style.
pub fn fib(n: u64) -> u64 {
if n <= 1 {
n
} else {
fib(n - 1) + fib(n - 2)
}
}
/// Fibonacci in CPS (requires boxed closure for recursion).
pub fn fib_k<R: 'static>(n: u64, k: Box<dyn FnOnce(u64) -> R>) -> R {
if n <= 1 {
k(n)
} else {
fib_k(
n - 1,
Box::new(move |r1| fib_k(n - 2, Box::new(move |r2| k(r1 + r2)))),
)
}
}
/// Map in CPS style.
pub fn map_k<T, U, R>(items: Vec<T>, f: impl Fn(T) -> U + Clone, k: impl FnOnce(Vec<U>) -> R) -> R {
fn go<T, U, R>(
mut items: Vec<T>,
f: impl Fn(T) -> U + Clone,
mut acc: Vec<U>,
k: impl FnOnce(Vec<U>) -> R,
) -> R {
if items.is_empty() {
k(acc)
} else {
let head = items.remove(0);
let u = f(head);
acc.push(u);
go(items, f, acc, k)
}
}
go(items, f, Vec::new(), k)
}
/// Fold in CPS style.
pub fn fold_k<T, A, R>(
items: Vec<T>,
init: A,
f: impl Fn(A, T) -> A + Clone,
k: impl FnOnce(A) -> R,
) -> R {
fn go<T, A, R>(
mut items: Vec<T>,
acc: A,
f: impl Fn(A, T) -> A + Clone,
k: impl FnOnce(A) -> R,
) -> R {
if items.is_empty() {
k(acc)
} else {
let head = items.remove(0);
let new_acc = f(acc, head);
go(items, new_acc, f, k)
}
}
go(items, init, f, k)
}
/// Safe division with success and error continuations.
pub fn safe_div_k<R>(a: f64, b: f64, ok: impl FnOnce(f64) -> R, err: impl FnOnce(&str) -> R) -> R {
if b == 0.0 {
err("division by zero")
} else {
ok(a / b)
}
}
/// Parse integer with continuation for success/failure.
pub fn parse_int_k<R>(s: &str, ok: impl FnOnce(i64) -> R, err: impl FnOnce(&str) -> R) -> R {
match s.parse::<i64>() {
Ok(n) => ok(n),
Err(_) => err(s),
}
}
/// Chained CPS operations.
pub fn chain_example<R>(s: &str, k: impl FnOnce(f64) -> R, err: impl FnOnce(&str) -> R) -> R {
match s.parse::<i64>() {
Ok(n) => safe_div_k(100.0, n as f64, k, err),
Err(_) => err(s),
}
}
/// Identity continuation - extract value from CPS.
pub fn run<T>(f: impl FnOnce(Box<dyn FnOnce(T) -> T>) -> T) -> T {
f(Box::new(|x| x))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fact_direct() {
assert_eq!(fact(5), 120);
assert_eq!(fact(10), 3_628_800);
}
#[test]
fn test_fact_cps() {
fact_k(5, Box::new(|n| assert_eq!(n, 120)));
fact_k(10, Box::new(|n| assert_eq!(n, 3_628_800)));
}
#[test]
fn test_fact_equivalence() {
for n in 0..=12 {
let direct = fact(n);
fact_k(n, Box::new(move |cps| assert_eq!(cps, direct)));
}
}
#[test]
fn test_fib_direct() {
assert_eq!(fib(10), 55);
}
#[test]
fn test_fib_cps() {
fib_k(10, Box::new(|n| assert_eq!(n, 55)));
}
#[test]
fn test_map_k() {
map_k(
vec![1, 2, 3],
|x| x * 2,
|result| {
assert_eq!(result, vec![2, 4, 6]);
},
);
}
#[test]
fn test_map_k_empty() {
map_k(
Vec::<i32>::new(),
|x| x * 2,
|result| {
assert!(result.is_empty());
},
);
}
#[test]
fn test_fold_k() {
fold_k(
vec![1, 2, 3, 4],
0,
|acc, x| acc + x,
|sum| {
assert_eq!(sum, 10);
},
);
}
#[test]
fn test_safe_div_ok() {
safe_div_k(
10.0,
2.0,
|r| assert_eq!(r, 5.0),
|_| panic!("unexpected error"),
);
}
#[test]
fn test_safe_div_err() {
let mut got_error = false;
safe_div_k(
10.0,
0.0,
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
#[test]
fn test_parse_int_ok() {
parse_int_k("42", |n| assert_eq!(n, 42), |_| panic!("unexpected error"));
}
#[test]
fn test_parse_int_err() {
let mut got_error = false;
parse_int_k(
"abc",
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
#[test]
fn test_chain() {
// "10" -> parse to 10 -> 100/10 = 10.0
chain_example(
"10",
|r| assert_eq!(r, 10.0),
|_| panic!("unexpected error"),
);
}
#[test]
fn test_chain_div_zero() {
let mut got_error = false;
chain_example("0", |_| panic!("unexpected success"), |_| got_error = true);
assert!(got_error);
}
#[test]
fn test_chain_parse_error() {
let mut got_error = false;
chain_example(
"abc",
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
}#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fact_direct() {
assert_eq!(fact(5), 120);
assert_eq!(fact(10), 3_628_800);
}
#[test]
fn test_fact_cps() {
fact_k(5, Box::new(|n| assert_eq!(n, 120)));
fact_k(10, Box::new(|n| assert_eq!(n, 3_628_800)));
}
#[test]
fn test_fact_equivalence() {
for n in 0..=12 {
let direct = fact(n);
fact_k(n, Box::new(move |cps| assert_eq!(cps, direct)));
}
}
#[test]
fn test_fib_direct() {
assert_eq!(fib(10), 55);
}
#[test]
fn test_fib_cps() {
fib_k(10, Box::new(|n| assert_eq!(n, 55)));
}
#[test]
fn test_map_k() {
map_k(
vec![1, 2, 3],
|x| x * 2,
|result| {
assert_eq!(result, vec![2, 4, 6]);
},
);
}
#[test]
fn test_map_k_empty() {
map_k(
Vec::<i32>::new(),
|x| x * 2,
|result| {
assert!(result.is_empty());
},
);
}
#[test]
fn test_fold_k() {
fold_k(
vec![1, 2, 3, 4],
0,
|acc, x| acc + x,
|sum| {
assert_eq!(sum, 10);
},
);
}
#[test]
fn test_safe_div_ok() {
safe_div_k(
10.0,
2.0,
|r| assert_eq!(r, 5.0),
|_| panic!("unexpected error"),
);
}
#[test]
fn test_safe_div_err() {
let mut got_error = false;
safe_div_k(
10.0,
0.0,
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
#[test]
fn test_parse_int_ok() {
parse_int_k("42", |n| assert_eq!(n, 42), |_| panic!("unexpected error"));
}
#[test]
fn test_parse_int_err() {
let mut got_error = false;
parse_int_k(
"abc",
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
#[test]
fn test_chain() {
// "10" -> parse to 10 -> 100/10 = 10.0
chain_example(
"10",
|r| assert_eq!(r, 10.0),
|_| panic!("unexpected error"),
);
}
#[test]
fn test_chain_div_zero() {
let mut got_error = false;
chain_example("0", |_| panic!("unexpected success"), |_| got_error = true);
assert!(got_error);
}
#[test]
fn test_chain_parse_error() {
let mut got_error = false;
chain_example(
"abc",
|_| panic!("unexpected success"),
|_| got_error = true,
);
assert!(got_error);
}
}
Deep Comparison
OCaml vs Rust: Continuation-Passing Style
Direct vs CPS
OCaml Direct
let rec fact n = if n <= 1 then 1 else n * fact (n-1)
OCaml CPS
let rec fact_k n k =
if n <= 1 then k 1
else fact_k (n-1) (fun result -> k (n * result))
Rust Direct
fn fact(n: u64) -> u64 {
if n <= 1 { 1 } else { n * fact(n - 1) }
}
Rust CPS
fn fact_k<R>(n: u64, k: impl FnOnce(u64) -> R) -> R {
if n <= 1 {
k(1)
} else {
fact_k(n - 1, move |r| k(n * r))
}
}
Error Handling with CPS
OCaml
let safe_div a b ok err =
if b = 0.0 then err "division by zero"
else ok (a /. b)
Rust
fn safe_div_k<R>(
a: f64, b: f64,
ok: impl FnOnce(f64) -> R,
err: impl FnOnce(&str) -> R
) -> R {
if b == 0.0 { err("division by zero") }
else { ok(a / b) }
}
Key Differences
| Aspect | OCaml | Rust |
|---|---|---|
| Closure syntax | fun x -> ... | \|x\| ... or move \|x\| ... |
| Recursive CPS | Direct | May need Box<dyn FnOnce> |
| Type inference | Full | Generic bounds needed |
| Move semantics | Implicit | Explicit with move |
Why CPS?
Exercises
identity(x: i32) -> i32 and its CPS version identity_k<R>(x: i32, k: impl FnOnce(i32) -> R) -> R — verify they produce the same results.fn add(a: i32, b: i32) -> i32 to CPS and use it to build fn sum_list_k(items: &[i32], k: impl FnOnce(i32)) -> () that passes the total to k.fact_k to use an explicit stack of enum Cont { Identity, MultAndCall(u64, Box<Cont>) } instead of closures — this is defunctionalization.