ExamplesBy LevelBy TopicLearning Paths
177 Expert

GADT Typed Expression Evaluator

Functional Programming

Tutorial

The Problem

Building on the GADT introduction, this example shows a full typed expression evaluator where the type system guarantees that evaluation never fails with a type mismatch. Each node type in the AST is a separate Rust struct implementing an Expr trait with an associated Value type. This approach ensures that eval on an Add node always produces an integer, and on a Compare node always produces a boolean — type safety is structural, not checked at runtime.

🎯 Learning Outcomes

  • • Implement a typed expression tree using a trait with an associated Value type
  • • See how separate structs per node type (rather than one enum) provide GADT-like safety
  • • Understand the trade-off: trait-based approach vs. phantom-type enum approach
  • • Learn to build composite expressions: IfExpr<B, V> parameterized by condition and value types
  • Code Example

    trait Expr: fmt::Debug {
        type Value;
        fn eval(&self) -> Self::Value;
    }
    
    struct Lit(i64);
    struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);

    Key Differences

  • Unification: OCaml uses one GADT type with multiple constructors; Rust uses one struct per node, all implementing the same trait — a different structural choice.
  • Recursive evaluation: OCaml's eval is let rec over the unified GADT; Rust's eval calls self.cond.eval(), self.then_.eval() as method calls on constituent structs.
  • Type inference: OCaml infers all types in the GADT; Rust requires explicit type bounds on generic parameters of each node struct.
  • Verbosity: Rust's approach requires N struct definitions and N trait implementations; OCaml's requires one type and one function.
  • OCaml Approach

    OCaml's GADT evaluator is more concise:

    type _ expr =
      | Lit : int -> int expr
      | BLit : bool -> bool expr
      | Add : int expr * int expr -> int expr
      | Compare : int expr * int expr -> bool expr
      | If : bool expr * 'a expr * 'a expr -> 'a expr
    let rec eval : type a. a expr -> a = function ...
    

    OCaml's single recursive type with indexed constructors is more unified than Rust's per-struct approach, and eval is a single function with exhaustive pattern matching.

    Full Source

    #![allow(clippy::all)]
    // Example 177: GADT Typed Expression Evaluator
    // Only well-typed expressions can be constructed
    
    use std::fmt;
    
    // === Approach 1: Trait-based typed expression tree ===
    // Each node type is a separate struct; the trait ensures type safety
    
    trait Expr: fmt::Debug {
        type Value;
        fn eval(&self) -> Self::Value;
        fn to_expr_string(&self) -> String;
    }
    
    #[derive(Debug)]
    struct Lit(i64);
    
    #[derive(Debug)]
    struct BLit(bool);
    
    #[derive(Debug)]
    struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    
    #[derive(Debug)]
    struct Mul<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    
    #[derive(Debug)]
    struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    
    #[derive(Debug)]
    struct And<A: Expr<Value = bool>, B: Expr<Value = bool>>(A, B);
    
    #[derive(Debug)]
    struct Not<A: Expr<Value = bool>>(A);
    
    #[derive(Debug)]
    struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);
    
    impl Expr for Lit {
        type Value = i64;
        fn eval(&self) -> i64 {
            self.0
        }
        fn to_expr_string(&self) -> String {
            self.0.to_string()
        }
    }
    
    impl Expr for BLit {
        type Value = bool;
        fn eval(&self) -> bool {
            self.0
        }
        fn to_expr_string(&self) -> String {
            self.0.to_string()
        }
    }
    
    impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Add<A, B> {
        type Value = i64;
        fn eval(&self) -> i64 {
            self.0.eval() + self.1.eval()
        }
        fn to_expr_string(&self) -> String {
            format!(
                "({} + {})",
                self.0.to_expr_string(),
                self.1.to_expr_string()
            )
        }
    }
    
    impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Mul<A, B> {
        type Value = i64;
        fn eval(&self) -> i64 {
            self.0.eval() * self.1.eval()
        }
        fn to_expr_string(&self) -> String {
            format!(
                "({} * {})",
                self.0.to_expr_string(),
                self.1.to_expr_string()
            )
        }
    }
    
    impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Eq<A, B> {
        type Value = bool;
        fn eval(&self) -> bool {
            self.0.eval() == self.1.eval()
        }
        fn to_expr_string(&self) -> String {
            format!(
                "({} = {})",
                self.0.to_expr_string(),
                self.1.to_expr_string()
            )
        }
    }
    
    impl<A: Expr<Value = bool>, B: Expr<Value = bool>> Expr for And<A, B> {
        type Value = bool;
        fn eval(&self) -> bool {
            self.0.eval() && self.1.eval()
        }
        fn to_expr_string(&self) -> String {
            format!(
                "({} && {})",
                self.0.to_expr_string(),
                self.1.to_expr_string()
            )
        }
    }
    
    impl<A: Expr<Value = bool>> Expr for Not<A> {
        type Value = bool;
        fn eval(&self) -> bool {
            !self.0.eval()
        }
        fn to_expr_string(&self) -> String {
            format!("not({})", self.0.to_expr_string())
        }
    }
    
    impl<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>> Expr for IfExpr<C, T, F> {
        type Value = T::Value;
        fn eval(&self) -> T::Value {
            if self.0.eval() {
                self.1.eval()
            } else {
                self.2.eval()
            }
        }
        fn to_expr_string(&self) -> String {
            format!(
                "if {} then {} else {}",
                self.0.to_expr_string(),
                self.1.to_expr_string(),
                self.2.to_expr_string()
            )
        }
    }
    
    // === Approach 2: Boxed dynamic dispatch for runtime-built trees ===
    
    trait DynExprI64: fmt::Debug {
        fn eval(&self) -> i64;
    }
    
    struct DynLit(i64);
    impl fmt::Debug for DynLit {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            write!(f, "{}", self.0)
        }
    }
    impl DynExprI64 for DynLit {
        fn eval(&self) -> i64 {
            self.0
        }
    }
    
    struct DynAdd(Box<dyn DynExprI64>, Box<dyn DynExprI64>);
    impl fmt::Debug for DynAdd {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            write!(f, "({:?} + {:?})", self.0, self.1)
        }
    }
    impl DynExprI64 for DynAdd {
        fn eval(&self) -> i64 {
            self.0.eval() + self.1.eval()
        }
    }
    
    // === Approach 3: Enum-based with optimization pass ===
    
    #[derive(Debug, Clone)]
    enum IntExpr {
        Lit(i64),
        Add(Box<IntExpr>, Box<IntExpr>),
        Mul(Box<IntExpr>, Box<IntExpr>),
        IfB(Box<BoolExpr>, Box<IntExpr>, Box<IntExpr>),
    }
    
    #[derive(Debug, Clone)]
    enum BoolExpr {
        Lit(bool),
        Eq(Box<IntExpr>, Box<IntExpr>),
        And(Box<BoolExpr>, Box<BoolExpr>),
        Not(Box<BoolExpr>),
    }
    
    impl IntExpr {
        fn eval(&self) -> i64 {
            match self {
                IntExpr::Lit(n) => *n,
                IntExpr::Add(a, b) => a.eval() + b.eval(),
                IntExpr::Mul(a, b) => a.eval() * b.eval(),
                IntExpr::IfB(c, t, f) => {
                    if c.eval() {
                        t.eval()
                    } else {
                        f.eval()
                    }
                }
            }
        }
    
        fn optimize(self) -> Self {
            match self {
                IntExpr::Add(a, b) => {
                    let a = a.optimize();
                    let b = b.optimize();
                    match (&a, &b) {
                        (IntExpr::Lit(0), _) => b,
                        (_, IntExpr::Lit(0)) => a,
                        (IntExpr::Lit(x), IntExpr::Lit(y)) => IntExpr::Lit(x + y),
                        _ => IntExpr::Add(Box::new(a), Box::new(b)),
                    }
                }
                IntExpr::Mul(a, b) => {
                    let a = a.optimize();
                    let b = b.optimize();
                    match (&a, &b) {
                        (IntExpr::Lit(0), _) | (_, IntExpr::Lit(0)) => IntExpr::Lit(0),
                        (IntExpr::Lit(1), _) => b,
                        (_, IntExpr::Lit(1)) => a,
                        (IntExpr::Lit(x), IntExpr::Lit(y)) => IntExpr::Lit(x * y),
                        _ => IntExpr::Mul(Box::new(a), Box::new(b)),
                    }
                }
                other => other,
            }
        }
    }
    
    impl BoolExpr {
        fn eval(&self) -> bool {
            match self {
                BoolExpr::Lit(b) => *b,
                BoolExpr::Eq(a, b) => a.eval() == b.eval(),
                BoolExpr::And(a, b) => a.eval() && b.eval(),
                BoolExpr::Not(a) => !a.eval(),
            }
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_static_eval() {
            assert_eq!(Lit(42).eval(), 42);
            assert_eq!(Add(Lit(1), Lit(2)).eval(), 3);
            assert_eq!(Mul(Lit(3), Lit(4)).eval(), 12);
            assert_eq!(Eq(Lit(1), Lit(1)).eval(), true);
            assert_eq!(Eq(Lit(1), Lit(2)).eval(), false);
            assert_eq!(And(BLit(true), BLit(true)).eval(), true);
            assert_eq!(Not(BLit(true)).eval(), false);
        }
    
        #[test]
        fn test_if_expr() {
            assert_eq!(IfExpr(BLit(true), Lit(10), Lit(20)).eval(), 10);
            assert_eq!(IfExpr(BLit(false), Lit(10), Lit(20)).eval(), 20);
        }
    
        #[test]
        fn test_pretty_print() {
            assert_eq!(Add(Lit(1), Lit(2)).to_expr_string(), "(1 + 2)");
            assert_eq!(Not(BLit(true)).to_expr_string(), "not(true)");
        }
    
        #[test]
        fn test_dynamic() {
            let d = DynAdd(Box::new(DynLit(10)), Box::new(DynLit(32)));
            assert_eq!(d.eval(), 42);
        }
    
        #[test]
        fn test_optimize() {
            let e = IntExpr::Add(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(5)));
            assert_eq!(e.optimize().eval(), 5);
    
            let e = IntExpr::Mul(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(999)));
            assert_eq!(e.optimize().eval(), 0);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_static_eval() {
            assert_eq!(Lit(42).eval(), 42);
            assert_eq!(Add(Lit(1), Lit(2)).eval(), 3);
            assert_eq!(Mul(Lit(3), Lit(4)).eval(), 12);
            assert_eq!(Eq(Lit(1), Lit(1)).eval(), true);
            assert_eq!(Eq(Lit(1), Lit(2)).eval(), false);
            assert_eq!(And(BLit(true), BLit(true)).eval(), true);
            assert_eq!(Not(BLit(true)).eval(), false);
        }
    
        #[test]
        fn test_if_expr() {
            assert_eq!(IfExpr(BLit(true), Lit(10), Lit(20)).eval(), 10);
            assert_eq!(IfExpr(BLit(false), Lit(10), Lit(20)).eval(), 20);
        }
    
        #[test]
        fn test_pretty_print() {
            assert_eq!(Add(Lit(1), Lit(2)).to_expr_string(), "(1 + 2)");
            assert_eq!(Not(BLit(true)).to_expr_string(), "not(true)");
        }
    
        #[test]
        fn test_dynamic() {
            let d = DynAdd(Box::new(DynLit(10)), Box::new(DynLit(32)));
            assert_eq!(d.eval(), 42);
        }
    
        #[test]
        fn test_optimize() {
            let e = IntExpr::Add(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(5)));
            assert_eq!(e.optimize().eval(), 5);
    
            let e = IntExpr::Mul(Box::new(IntExpr::Lit(0)), Box::new(IntExpr::Lit(999)));
            assert_eq!(e.optimize().eval(), 0);
        }
    }

    Deep Comparison

    Comparison: Example 177 — GADT Typed Expression Evaluator

    Type Definition

    OCaml

    type _ expr =
      | Lit  : int -> int expr
      | BLit : bool -> bool expr
      | Add  : int expr * int expr -> int expr
      | Eq   : int expr * int expr -> bool expr
      | If   : bool expr * 'a expr * 'a expr -> 'a expr
      | Pair : 'a expr * 'b expr -> ('a * 'b) expr
      | Fst  : ('a * 'b) expr -> 'a expr
    

    Rust

    trait Expr: fmt::Debug {
        type Value;
        fn eval(&self) -> Self::Value;
    }
    
    struct Lit(i64);
    struct Add<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    struct Eq<A: Expr<Value = i64>, B: Expr<Value = i64>>(A, B);
    struct IfExpr<C: Expr<Value = bool>, T: Expr, F: Expr<Value = T::Value>>(C, T, F);
    

    Evaluation

    OCaml

    let rec eval : type a. a expr -> a = function
      | Lit n -> n
      | Add (a, b) -> eval a + eval b
      | Eq (a, b) -> eval a = eval b
      | If (c, t, f) -> if eval c then eval t else eval f
      | Pair (a, b) -> (eval a, eval b)
      | Fst p -> fst (eval p)
    

    Rust

    impl Expr for Lit {
        type Value = i64;
        fn eval(&self) -> i64 { self.0 }
    }
    
    impl<A: Expr<Value = i64>, B: Expr<Value = i64>> Expr for Add<A, B> {
        type Value = i64;
        fn eval(&self) -> i64 { self.0.eval() + self.1.eval() }
    }
    

    Constant Folding

    OCaml

    let rec optimize : type a. a expr -> a expr = function
      | Add (Lit 0, b) -> optimize b
      | Mul (Lit 0, _) -> Lit 0
      | e -> e
    

    Rust

    fn optimize(self) -> Self {
        match self {
            IntExpr::Add(a, b) => match (&a.optimize(), &b.optimize()) {
                (IntExpr::Lit(0), _) => b,
                _ => IntExpr::Add(a, b),
            },
            other => other,
        }
    }
    

    Exercises

  • Add a Mul<L: Expr<Value=i64>, R: Expr<Value=i64>> node for multiplication.
  • Add a Not<E: Expr<Value=bool>> node for boolean negation.
  • Implement a pretty_print method on each node type via a separate PrettyPrint trait.
  • Open Source Repos