935-tree-map-fold — Tree Map and Fold
Tutorial
The Problem
Map and fold over lists generalize naturally to trees. Tree map preserves structure while transforming values; tree fold collapses the tree into a single value. Once fold_tree is defined, all aggregate operations — size, depth, sum, flatten to list — can be expressed without any additional explicit recursion. This is the principle behind the Foldable typeclass and catamorphisms: define one fold, derive everything else. OCaml's standard List.map and List.fold_left extend to trees through the same pattern. This example demonstrates the full power of a well-designed fold.
🎯 Learning Outcomes
map_tree that preserves tree structure while transforming each node valuefold_tree (catamorphism) that generalizes all tree reductionssize, depth, sum, flatten from fold_tree without additional recursionCode Example
#![allow(clippy::all)]
/// Map and Fold on Trees
///
/// Lifting map and fold from lists to binary trees. Once you define
/// `fold_tree`, you can express size, depth, sum, and traversals
/// without any explicit recursion — the fold does it all.
#[derive(Debug, Clone, PartialEq)]
pub enum Tree<T> {
Leaf,
Node(T, Box<Tree<T>>, Box<Tree<T>>),
}
impl<T> Tree<T> {
pub fn node(v: T, l: Tree<T>, r: Tree<T>) -> Self {
Tree::Node(v, Box::new(l), Box::new(r))
}
}
/// Map a function over every node value, producing a new tree.
pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
match tree {
Tree::Leaf => Tree::Leaf,
Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
}
}
/// Fold (catamorphism) on a tree. The function `f` receives the node value
/// and the results of folding the left and right subtrees.
pub fn fold_tree<T, A>(tree: &Tree<T>, acc: A, f: &impl Fn(&T, A, A) -> A) -> A
where
A: Clone,
{
match tree {
Tree::Leaf => acc,
Tree::Node(v, l, r) => {
let left = fold_tree(l, acc.clone(), f);
let right = fold_tree(r, acc, f);
f(v, left, right)
}
}
}
/// All derived via fold — no explicit recursion needed.
pub fn size<T>(t: &Tree<T>) -> usize {
fold_tree(t, 0, &|_, l, r| 1 + l + r)
}
pub fn depth<T>(t: &Tree<T>) -> usize {
fold_tree(t, 0, &|_, l, r| 1 + l.max(r))
}
pub fn sum(t: &Tree<i32>) -> i32 {
fold_tree(t, 0, &|v, l, r| v + l + r)
}
pub fn preorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
fold_tree(t, vec![], &|v, l, r| {
let mut result = vec![v.clone()];
result.extend(l);
result.extend(r);
result
})
}
pub fn inorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
fold_tree(t, vec![], &|v, l, r| {
let mut result = l;
result.push(v.clone());
result.extend(r);
result
})
}
#[cfg(test)]
mod tests {
use super::*;
use Tree::*;
fn sample() -> Tree<i32> {
// 4
// / \
// 2 6
// / \
// 1 3
Tree::node(
4,
Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
Tree::node(6, Leaf, Leaf),
)
}
#[test]
fn test_size() {
assert_eq!(size(&sample()), 5);
assert_eq!(size::<i32>(&Leaf), 0);
}
#[test]
fn test_depth() {
assert_eq!(depth(&sample()), 3);
assert_eq!(depth::<i32>(&Leaf), 0);
}
#[test]
fn test_sum() {
assert_eq!(sum(&sample()), 16);
assert_eq!(sum(&Leaf), 0);
}
#[test]
fn test_preorder() {
assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
}
#[test]
fn test_inorder() {
assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
}
#[test]
fn test_map_tree() {
let doubled = map_tree(&sample(), &|v| v * 2);
assert_eq!(sum(&doubled), 32);
assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
}
#[test]
fn test_single_node() {
let t = Tree::node(42, Leaf, Leaf);
assert_eq!(size(&t), 1);
assert_eq!(sum(&t), 42);
assert_eq!(preorder(&t), vec![42]);
}
}Key Differences
f: &impl Fn(&T, A, A) -> A takes (value, left_result, right_result); OCaml typically uses f left_result value right_result or varies by convention.fold_tree requires A: Clone to pass acc to both subtrees; OCaml passes the same immutable value to both branches without cloning.Node(T, Box<Tree<T>>, Box<Tree<T>>) requires explicit Box; OCaml Node of 'a tree * 'a * 'a tree is implicit.OCaml Approach
let rec map_tree f = function | Leaf -> Leaf | Node(v, l, r) -> Node(f v, map_tree f l, map_tree f r). let rec fold_tree leaf node = function | Leaf -> leaf | Node(v, l, r) -> node (fold_tree leaf node l) v (fold_tree leaf node r). Derived: let size t = fold_tree 0 (fun l _ r -> 1 + l + r) t. let depth t = fold_tree 0 (fun l _ r -> 1 + max l r) t. The OCaml version is slightly more concise due to currying and the function keyword.
Full Source
#![allow(clippy::all)]
/// Map and Fold on Trees
///
/// Lifting map and fold from lists to binary trees. Once you define
/// `fold_tree`, you can express size, depth, sum, and traversals
/// without any explicit recursion — the fold does it all.
#[derive(Debug, Clone, PartialEq)]
pub enum Tree<T> {
Leaf,
Node(T, Box<Tree<T>>, Box<Tree<T>>),
}
impl<T> Tree<T> {
pub fn node(v: T, l: Tree<T>, r: Tree<T>) -> Self {
Tree::Node(v, Box::new(l), Box::new(r))
}
}
/// Map a function over every node value, producing a new tree.
pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
match tree {
Tree::Leaf => Tree::Leaf,
Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
}
}
/// Fold (catamorphism) on a tree. The function `f` receives the node value
/// and the results of folding the left and right subtrees.
pub fn fold_tree<T, A>(tree: &Tree<T>, acc: A, f: &impl Fn(&T, A, A) -> A) -> A
where
A: Clone,
{
match tree {
Tree::Leaf => acc,
Tree::Node(v, l, r) => {
let left = fold_tree(l, acc.clone(), f);
let right = fold_tree(r, acc, f);
f(v, left, right)
}
}
}
/// All derived via fold — no explicit recursion needed.
pub fn size<T>(t: &Tree<T>) -> usize {
fold_tree(t, 0, &|_, l, r| 1 + l + r)
}
pub fn depth<T>(t: &Tree<T>) -> usize {
fold_tree(t, 0, &|_, l, r| 1 + l.max(r))
}
pub fn sum(t: &Tree<i32>) -> i32 {
fold_tree(t, 0, &|v, l, r| v + l + r)
}
pub fn preorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
fold_tree(t, vec![], &|v, l, r| {
let mut result = vec![v.clone()];
result.extend(l);
result.extend(r);
result
})
}
pub fn inorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
fold_tree(t, vec![], &|v, l, r| {
let mut result = l;
result.push(v.clone());
result.extend(r);
result
})
}
#[cfg(test)]
mod tests {
use super::*;
use Tree::*;
fn sample() -> Tree<i32> {
// 4
// / \
// 2 6
// / \
// 1 3
Tree::node(
4,
Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
Tree::node(6, Leaf, Leaf),
)
}
#[test]
fn test_size() {
assert_eq!(size(&sample()), 5);
assert_eq!(size::<i32>(&Leaf), 0);
}
#[test]
fn test_depth() {
assert_eq!(depth(&sample()), 3);
assert_eq!(depth::<i32>(&Leaf), 0);
}
#[test]
fn test_sum() {
assert_eq!(sum(&sample()), 16);
assert_eq!(sum(&Leaf), 0);
}
#[test]
fn test_preorder() {
assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
}
#[test]
fn test_inorder() {
assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
}
#[test]
fn test_map_tree() {
let doubled = map_tree(&sample(), &|v| v * 2);
assert_eq!(sum(&doubled), 32);
assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
}
#[test]
fn test_single_node() {
let t = Tree::node(42, Leaf, Leaf);
assert_eq!(size(&t), 1);
assert_eq!(sum(&t), 42);
assert_eq!(preorder(&t), vec![42]);
}
}#[cfg(test)]
mod tests {
use super::*;
use Tree::*;
fn sample() -> Tree<i32> {
// 4
// / \
// 2 6
// / \
// 1 3
Tree::node(
4,
Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
Tree::node(6, Leaf, Leaf),
)
}
#[test]
fn test_size() {
assert_eq!(size(&sample()), 5);
assert_eq!(size::<i32>(&Leaf), 0);
}
#[test]
fn test_depth() {
assert_eq!(depth(&sample()), 3);
assert_eq!(depth::<i32>(&Leaf), 0);
}
#[test]
fn test_sum() {
assert_eq!(sum(&sample()), 16);
assert_eq!(sum(&Leaf), 0);
}
#[test]
fn test_preorder() {
assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
}
#[test]
fn test_inorder() {
assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
}
#[test]
fn test_map_tree() {
let doubled = map_tree(&sample(), &|v| v * 2);
assert_eq!(sum(&doubled), 32);
assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
}
#[test]
fn test_single_node() {
let t = Tree::node(42, Leaf, Leaf);
assert_eq!(size(&t), 1);
assert_eq!(sum(&t), 42);
assert_eq!(preorder(&t), vec![42]);
}
}
Deep Comparison
Map and Fold on Trees: OCaml vs Rust
The Core Insight
Once you can fold a tree, you can express almost any tree computation as a one-liner. This example shows how the catamorphism pattern — replacing constructors with functions — works identically in both languages, but Rust's ownership model adds friction around accumulator cloning and closure references.
OCaml Approach
OCaml's fold_tree takes a function f : 'a -> 'b -> 'b -> 'b and a base value, recursing over the tree structure. Thanks to currying, size = fold_tree (fun _ l r -> 1 + l + r) 0 reads cleanly. The GC handles all intermediate lists created by preorder/inorder — [v] @ l @ r allocates freely. Pattern matching with function keyword keeps the code terse.
Rust Approach
Rust's fold_tree needs A: Clone because the accumulator must be passed to both subtrees — ownership can't be in two places at once. Closures are passed as &impl Fn(...) references to avoid ownership issues. The vec! macro and extend method replace OCaml's @ list append. The code is slightly more verbose but makes every allocation explicit.
Side-by-Side
| Concept | OCaml | Rust |
|---|---|---|
| Fold signature | ('a -> 'b -> 'b -> 'b) -> 'b -> 'a tree -> 'b | (&Tree<T>, A, &impl Fn(&T, A, A) -> A) -> A |
| Accumulator | Passed freely (GC) | Requires Clone bound |
| List append | @ operator | extend() method |
| Closure passing | Implicit currying | &impl Fn(...) reference |
| Derived operations | One-liners via fold | One-liners via fold |
| Memory | GC handles intermediates | Explicit Vec allocation |
What Rust Learners Should Notice
Clone bound on the accumulator is the price of ownership: both subtrees need their own copy of the base case&impl Fn(...) avoids taking ownership of the closure, so fold_tree can call it multiple timesvec![] + extend is the idiomatic way to build up collections, replacing OCaml's @ list concatenationFurther Reading
Exercises
flatten_inorder(t: &Tree<T>) -> Vec<&T> using fold_tree with a Vec accumulator.count_leaves(t: &Tree<T>) -> usize and count_inner_nodes(t: &Tree<T>) -> usize using fold_tree.mirror(t: Tree<T>) -> Tree<T> using map_tree and node reconstruction that swaps left and right subtrees.