ExamplesBy LevelBy TopicLearning Paths
935 Fundamental

935-tree-map-fold — Tree Map and Fold

Functional Programming

Tutorial

The Problem

Map and fold over lists generalize naturally to trees. Tree map preserves structure while transforming values; tree fold collapses the tree into a single value. Once fold_tree is defined, all aggregate operations — size, depth, sum, flatten to list — can be expressed without any additional explicit recursion. This is the principle behind the Foldable typeclass and catamorphisms: define one fold, derive everything else. OCaml's standard List.map and List.fold_left extend to trees through the same pattern. This example demonstrates the full power of a well-designed fold.

🎯 Learning Outcomes

  • • Implement map_tree that preserves tree structure while transforming each node value
  • • Implement fold_tree (catamorphism) that generalizes all tree reductions
  • • Derive size, depth, sum, flatten from fold_tree without additional recursion
  • • Understand why fold is the "universal" tree consumer
  • • Compare with OCaml's equivalent tree map and fold patterns
  • Code Example

    #![allow(clippy::all)]
    /// Map and Fold on Trees
    ///
    /// Lifting map and fold from lists to binary trees. Once you define
    /// `fold_tree`, you can express size, depth, sum, and traversals
    /// without any explicit recursion — the fold does it all.
    
    #[derive(Debug, Clone, PartialEq)]
    pub enum Tree<T> {
        Leaf,
        Node(T, Box<Tree<T>>, Box<Tree<T>>),
    }
    
    impl<T> Tree<T> {
        pub fn node(v: T, l: Tree<T>, r: Tree<T>) -> Self {
            Tree::Node(v, Box::new(l), Box::new(r))
        }
    }
    
    /// Map a function over every node value, producing a new tree.
    pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
        match tree {
            Tree::Leaf => Tree::Leaf,
            Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
        }
    }
    
    /// Fold (catamorphism) on a tree. The function `f` receives the node value
    /// and the results of folding the left and right subtrees.
    pub fn fold_tree<T, A>(tree: &Tree<T>, acc: A, f: &impl Fn(&T, A, A) -> A) -> A
    where
        A: Clone,
    {
        match tree {
            Tree::Leaf => acc,
            Tree::Node(v, l, r) => {
                let left = fold_tree(l, acc.clone(), f);
                let right = fold_tree(r, acc, f);
                f(v, left, right)
            }
        }
    }
    
    /// All derived via fold — no explicit recursion needed.
    pub fn size<T>(t: &Tree<T>) -> usize {
        fold_tree(t, 0, &|_, l, r| 1 + l + r)
    }
    
    pub fn depth<T>(t: &Tree<T>) -> usize {
        fold_tree(t, 0, &|_, l, r| 1 + l.max(r))
    }
    
    pub fn sum(t: &Tree<i32>) -> i32 {
        fold_tree(t, 0, &|v, l, r| v + l + r)
    }
    
    pub fn preorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
        fold_tree(t, vec![], &|v, l, r| {
            let mut result = vec![v.clone()];
            result.extend(l);
            result.extend(r);
            result
        })
    }
    
    pub fn inorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
        fold_tree(t, vec![], &|v, l, r| {
            let mut result = l;
            result.push(v.clone());
            result.extend(r);
            result
        })
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
        use Tree::*;
    
        fn sample() -> Tree<i32> {
            //      4
            //     / \
            //    2   6
            //   / \
            //  1   3
            Tree::node(
                4,
                Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
                Tree::node(6, Leaf, Leaf),
            )
        }
    
        #[test]
        fn test_size() {
            assert_eq!(size(&sample()), 5);
            assert_eq!(size::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_depth() {
            assert_eq!(depth(&sample()), 3);
            assert_eq!(depth::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_sum() {
            assert_eq!(sum(&sample()), 16);
            assert_eq!(sum(&Leaf), 0);
        }
    
        #[test]
        fn test_preorder() {
            assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
        }
    
        #[test]
        fn test_inorder() {
            assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
        }
    
        #[test]
        fn test_map_tree() {
            let doubled = map_tree(&sample(), &|v| v * 2);
            assert_eq!(sum(&doubled), 32);
            assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
        }
    
        #[test]
        fn test_single_node() {
            let t = Tree::node(42, Leaf, Leaf);
            assert_eq!(size(&t), 1);
            assert_eq!(sum(&t), 42);
            assert_eq!(preorder(&t), vec![42]);
        }
    }

    Key Differences

  • Fold argument order: Rust f: &impl Fn(&T, A, A) -> A takes (value, left_result, right_result); OCaml typically uses f left_result value right_result or varies by convention.
  • Cloning accumulator: Rust fold_tree requires A: Clone to pass acc to both subtrees; OCaml passes the same immutable value to both branches without cloning.
  • Recursive boxing: Rust Node(T, Box<Tree<T>>, Box<Tree<T>>) requires explicit Box; OCaml Node of 'a tree * 'a * 'a tree is implicit.
  • Derived operations: Both languages derive all operations from fold with equal elegance — the one-liner pattern works in both.
  • OCaml Approach

    let rec map_tree f = function | Leaf -> Leaf | Node(v, l, r) -> Node(f v, map_tree f l, map_tree f r). let rec fold_tree leaf node = function | Leaf -> leaf | Node(v, l, r) -> node (fold_tree leaf node l) v (fold_tree leaf node r). Derived: let size t = fold_tree 0 (fun l _ r -> 1 + l + r) t. let depth t = fold_tree 0 (fun l _ r -> 1 + max l r) t. The OCaml version is slightly more concise due to currying and the function keyword.

    Full Source

    #![allow(clippy::all)]
    /// Map and Fold on Trees
    ///
    /// Lifting map and fold from lists to binary trees. Once you define
    /// `fold_tree`, you can express size, depth, sum, and traversals
    /// without any explicit recursion — the fold does it all.
    
    #[derive(Debug, Clone, PartialEq)]
    pub enum Tree<T> {
        Leaf,
        Node(T, Box<Tree<T>>, Box<Tree<T>>),
    }
    
    impl<T> Tree<T> {
        pub fn node(v: T, l: Tree<T>, r: Tree<T>) -> Self {
            Tree::Node(v, Box::new(l), Box::new(r))
        }
    }
    
    /// Map a function over every node value, producing a new tree.
    pub fn map_tree<T, U>(tree: &Tree<T>, f: &impl Fn(&T) -> U) -> Tree<U> {
        match tree {
            Tree::Leaf => Tree::Leaf,
            Tree::Node(v, l, r) => Tree::node(f(v), map_tree(l, f), map_tree(r, f)),
        }
    }
    
    /// Fold (catamorphism) on a tree. The function `f` receives the node value
    /// and the results of folding the left and right subtrees.
    pub fn fold_tree<T, A>(tree: &Tree<T>, acc: A, f: &impl Fn(&T, A, A) -> A) -> A
    where
        A: Clone,
    {
        match tree {
            Tree::Leaf => acc,
            Tree::Node(v, l, r) => {
                let left = fold_tree(l, acc.clone(), f);
                let right = fold_tree(r, acc, f);
                f(v, left, right)
            }
        }
    }
    
    /// All derived via fold — no explicit recursion needed.
    pub fn size<T>(t: &Tree<T>) -> usize {
        fold_tree(t, 0, &|_, l, r| 1 + l + r)
    }
    
    pub fn depth<T>(t: &Tree<T>) -> usize {
        fold_tree(t, 0, &|_, l, r| 1 + l.max(r))
    }
    
    pub fn sum(t: &Tree<i32>) -> i32 {
        fold_tree(t, 0, &|v, l, r| v + l + r)
    }
    
    pub fn preorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
        fold_tree(t, vec![], &|v, l, r| {
            let mut result = vec![v.clone()];
            result.extend(l);
            result.extend(r);
            result
        })
    }
    
    pub fn inorder<T: Clone>(t: &Tree<T>) -> Vec<T> {
        fold_tree(t, vec![], &|v, l, r| {
            let mut result = l;
            result.push(v.clone());
            result.extend(r);
            result
        })
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
        use Tree::*;
    
        fn sample() -> Tree<i32> {
            //      4
            //     / \
            //    2   6
            //   / \
            //  1   3
            Tree::node(
                4,
                Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
                Tree::node(6, Leaf, Leaf),
            )
        }
    
        #[test]
        fn test_size() {
            assert_eq!(size(&sample()), 5);
            assert_eq!(size::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_depth() {
            assert_eq!(depth(&sample()), 3);
            assert_eq!(depth::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_sum() {
            assert_eq!(sum(&sample()), 16);
            assert_eq!(sum(&Leaf), 0);
        }
    
        #[test]
        fn test_preorder() {
            assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
        }
    
        #[test]
        fn test_inorder() {
            assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
        }
    
        #[test]
        fn test_map_tree() {
            let doubled = map_tree(&sample(), &|v| v * 2);
            assert_eq!(sum(&doubled), 32);
            assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
        }
    
        #[test]
        fn test_single_node() {
            let t = Tree::node(42, Leaf, Leaf);
            assert_eq!(size(&t), 1);
            assert_eq!(sum(&t), 42);
            assert_eq!(preorder(&t), vec![42]);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
        use Tree::*;
    
        fn sample() -> Tree<i32> {
            //      4
            //     / \
            //    2   6
            //   / \
            //  1   3
            Tree::node(
                4,
                Tree::node(2, Tree::node(1, Leaf, Leaf), Tree::node(3, Leaf, Leaf)),
                Tree::node(6, Leaf, Leaf),
            )
        }
    
        #[test]
        fn test_size() {
            assert_eq!(size(&sample()), 5);
            assert_eq!(size::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_depth() {
            assert_eq!(depth(&sample()), 3);
            assert_eq!(depth::<i32>(&Leaf), 0);
        }
    
        #[test]
        fn test_sum() {
            assert_eq!(sum(&sample()), 16);
            assert_eq!(sum(&Leaf), 0);
        }
    
        #[test]
        fn test_preorder() {
            assert_eq!(preorder(&sample()), vec![4, 2, 1, 3, 6]);
        }
    
        #[test]
        fn test_inorder() {
            assert_eq!(inorder(&sample()), vec![1, 2, 3, 4, 6]);
        }
    
        #[test]
        fn test_map_tree() {
            let doubled = map_tree(&sample(), &|v| v * 2);
            assert_eq!(sum(&doubled), 32);
            assert_eq!(preorder(&doubled), vec![8, 4, 2, 6, 12]);
        }
    
        #[test]
        fn test_single_node() {
            let t = Tree::node(42, Leaf, Leaf);
            assert_eq!(size(&t), 1);
            assert_eq!(sum(&t), 42);
            assert_eq!(preorder(&t), vec![42]);
        }
    }

    Deep Comparison

    Map and Fold on Trees: OCaml vs Rust

    The Core Insight

    Once you can fold a tree, you can express almost any tree computation as a one-liner. This example shows how the catamorphism pattern — replacing constructors with functions — works identically in both languages, but Rust's ownership model adds friction around accumulator cloning and closure references.

    OCaml Approach

    OCaml's fold_tree takes a function f : 'a -> 'b -> 'b -> 'b and a base value, recursing over the tree structure. Thanks to currying, size = fold_tree (fun _ l r -> 1 + l + r) 0 reads cleanly. The GC handles all intermediate lists created by preorder/inorder — [v] @ l @ r allocates freely. Pattern matching with function keyword keeps the code terse.

    Rust Approach

    Rust's fold_tree needs A: Clone because the accumulator must be passed to both subtrees — ownership can't be in two places at once. Closures are passed as &impl Fn(...) references to avoid ownership issues. The vec! macro and extend method replace OCaml's @ list append. The code is slightly more verbose but makes every allocation explicit.

    Side-by-Side

    ConceptOCamlRust
    Fold signature('a -> 'b -> 'b -> 'b) -> 'b -> 'a tree -> 'b(&Tree<T>, A, &impl Fn(&T, A, A) -> A) -> A
    AccumulatorPassed freely (GC)Requires Clone bound
    List append@ operatorextend() method
    Closure passingImplicit currying&impl Fn(...) reference
    Derived operationsOne-liners via foldOne-liners via fold
    MemoryGC handles intermediatesExplicit Vec allocation

    What Rust Learners Should Notice

  • • The Clone bound on the accumulator is the price of ownership: both subtrees need their own copy of the base case
  • &impl Fn(...) avoids taking ownership of the closure, so fold_tree can call it multiple times
  • • Rust's vec![] + extend is the idiomatic way to build up collections, replacing OCaml's @ list concatenation
  • • The catamorphism pattern is universal — once you define fold for any data type, you unlock compositional programming
  • • Intermediate Vecs in preorder/inorder are allocated on the heap; in performance-critical code, you'd use a mutable accumulator instead
  • Further Reading

  • • [The Rust Book — Closures](https://doc.rust-lang.org/book/ch13-01-closures.html)
  • • [OCaml Beyond Lists](https://cs3110.github.io/textbook/chapters/hop/beyond_lists.html)
  • Exercises

  • Implement flatten_inorder(t: &Tree<T>) -> Vec<&T> using fold_tree with a Vec accumulator.
  • Write count_leaves(t: &Tree<T>) -> usize and count_inner_nodes(t: &Tree<T>) -> usize using fold_tree.
  • Implement mirror(t: Tree<T>) -> Tree<T> using map_tree and node reconstruction that swaps left and right subtrees.
  • Open Source Repos