ExamplesBy LevelBy TopicLearning Paths
783 Fundamental

783-const-type-arithmetic — Const Type Arithmetic

Functional Programming

Tutorial Video

Text description (accessibility)

This video demonstrates the "783-const-type-arithmetic — Const Type Arithmetic" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Type-level arithmetic allows the type system to reason about sizes and dimensions. Key difference from OCaml: 1. **Ergonomics**: Rust's `Matrix<3, 4>` is concise; OCaml's Peano encoding (`Succ (Succ (Succ Zero))`) is verbose and impractical for large dimensions.

Tutorial

The Problem

Type-level arithmetic allows the type system to reason about sizes and dimensions. Add<3, 4>::VALUE == 7 and Mul<3, 4>::VALUE == 12 are computed by the compiler. This enables types like Matrix<f64, 3, 4> and Matrix<f64, 4, 5> to multiply only when the inner dimensions match — a matrix multiplication type error becomes a compile error. Used in linear algebra libraries (nalgebra) and tensor libraries that enforce dimension compatibility.

🎯 Learning Outcomes

  • • Implement Add<A, B> and Mul<A, B> structs with const VALUE: usize
  • • Create Vec3<const N: usize> with typed length and slice access
  • • Implement concat_vec that combines two vectors and returns a Vec<f64> (not [f64; A+B] — limited by stable Rust)
  • • Understand why Matrix<ROWS, COLS> multiplication checks inner dimensions at compile time
  • • See how nalgebra uses this for dimension-safe linear algebra operations
  • Code Example

    pub struct Add<const A: usize, const B: usize>;
    
    impl<const A: usize, const B: usize> Add<A, B> {
        pub const VALUE: usize = A + B;
    }
    
    // Compile-time sum
    const SUM: usize = Add::<3, 4>::VALUE; // 7

    Key Differences

  • Ergonomics: Rust's Matrix<3, 4> is concise; OCaml's Peano encoding (Succ (Succ (Succ Zero))) is verbose and impractical for large dimensions.
  • Expression restriction: Stable Rust cannot write [f64; A + B] as a generic array size; OCaml's GADTs don't have this restriction since all arrays are heap-allocated.
  • Library support: nalgebra (Rust) uses const generics for dimension-safe matrices with excellent ergonomics; OCaml lacks an equivalent mature library.
  • Nightly features: Rust nightly allows { A + B } in const generic expressions; this enables concat_arr<A, B>() -> [f64; A+B] without Vec.
  • OCaml Approach

    OCaml achieves type-level dimension checking via GADTs and phantom type arithmetic. Libraries like Tensorflow_ocaml use shape-indexed tensors. tensor-ocaml uses Z.t Succ.t for natural number type encoding (Peano arithmetic). While more verbose than Rust's const generics, OCaml's GADT approach can express more complex invariants. linalg and owl libraries sacrifice type-level safety for practical usability.

    Full Source

    #![allow(clippy::all)]
    //! # Const Type Arithmetic
    //!
    //! Type-level arithmetic using const generics.
    
    /// Type-level addition result
    pub struct Add<const A: usize, const B: usize>;
    
    impl<const A: usize, const B: usize> Add<A, B> {
        pub const VALUE: usize = A + B;
    }
    
    /// Type-level multiplication result
    pub struct Mul<const A: usize, const B: usize>;
    
    impl<const A: usize, const B: usize> Mul<A, B> {
        pub const VALUE: usize = A * B;
    }
    
    /// Vector with compile-time length
    #[derive(Debug)]
    pub struct Vec3<const N: usize>([f64; N]);
    
    impl<const N: usize> Vec3<N> {
        pub fn new(data: [f64; N]) -> Self {
            Vec3(data)
        }
    
        pub const fn len(&self) -> usize {
            N
        }
    
        pub fn get(&self, idx: usize) -> Option<f64> {
            self.0.get(idx).copied()
        }
    
        pub fn as_slice(&self) -> &[f64] {
            &self.0
        }
    }
    
    /// Concatenate two vectors — returns a Vec since { A + B } requires nightly.
    pub fn concat_vec<const A: usize, const B: usize>(a: &Vec3<A>, b: &Vec3<B>) -> Vec<f64> {
        let mut result = Vec::with_capacity(A + B);
        result.extend_from_slice(a.as_slice());
        result.extend_from_slice(b.as_slice());
        result
    }
    
    /// Matrix dimensions at type level
    #[derive(Debug)]
    pub struct Matrix<const ROWS: usize, const COLS: usize> {
        data: [[f64; COLS]; ROWS],
    }
    
    impl<const ROWS: usize, const COLS: usize> Matrix<ROWS, COLS> {
        pub fn new() -> Self {
            Matrix {
                data: [[0.0; COLS]; ROWS],
            }
        }
    
        pub fn from_array(data: [[f64; COLS]; ROWS]) -> Self {
            Matrix { data }
        }
    
        pub const fn rows(&self) -> usize {
            ROWS
        }
    
        pub const fn cols(&self) -> usize {
            COLS
        }
    
        pub fn get(&self, row: usize, col: usize) -> Option<f64> {
            self.data.get(row).and_then(|r| r.get(col)).copied()
        }
    
        pub fn set(&mut self, row: usize, col: usize, val: f64) {
            if row < ROWS && col < COLS {
                self.data[row][col] = val;
            }
        }
    }
    
    impl<const ROWS: usize, const COLS: usize> Default for Matrix<ROWS, COLS> {
        fn default() -> Self {
            Self::new()
        }
    }
    
    /// Matrix multiplication with dimension checking
    pub fn matmul<const M: usize, const N: usize, const P: usize>(
        a: &Matrix<M, N>,
        b: &Matrix<N, P>,
    ) -> Matrix<M, P> {
        let mut result = Matrix::<M, P>::new();
        for i in 0..M {
            for j in 0..P {
                let mut sum = 0.0;
                for k in 0..N {
                    sum += a.data[i][k] * b.data[k][j];
                }
                result.data[i][j] = sum;
            }
        }
        result
    }
    
    /// Transpose with dimension swap
    pub fn transpose<const M: usize, const N: usize>(a: &Matrix<M, N>) -> Matrix<N, M> {
        let mut result = Matrix::<N, M>::new();
        for i in 0..M {
            for j in 0..N {
                result.data[j][i] = a.data[i][j];
            }
        }
        result
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_add_type() {
            assert_eq!(Add::<3, 4>::VALUE, 7);
            assert_eq!(Add::<10, 20>::VALUE, 30);
        }
    
        #[test]
        fn test_mul_type() {
            assert_eq!(Mul::<3, 4>::VALUE, 12);
            assert_eq!(Mul::<5, 6>::VALUE, 30);
        }
    
        #[test]
        fn test_vec_concat() {
            let a = Vec3::new([1.0, 2.0]);
            let b = Vec3::new([3.0, 4.0, 5.0]);
            let c = concat_vec(&a, &b);
            assert_eq!(c.len(), 5);
            assert_eq!(c[0], 1.0);
            assert_eq!(c[4], 5.0);
        }
    
        #[test]
        fn test_matrix_dimensions() {
            let m: Matrix<3, 4> = Matrix::new();
            assert_eq!(m.rows(), 3);
            assert_eq!(m.cols(), 4);
        }
    
        #[test]
        fn test_matmul_dimensions() {
            let a: Matrix<2, 3> = Matrix::new();
            let b: Matrix<3, 4> = Matrix::new();
            let c = matmul(&a, &b);
            assert_eq!(c.rows(), 2);
            assert_eq!(c.cols(), 4);
        }
    
        #[test]
        fn test_transpose() {
            let mut m: Matrix<2, 3> = Matrix::new();
            m.set(0, 1, 5.0);
            let t = transpose(&m);
            assert_eq!(t.rows(), 3);
            assert_eq!(t.cols(), 2);
            assert_eq!(t.get(1, 0), Some(5.0));
        }
    
        // Compile-time dimension check
        const _: () = assert!(Add::<3, 4>::VALUE == 7);
        const _: () = assert!(Mul::<3, 4>::VALUE == 12);
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_add_type() {
            assert_eq!(Add::<3, 4>::VALUE, 7);
            assert_eq!(Add::<10, 20>::VALUE, 30);
        }
    
        #[test]
        fn test_mul_type() {
            assert_eq!(Mul::<3, 4>::VALUE, 12);
            assert_eq!(Mul::<5, 6>::VALUE, 30);
        }
    
        #[test]
        fn test_vec_concat() {
            let a = Vec3::new([1.0, 2.0]);
            let b = Vec3::new([3.0, 4.0, 5.0]);
            let c = concat_vec(&a, &b);
            assert_eq!(c.len(), 5);
            assert_eq!(c[0], 1.0);
            assert_eq!(c[4], 5.0);
        }
    
        #[test]
        fn test_matrix_dimensions() {
            let m: Matrix<3, 4> = Matrix::new();
            assert_eq!(m.rows(), 3);
            assert_eq!(m.cols(), 4);
        }
    
        #[test]
        fn test_matmul_dimensions() {
            let a: Matrix<2, 3> = Matrix::new();
            let b: Matrix<3, 4> = Matrix::new();
            let c = matmul(&a, &b);
            assert_eq!(c.rows(), 2);
            assert_eq!(c.cols(), 4);
        }
    
        #[test]
        fn test_transpose() {
            let mut m: Matrix<2, 3> = Matrix::new();
            m.set(0, 1, 5.0);
            let t = transpose(&m);
            assert_eq!(t.rows(), 3);
            assert_eq!(t.cols(), 2);
            assert_eq!(t.get(1, 0), Some(5.0));
        }
    
        // Compile-time dimension check
        const _: () = assert!(Add::<3, 4>::VALUE == 7);
        const _: () = assert!(Mul::<3, 4>::VALUE == 12);
    }

    Deep Comparison

    OCaml vs Rust: Const Type Arithmetic

    Type-Level Arithmetic

    Rust

    pub struct Add<const A: usize, const B: usize>;
    
    impl<const A: usize, const B: usize> Add<A, B> {
        pub const VALUE: usize = A + B;
    }
    
    // Compile-time sum
    const SUM: usize = Add::<3, 4>::VALUE; // 7
    

    OCaml (GADTs)

    (* Complex type-level naturals *)
    type z = Z
    type 'n s = S of 'n
    
    type ('a, 'b, 'c) add =
      | Add_z : (z, 'b, 'b) add
      | Add_s : ('a, 'b, 'c) add -> ('a s, 'b, 'c s) add
    

    Matrix with Dimension Types

    Rust

    pub fn matmul<const M: usize, const N: usize, const P: usize>(
        a: &Matrix<M, N>,
        b: &Matrix<N, P>,
    ) -> Matrix<M, P>  // Dimensions guaranteed!
    
    let a: Matrix<2, 3> = ...;
    let b: Matrix<3, 4> = ...;
    let c = matmul(&a, &b);  // c: Matrix<2, 4>
    

    OCaml

    (* Runtime dimension check *)
    let matmul a b =
      if Array.length a.(0) <> Array.length b then
        invalid_arg "dimension mismatch";
      ...
    

    Key Differences

    AspectOCamlRust
    Type-level numbersGADTs (complex)Native const generics
    Dimension checkRuntimeCompile-time
    SyntaxWitness types<const N: usize>
    ArithmeticCustom typesNative operators

    Exercises

  • Implement Matrix<R, C>::transpose() -> Matrix<C, R> that reverses the dimensions in the type signature.
  • Write a dot_product<const N: usize>(a: &Vec3<N>, b: &Vec3<N>) -> f64 that computes the inner product — the same N constraint prevents mismatched lengths.
  • Implement outer_product<const M: usize, const N: usize>(a: &Vec3<M>, b: &Vec3<N>) -> Matrix<M, N> that computes the outer product with correct dimension types.
  • Open Source Repos