783-const-type-arithmetic — Const Type Arithmetic
Tutorial Video
Text description (accessibility)
This video demonstrates the "783-const-type-arithmetic — Const Type Arithmetic" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Type-level arithmetic allows the type system to reason about sizes and dimensions. Key difference from OCaml: 1. **Ergonomics**: Rust's `Matrix<3, 4>` is concise; OCaml's Peano encoding (`Succ (Succ (Succ Zero))`) is verbose and impractical for large dimensions.
Tutorial
The Problem
Type-level arithmetic allows the type system to reason about sizes and dimensions. Add<3, 4>::VALUE == 7 and Mul<3, 4>::VALUE == 12 are computed by the compiler. This enables types like Matrix<f64, 3, 4> and Matrix<f64, 4, 5> to multiply only when the inner dimensions match — a matrix multiplication type error becomes a compile error. Used in linear algebra libraries (nalgebra) and tensor libraries that enforce dimension compatibility.
🎯 Learning Outcomes
Add<A, B> and Mul<A, B> structs with const VALUE: usizeVec3<const N: usize> with typed length and slice accessconcat_vec that combines two vectors and returns a Vec<f64> (not [f64; A+B] — limited by stable Rust)Matrix<ROWS, COLS> multiplication checks inner dimensions at compile timeCode Example
pub struct Add<const A: usize, const B: usize>;
impl<const A: usize, const B: usize> Add<A, B> {
pub const VALUE: usize = A + B;
}
// Compile-time sum
const SUM: usize = Add::<3, 4>::VALUE; // 7Key Differences
Matrix<3, 4> is concise; OCaml's Peano encoding (Succ (Succ (Succ Zero))) is verbose and impractical for large dimensions.[f64; A + B] as a generic array size; OCaml's GADTs don't have this restriction since all arrays are heap-allocated.nalgebra (Rust) uses const generics for dimension-safe matrices with excellent ergonomics; OCaml lacks an equivalent mature library.{ A + B } in const generic expressions; this enables concat_arr<A, B>() -> [f64; A+B] without Vec.OCaml Approach
OCaml achieves type-level dimension checking via GADTs and phantom type arithmetic. Libraries like Tensorflow_ocaml use shape-indexed tensors. tensor-ocaml uses Z.t Succ.t for natural number type encoding (Peano arithmetic). While more verbose than Rust's const generics, OCaml's GADT approach can express more complex invariants. linalg and owl libraries sacrifice type-level safety for practical usability.
Full Source
#![allow(clippy::all)]
//! # Const Type Arithmetic
//!
//! Type-level arithmetic using const generics.
/// Type-level addition result
pub struct Add<const A: usize, const B: usize>;
impl<const A: usize, const B: usize> Add<A, B> {
pub const VALUE: usize = A + B;
}
/// Type-level multiplication result
pub struct Mul<const A: usize, const B: usize>;
impl<const A: usize, const B: usize> Mul<A, B> {
pub const VALUE: usize = A * B;
}
/// Vector with compile-time length
#[derive(Debug)]
pub struct Vec3<const N: usize>([f64; N]);
impl<const N: usize> Vec3<N> {
pub fn new(data: [f64; N]) -> Self {
Vec3(data)
}
pub const fn len(&self) -> usize {
N
}
pub fn get(&self, idx: usize) -> Option<f64> {
self.0.get(idx).copied()
}
pub fn as_slice(&self) -> &[f64] {
&self.0
}
}
/// Concatenate two vectors — returns a Vec since { A + B } requires nightly.
pub fn concat_vec<const A: usize, const B: usize>(a: &Vec3<A>, b: &Vec3<B>) -> Vec<f64> {
let mut result = Vec::with_capacity(A + B);
result.extend_from_slice(a.as_slice());
result.extend_from_slice(b.as_slice());
result
}
/// Matrix dimensions at type level
#[derive(Debug)]
pub struct Matrix<const ROWS: usize, const COLS: usize> {
data: [[f64; COLS]; ROWS],
}
impl<const ROWS: usize, const COLS: usize> Matrix<ROWS, COLS> {
pub fn new() -> Self {
Matrix {
data: [[0.0; COLS]; ROWS],
}
}
pub fn from_array(data: [[f64; COLS]; ROWS]) -> Self {
Matrix { data }
}
pub const fn rows(&self) -> usize {
ROWS
}
pub const fn cols(&self) -> usize {
COLS
}
pub fn get(&self, row: usize, col: usize) -> Option<f64> {
self.data.get(row).and_then(|r| r.get(col)).copied()
}
pub fn set(&mut self, row: usize, col: usize, val: f64) {
if row < ROWS && col < COLS {
self.data[row][col] = val;
}
}
}
impl<const ROWS: usize, const COLS: usize> Default for Matrix<ROWS, COLS> {
fn default() -> Self {
Self::new()
}
}
/// Matrix multiplication with dimension checking
pub fn matmul<const M: usize, const N: usize, const P: usize>(
a: &Matrix<M, N>,
b: &Matrix<N, P>,
) -> Matrix<M, P> {
let mut result = Matrix::<M, P>::new();
for i in 0..M {
for j in 0..P {
let mut sum = 0.0;
for k in 0..N {
sum += a.data[i][k] * b.data[k][j];
}
result.data[i][j] = sum;
}
}
result
}
/// Transpose with dimension swap
pub fn transpose<const M: usize, const N: usize>(a: &Matrix<M, N>) -> Matrix<N, M> {
let mut result = Matrix::<N, M>::new();
for i in 0..M {
for j in 0..N {
result.data[j][i] = a.data[i][j];
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_type() {
assert_eq!(Add::<3, 4>::VALUE, 7);
assert_eq!(Add::<10, 20>::VALUE, 30);
}
#[test]
fn test_mul_type() {
assert_eq!(Mul::<3, 4>::VALUE, 12);
assert_eq!(Mul::<5, 6>::VALUE, 30);
}
#[test]
fn test_vec_concat() {
let a = Vec3::new([1.0, 2.0]);
let b = Vec3::new([3.0, 4.0, 5.0]);
let c = concat_vec(&a, &b);
assert_eq!(c.len(), 5);
assert_eq!(c[0], 1.0);
assert_eq!(c[4], 5.0);
}
#[test]
fn test_matrix_dimensions() {
let m: Matrix<3, 4> = Matrix::new();
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
}
#[test]
fn test_matmul_dimensions() {
let a: Matrix<2, 3> = Matrix::new();
let b: Matrix<3, 4> = Matrix::new();
let c = matmul(&a, &b);
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 4);
}
#[test]
fn test_transpose() {
let mut m: Matrix<2, 3> = Matrix::new();
m.set(0, 1, 5.0);
let t = transpose(&m);
assert_eq!(t.rows(), 3);
assert_eq!(t.cols(), 2);
assert_eq!(t.get(1, 0), Some(5.0));
}
// Compile-time dimension check
const _: () = assert!(Add::<3, 4>::VALUE == 7);
const _: () = assert!(Mul::<3, 4>::VALUE == 12);
}#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_type() {
assert_eq!(Add::<3, 4>::VALUE, 7);
assert_eq!(Add::<10, 20>::VALUE, 30);
}
#[test]
fn test_mul_type() {
assert_eq!(Mul::<3, 4>::VALUE, 12);
assert_eq!(Mul::<5, 6>::VALUE, 30);
}
#[test]
fn test_vec_concat() {
let a = Vec3::new([1.0, 2.0]);
let b = Vec3::new([3.0, 4.0, 5.0]);
let c = concat_vec(&a, &b);
assert_eq!(c.len(), 5);
assert_eq!(c[0], 1.0);
assert_eq!(c[4], 5.0);
}
#[test]
fn test_matrix_dimensions() {
let m: Matrix<3, 4> = Matrix::new();
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
}
#[test]
fn test_matmul_dimensions() {
let a: Matrix<2, 3> = Matrix::new();
let b: Matrix<3, 4> = Matrix::new();
let c = matmul(&a, &b);
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 4);
}
#[test]
fn test_transpose() {
let mut m: Matrix<2, 3> = Matrix::new();
m.set(0, 1, 5.0);
let t = transpose(&m);
assert_eq!(t.rows(), 3);
assert_eq!(t.cols(), 2);
assert_eq!(t.get(1, 0), Some(5.0));
}
// Compile-time dimension check
const _: () = assert!(Add::<3, 4>::VALUE == 7);
const _: () = assert!(Mul::<3, 4>::VALUE == 12);
}
Deep Comparison
OCaml vs Rust: Const Type Arithmetic
Type-Level Arithmetic
Rust
pub struct Add<const A: usize, const B: usize>;
impl<const A: usize, const B: usize> Add<A, B> {
pub const VALUE: usize = A + B;
}
// Compile-time sum
const SUM: usize = Add::<3, 4>::VALUE; // 7
OCaml (GADTs)
(* Complex type-level naturals *)
type z = Z
type 'n s = S of 'n
type ('a, 'b, 'c) add =
| Add_z : (z, 'b, 'b) add
| Add_s : ('a, 'b, 'c) add -> ('a s, 'b, 'c s) add
Matrix with Dimension Types
Rust
pub fn matmul<const M: usize, const N: usize, const P: usize>(
a: &Matrix<M, N>,
b: &Matrix<N, P>,
) -> Matrix<M, P> // Dimensions guaranteed!
let a: Matrix<2, 3> = ...;
let b: Matrix<3, 4> = ...;
let c = matmul(&a, &b); // c: Matrix<2, 4>
OCaml
(* Runtime dimension check *)
let matmul a b =
if Array.length a.(0) <> Array.length b then
invalid_arg "dimension mismatch";
...
Key Differences
| Aspect | OCaml | Rust |
|---|---|---|
| Type-level numbers | GADTs (complex) | Native const generics |
| Dimension check | Runtime | Compile-time |
| Syntax | Witness types | <const N: usize> |
| Arithmetic | Custom types | Native operators |
Exercises
Matrix<R, C>::transpose() -> Matrix<C, R> that reverses the dimensions in the type signature.dot_product<const N: usize>(a: &Vec3<N>, b: &Vec3<N>) -> f64 that computes the inner product — the same N constraint prevents mismatched lengths.outer_product<const M: usize, const N: usize>(a: &Vec3<M>, b: &Vec3<N>) -> Matrix<M, N> that computes the outer product with correct dimension types.