Pattern Visitor Match
Tutorial Video
Text description (accessibility)
This video demonstrates the "Pattern Visitor Match" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Pattern matching in Rust goes beyond simple value checks — it enables powerful dispatch mechanisms for type-safe command processing, visitor-pattern traversals, state machine transitions, and recursive data structure manipulation. Key difference from OCaml: 1. **Box deref**: Rust requires `Box<T>` for recursive types and Rust's patterns transparently deref through `Box`; OCaml's GC manages recursive variant pointers automatically.
Tutorial
The Problem
Pattern matching in Rust goes beyond simple value checks — it enables powerful dispatch mechanisms for type-safe command processing, visitor-pattern traversals, state machine transitions, and recursive data structure manipulation. This example demonstrates advanced pattern matching techniques that arise in compiler construction, game engines, protocol implementations, and functional programming idioms applied to real systems code.
🎯 Learning Outcomes
Code Example
enum Expr {
Lit(f64),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}Key Differences
Box<T> for recursive types and Rust's patterns transparently deref through Box; OCaml's GC manages recursive variant pointers automatically.const values in patterns; OCaml can use let open Consts in to bring constants into scope for pattern matching.OCaml Approach
OCaml's ML heritage makes it the reference implementation for these patterns. Variant types, exhaustive matching, and recursive type handling in OCaml are equivalent in power:
(* Pattern matching in OCaml handles:
- Variant constructors with data: Cmd (arg1, arg2) -> ...
- Guards: | x when x > threshold -> ...
- Nested patterns: Node { left; right } -> ...
- Recursive cases: the natural form for tree traversal *)
Full Source
#![allow(clippy::all)]
//! # Visitor Pattern via Match
//!
//! Implement multiple traversal operations over a recursive data structure
//! using pattern matching instead of the traditional OOP visitor pattern.
/// Expression AST for arithmetic operations.
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
Lit(f64),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}
impl Expr {
/// Helper to create a literal.
pub fn lit(n: f64) -> Box<Self> {
Box::new(Expr::Lit(n))
}
/// Helper to create an addition.
pub fn add(l: Box<Expr>, r: Box<Expr>) -> Box<Self> {
Box::new(Expr::Add(l, r))
}
/// Helper to create a subtraction.
pub fn sub(l: Box<Expr>, r: Box<Expr>) -> Box<Self> {
Box::new(Expr::Sub(l, r))
}
/// Helper to create a multiplication.
pub fn mul(l: Box<Expr>, r: Box<Expr>) -> Box<Self> {
Box::new(Expr::Mul(l, r))
}
/// Helper to create a division.
pub fn div(l: Box<Expr>, r: Box<Expr>) -> Box<Self> {
Box::new(Expr::Div(l, r))
}
}
/// Visitor 1: Evaluate expression to a number.
pub fn eval(e: &Expr) -> f64 {
match e {
Expr::Lit(n) => *n,
Expr::Add(l, r) => eval(l) + eval(r),
Expr::Sub(l, r) => eval(l) - eval(r),
Expr::Mul(l, r) => eval(l) * eval(r),
Expr::Div(l, r) => eval(l) / eval(r),
}
}
/// Visitor 2: Count number of operations.
pub fn count_ops(e: &Expr) -> usize {
match e {
Expr::Lit(_) => 0,
Expr::Add(l, r) | Expr::Sub(l, r) | Expr::Mul(l, r) | Expr::Div(l, r) => {
1 + count_ops(l) + count_ops(r)
}
}
}
/// Visitor 3: Pretty print the expression.
pub fn pretty(e: &Expr) -> String {
match e {
Expr::Lit(n) => format!("{}", n),
Expr::Add(l, r) => format!("({} + {})", pretty(l), pretty(r)),
Expr::Sub(l, r) => format!("({} - {})", pretty(l), pretty(r)),
Expr::Mul(l, r) => format!("({} * {})", pretty(l), pretty(r)),
Expr::Div(l, r) => format!("({} / {})", pretty(l), pretty(r)),
}
}
/// Visitor 4: Collect all literal values.
pub fn collect_lits(e: &Expr) -> Vec<f64> {
match e {
Expr::Lit(n) => vec![*n],
Expr::Add(l, r) | Expr::Sub(l, r) | Expr::Mul(l, r) | Expr::Div(l, r) => {
let mut v = collect_lits(l);
v.extend(collect_lits(r));
v
}
}
}
/// Visitor 5: Calculate tree depth.
pub fn depth(e: &Expr) -> usize {
match e {
Expr::Lit(_) => 1,
Expr::Add(l, r) | Expr::Sub(l, r) | Expr::Mul(l, r) | Expr::Div(l, r) => {
1 + depth(l).max(depth(r))
}
}
}
/// Visitor 6: Simplify constant expressions (constant folding).
pub fn simplify(e: &Expr) -> Box<Expr> {
match e {
Expr::Lit(n) => Expr::lit(*n),
Expr::Add(l, r) => {
let l = simplify(l);
let r = simplify(r);
if let (Expr::Lit(a), Expr::Lit(b)) = (l.as_ref(), r.as_ref()) {
Expr::lit(a + b)
} else {
Expr::add(l, r)
}
}
Expr::Sub(l, r) => {
let l = simplify(l);
let r = simplify(r);
if let (Expr::Lit(a), Expr::Lit(b)) = (l.as_ref(), r.as_ref()) {
Expr::lit(a - b)
} else {
Expr::sub(l, r)
}
}
Expr::Mul(l, r) => {
let l = simplify(l);
let r = simplify(r);
if let (Expr::Lit(a), Expr::Lit(b)) = (l.as_ref(), r.as_ref()) {
Expr::lit(a * b)
} else {
Expr::mul(l, r)
}
}
Expr::Div(l, r) => {
let l = simplify(l);
let r = simplify(r);
if let (Expr::Lit(a), Expr::Lit(b)) = (l.as_ref(), r.as_ref()) {
Expr::lit(a / b)
} else {
Expr::div(l, r)
}
}
}
}
/// Visitor 7: Check if expression contains division.
pub fn has_division(e: &Expr) -> bool {
match e {
Expr::Lit(_) => false,
Expr::Div(_, _) => true,
Expr::Add(l, r) | Expr::Sub(l, r) | Expr::Mul(l, r) => has_division(l) || has_division(r),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_expr() -> Box<Expr> {
// (3 * 4) + (10 - 2) = 12 + 8 = 20
Expr::add(
Expr::mul(Expr::lit(3.0), Expr::lit(4.0)),
Expr::sub(Expr::lit(10.0), Expr::lit(2.0)),
)
}
#[test]
fn test_eval() {
let e = sample_expr();
assert_eq!(eval(&e), 20.0);
}
#[test]
fn test_eval_simple() {
assert_eq!(eval(&Expr::Lit(42.0)), 42.0);
assert_eq!(eval(&Expr::Add(Expr::lit(2.0), Expr::lit(3.0))), 5.0);
assert_eq!(eval(&Expr::Mul(Expr::lit(4.0), Expr::lit(5.0))), 20.0);
}
#[test]
fn test_count_ops() {
let e = sample_expr();
assert_eq!(count_ops(&e), 3); // add, mul, sub
}
#[test]
fn test_count_ops_lit() {
assert_eq!(count_ops(&Expr::Lit(1.0)), 0);
}
#[test]
fn test_pretty() {
let e = Expr::add(Expr::lit(2.0), Expr::lit(3.0));
assert_eq!(pretty(&e), "(2 + 3)");
}
#[test]
fn test_collect_lits() {
let e = sample_expr();
let lits = collect_lits(&e);
assert_eq!(lits, vec![3.0, 4.0, 10.0, 2.0]);
}
#[test]
fn test_depth() {
let e = sample_expr();
assert_eq!(depth(&e), 3); // add -> mul/sub -> lit
}
#[test]
fn test_depth_lit() {
assert_eq!(depth(&Expr::Lit(1.0)), 1);
}
#[test]
fn test_simplify() {
let e = sample_expr();
let simplified = simplify(&e);
assert_eq!(*simplified, Expr::Lit(20.0));
}
#[test]
fn test_has_division() {
let e = sample_expr();
assert!(!has_division(&e));
let e_div = Expr::div(Expr::lit(10.0), Expr::lit(2.0));
assert!(has_division(&e_div));
}
}#[cfg(test)]
mod tests {
use super::*;
fn sample_expr() -> Box<Expr> {
// (3 * 4) + (10 - 2) = 12 + 8 = 20
Expr::add(
Expr::mul(Expr::lit(3.0), Expr::lit(4.0)),
Expr::sub(Expr::lit(10.0), Expr::lit(2.0)),
)
}
#[test]
fn test_eval() {
let e = sample_expr();
assert_eq!(eval(&e), 20.0);
}
#[test]
fn test_eval_simple() {
assert_eq!(eval(&Expr::Lit(42.0)), 42.0);
assert_eq!(eval(&Expr::Add(Expr::lit(2.0), Expr::lit(3.0))), 5.0);
assert_eq!(eval(&Expr::Mul(Expr::lit(4.0), Expr::lit(5.0))), 20.0);
}
#[test]
fn test_count_ops() {
let e = sample_expr();
assert_eq!(count_ops(&e), 3); // add, mul, sub
}
#[test]
fn test_count_ops_lit() {
assert_eq!(count_ops(&Expr::Lit(1.0)), 0);
}
#[test]
fn test_pretty() {
let e = Expr::add(Expr::lit(2.0), Expr::lit(3.0));
assert_eq!(pretty(&e), "(2 + 3)");
}
#[test]
fn test_collect_lits() {
let e = sample_expr();
let lits = collect_lits(&e);
assert_eq!(lits, vec![3.0, 4.0, 10.0, 2.0]);
}
#[test]
fn test_depth() {
let e = sample_expr();
assert_eq!(depth(&e), 3); // add -> mul/sub -> lit
}
#[test]
fn test_depth_lit() {
assert_eq!(depth(&Expr::Lit(1.0)), 1);
}
#[test]
fn test_simplify() {
let e = sample_expr();
let simplified = simplify(&e);
assert_eq!(*simplified, Expr::Lit(20.0));
}
#[test]
fn test_has_division() {
let e = sample_expr();
assert!(!has_division(&e));
let e_div = Expr::div(Expr::lit(10.0), Expr::lit(2.0));
assert!(has_division(&e_div));
}
}
Deep Comparison
OCaml vs Rust: Visitor Pattern via Match
Expression Type
OCaml
type expr =
| Lit of float
| Add of expr * expr
| Sub of expr * expr
| Mul of expr * expr
| Div of expr * expr
Rust
enum Expr {
Lit(f64),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}
Visitors as Functions
OCaml
let rec eval = function
| Lit n -> n
| Add(l,r) -> eval l +. eval r
| Sub(l,r) -> eval l -. eval r
| Mul(l,r) -> eval l *. eval r
| Div(l,r) -> eval l /. eval r
let rec count_ops = function
| Lit _ -> 0
| Add(l,r)|Sub(l,r)|Mul(l,r)|Div(l,r) ->
1 + count_ops l + count_ops r
Rust
fn eval(e: &Expr) -> f64 {
match e {
Expr::Lit(n) => *n,
Expr::Add(l, r) => eval(l) + eval(r),
Expr::Sub(l, r) => eval(l) - eval(r),
Expr::Mul(l, r) => eval(l) * eval(r),
Expr::Div(l, r) => eval(l) / eval(r),
}
}
fn count_ops(e: &Expr) -> usize {
match e {
Expr::Lit(_) => 0,
Expr::Add(l, r) | Expr::Sub(l, r) |
Expr::Mul(l, r) | Expr::Div(l, r) =>
1 + count_ops(l) + count_ops(r),
}
}
Key Insight
In functional languages, the visitor pattern is simply "write a recursive function with pattern matching." No interfaces, no accept/visit methods, no double dispatch.
Each "visitor" is just a function that matches on the structure.
Advantages over OOP Visitor
| Aspect | OOP Visitor | FP Pattern Match |
|---|---|---|
| New operation | Add visit method to all | Add one function |
| Boilerplate | Accept/Visit interfaces | None |
| Double dispatch | Required | Not needed |
| Exhaustiveness | Manual | Compiler-checked |
| Code location | Scattered across classes | Single function |
Exercises
Vec<T> using only pattern matching and recursion.Err("invalid transition") instead of panicking.