ExamplesBy LevelBy TopicLearning Paths
975 Fundamental

975 Sparse Matrix

Functional Programming

Tutorial

The Problem

Implement a sparse matrix using HashMap<(usize, usize), f64> to store only non-zero elements. Operations include get (returns 0.0 for absent entries), set (removes entry on zero assignment), matrix-vector multiply, and transpose. Compare with OCaml's Hashtbl.Make approach using a custom key module.

🎯 Learning Outcomes

  • • Store sparse data in HashMap<(usize, usize), f64> — tuple keys are hashable and comparable
  • • Implement set that removes the entry when v == 0.0 to maintain the sparse invariant
  • • Implement get returning *data.get(&(r, c)).unwrap_or(&0.0) for implicit zero
  • • Implement matrix-vector multiply mv(vec: &[f64]) -> Vec<f64> by iterating only non-zero entries
  • • Implement transpose by building a new SparseMatrix with swapped row/column indices
  • Code Example

    #![allow(clippy::all)]
    // 975: Sparse Matrix
    // Only store non-zero elements using HashMap<(usize,usize), f64>
    // OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys
    
    use std::collections::HashMap;
    
    pub struct SparseMatrix {
        rows: usize,
        cols: usize,
        data: HashMap<(usize, usize), f64>,
    }
    
    impl SparseMatrix {
        pub fn new(rows: usize, cols: usize) -> Self {
            SparseMatrix {
                rows,
                cols,
                data: HashMap::new(),
            }
        }
    
        pub fn set(&mut self, r: usize, c: usize, v: f64) {
            assert!(r < self.rows && c < self.cols, "index out of bounds");
            if v == 0.0 {
                self.data.remove(&(r, c));
            } else {
                self.data.insert((r, c), v);
            }
        }
    
        pub fn get(&self, r: usize, c: usize) -> f64 {
            *self.data.get(&(r, c)).unwrap_or(&0.0)
        }
    
        /// Number of non-zero elements
        pub fn nnz(&self) -> usize {
            self.data.len()
        }
    
        pub fn rows(&self) -> usize {
            self.rows
        }
        pub fn cols(&self) -> usize {
            self.cols
        }
    
        /// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
        pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
            assert_eq!(v.len(), self.cols, "vector length mismatch");
            let mut result = vec![0.0f64; self.rows];
            for (&(r, c), &val) in &self.data {
                result[r] += val * v[c];
            }
            result
        }
    
        /// Transpose: returns new SparseMatrix with rows/cols swapped
        pub fn transpose(&self) -> SparseMatrix {
            let mut t = SparseMatrix::new(self.cols, self.rows);
            for (&(r, c), &v) in &self.data {
                t.data.insert((c, r), v);
            }
            t
        }
    
        /// Element-wise add: returns new matrix
        pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
            assert_eq!(self.rows, other.rows);
            assert_eq!(self.cols, other.cols);
            let mut result = SparseMatrix::new(self.rows, self.cols);
            // Copy self
            for (&k, &v) in &self.data {
                result.data.insert(k, v);
            }
            // Add other
            for (&(r, c), &v) in &other.data {
                let entry = result.data.entry((r, c)).or_insert(0.0);
                *entry += v;
                if *entry == 0.0 {
                    result.data.remove(&(r, c));
                }
            }
            result
        }
    
        /// Iterate non-zero entries (sorted for determinism in tests)
        pub fn entries(&self) -> Vec<((usize, usize), f64)> {
            let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
            v.sort_by_key(|(k, _)| *k);
            v
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_matrix() -> SparseMatrix {
            let mut m = SparseMatrix::new(4, 4);
            m.set(0, 0, 1.0);
            m.set(0, 2, 2.0);
            m.set(1, 1, 3.0);
            m.set(2, 0, 4.0);
            m.set(2, 3, 5.0);
            m.set(3, 3, 6.0);
            m
        }
    
        #[test]
        fn test_get_set() {
            let m = make_matrix();
            assert_eq!(m.nnz(), 6);
            assert_eq!(m.get(0, 0), 1.0);
            assert_eq!(m.get(0, 1), 0.0); // zero element
            assert_eq!(m.get(1, 1), 3.0);
        }
    
        #[test]
        fn test_set_zero_removes() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0);
            assert_eq!(m.nnz(), 5);
            assert_eq!(m.get(1, 1), 0.0);
        }
    
        #[test]
        fn test_matvec() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0); // remove entry
            let v = vec![1.0, 0.0, 1.0, 0.0];
            let result = m.matvec(&v);
            assert_eq!(result[0], 3.0); // 1*1 + 2*1
            assert_eq!(result[1], 0.0);
            assert_eq!(result[2], 4.0); // 4*1
        }
    
        #[test]
        fn test_transpose() {
            let m = make_matrix();
            let mt = m.transpose();
            assert_eq!(mt.get(0, 0), 1.0);
            assert_eq!(mt.get(2, 0), 2.0);
            assert_eq!(mt.get(0, 2), 4.0);
            assert_eq!(mt.get(3, 2), 5.0);
            assert_eq!(mt.get(3, 3), 6.0);
            assert_eq!(mt.nnz(), 6);
        }
    
        #[test]
        fn test_add() {
            let m1 = make_matrix();
            let mut m2 = SparseMatrix::new(4, 4);
            m2.set(0, 0, 1.0);
            m2.set(1, 1, -3.0); // cancels out
    
            let sum = m1.add(&m2);
            assert_eq!(sum.get(0, 0), 2.0); // 1+1
            assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
                                            // m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
            assert_eq!(sum.nnz(), 5);
        }
    }

    Key Differences

    AspectRustOCaml
    Tuple as key(usize, usize)Hash + Eq auto-derivedRequires Hashtbl.Make(IntPairKey) module
    Default valueunwrap_or(&0.0)try find ... with Not_found -> 0.0
    Sparse iteratefor (&(r, c), &v) in &self.dataHashtbl.iter (fun (r,c) v -> ...)
    Floating-point ==v == 0.0 (exact zero)v = 0.0 (same)

    Comparing floating-point values with == 0.0 is deliberate here — we only want to remove entries that were explicitly set to zero, not entries that are approximately zero. For numerical stability, use a threshold in production code.

    OCaml Approach

    module IntPairKey = struct
      type t = int * int
      let compare = compare
      let hash (r, c) = Hashtbl.hash (r lsl 32 lor c)
      let equal a b = a = b
    end
    
    module SparseHashtbl = Hashtbl.Make(IntPairKey)
    
    type sparse_matrix = {
      rows: int; cols: int;
      data: float SparseHashtbl.t;
    }
    
    let create rows cols =
      { rows; cols; data = SparseHashtbl.create 16 }
    
    let set m r c v =
      if v = 0.0 then SparseHashtbl.remove m.data (r, c)
      else SparseHashtbl.replace m.data (r, c) v
    
    let get m r c =
      try SparseHashtbl.find m.data (r, c) with Not_found -> 0.0
    
    let mv m vec =
      let result = Array.make m.rows 0.0 in
      SparseHashtbl.iter (fun (r, c) v ->
        result.(r) <- result.(r) +. v *. vec.(c)
      ) m.data;
      result
    

    OCaml's Hashtbl.Make(Key) requires defining a key module with hash and equal. Rust's HashMap automatically uses Hash and Eq trait implementations on (usize, usize) — no extra boilerplate.

    Full Source

    #![allow(clippy::all)]
    // 975: Sparse Matrix
    // Only store non-zero elements using HashMap<(usize,usize), f64>
    // OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys
    
    use std::collections::HashMap;
    
    pub struct SparseMatrix {
        rows: usize,
        cols: usize,
        data: HashMap<(usize, usize), f64>,
    }
    
    impl SparseMatrix {
        pub fn new(rows: usize, cols: usize) -> Self {
            SparseMatrix {
                rows,
                cols,
                data: HashMap::new(),
            }
        }
    
        pub fn set(&mut self, r: usize, c: usize, v: f64) {
            assert!(r < self.rows && c < self.cols, "index out of bounds");
            if v == 0.0 {
                self.data.remove(&(r, c));
            } else {
                self.data.insert((r, c), v);
            }
        }
    
        pub fn get(&self, r: usize, c: usize) -> f64 {
            *self.data.get(&(r, c)).unwrap_or(&0.0)
        }
    
        /// Number of non-zero elements
        pub fn nnz(&self) -> usize {
            self.data.len()
        }
    
        pub fn rows(&self) -> usize {
            self.rows
        }
        pub fn cols(&self) -> usize {
            self.cols
        }
    
        /// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
        pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
            assert_eq!(v.len(), self.cols, "vector length mismatch");
            let mut result = vec![0.0f64; self.rows];
            for (&(r, c), &val) in &self.data {
                result[r] += val * v[c];
            }
            result
        }
    
        /// Transpose: returns new SparseMatrix with rows/cols swapped
        pub fn transpose(&self) -> SparseMatrix {
            let mut t = SparseMatrix::new(self.cols, self.rows);
            for (&(r, c), &v) in &self.data {
                t.data.insert((c, r), v);
            }
            t
        }
    
        /// Element-wise add: returns new matrix
        pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
            assert_eq!(self.rows, other.rows);
            assert_eq!(self.cols, other.cols);
            let mut result = SparseMatrix::new(self.rows, self.cols);
            // Copy self
            for (&k, &v) in &self.data {
                result.data.insert(k, v);
            }
            // Add other
            for (&(r, c), &v) in &other.data {
                let entry = result.data.entry((r, c)).or_insert(0.0);
                *entry += v;
                if *entry == 0.0 {
                    result.data.remove(&(r, c));
                }
            }
            result
        }
    
        /// Iterate non-zero entries (sorted for determinism in tests)
        pub fn entries(&self) -> Vec<((usize, usize), f64)> {
            let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
            v.sort_by_key(|(k, _)| *k);
            v
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_matrix() -> SparseMatrix {
            let mut m = SparseMatrix::new(4, 4);
            m.set(0, 0, 1.0);
            m.set(0, 2, 2.0);
            m.set(1, 1, 3.0);
            m.set(2, 0, 4.0);
            m.set(2, 3, 5.0);
            m.set(3, 3, 6.0);
            m
        }
    
        #[test]
        fn test_get_set() {
            let m = make_matrix();
            assert_eq!(m.nnz(), 6);
            assert_eq!(m.get(0, 0), 1.0);
            assert_eq!(m.get(0, 1), 0.0); // zero element
            assert_eq!(m.get(1, 1), 3.0);
        }
    
        #[test]
        fn test_set_zero_removes() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0);
            assert_eq!(m.nnz(), 5);
            assert_eq!(m.get(1, 1), 0.0);
        }
    
        #[test]
        fn test_matvec() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0); // remove entry
            let v = vec![1.0, 0.0, 1.0, 0.0];
            let result = m.matvec(&v);
            assert_eq!(result[0], 3.0); // 1*1 + 2*1
            assert_eq!(result[1], 0.0);
            assert_eq!(result[2], 4.0); // 4*1
        }
    
        #[test]
        fn test_transpose() {
            let m = make_matrix();
            let mt = m.transpose();
            assert_eq!(mt.get(0, 0), 1.0);
            assert_eq!(mt.get(2, 0), 2.0);
            assert_eq!(mt.get(0, 2), 4.0);
            assert_eq!(mt.get(3, 2), 5.0);
            assert_eq!(mt.get(3, 3), 6.0);
            assert_eq!(mt.nnz(), 6);
        }
    
        #[test]
        fn test_add() {
            let m1 = make_matrix();
            let mut m2 = SparseMatrix::new(4, 4);
            m2.set(0, 0, 1.0);
            m2.set(1, 1, -3.0); // cancels out
    
            let sum = m1.add(&m2);
            assert_eq!(sum.get(0, 0), 2.0); // 1+1
            assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
                                            // m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
            assert_eq!(sum.nnz(), 5);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_matrix() -> SparseMatrix {
            let mut m = SparseMatrix::new(4, 4);
            m.set(0, 0, 1.0);
            m.set(0, 2, 2.0);
            m.set(1, 1, 3.0);
            m.set(2, 0, 4.0);
            m.set(2, 3, 5.0);
            m.set(3, 3, 6.0);
            m
        }
    
        #[test]
        fn test_get_set() {
            let m = make_matrix();
            assert_eq!(m.nnz(), 6);
            assert_eq!(m.get(0, 0), 1.0);
            assert_eq!(m.get(0, 1), 0.0); // zero element
            assert_eq!(m.get(1, 1), 3.0);
        }
    
        #[test]
        fn test_set_zero_removes() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0);
            assert_eq!(m.nnz(), 5);
            assert_eq!(m.get(1, 1), 0.0);
        }
    
        #[test]
        fn test_matvec() {
            let mut m = make_matrix();
            m.set(1, 1, 0.0); // remove entry
            let v = vec![1.0, 0.0, 1.0, 0.0];
            let result = m.matvec(&v);
            assert_eq!(result[0], 3.0); // 1*1 + 2*1
            assert_eq!(result[1], 0.0);
            assert_eq!(result[2], 4.0); // 4*1
        }
    
        #[test]
        fn test_transpose() {
            let m = make_matrix();
            let mt = m.transpose();
            assert_eq!(mt.get(0, 0), 1.0);
            assert_eq!(mt.get(2, 0), 2.0);
            assert_eq!(mt.get(0, 2), 4.0);
            assert_eq!(mt.get(3, 2), 5.0);
            assert_eq!(mt.get(3, 3), 6.0);
            assert_eq!(mt.nnz(), 6);
        }
    
        #[test]
        fn test_add() {
            let m1 = make_matrix();
            let mut m2 = SparseMatrix::new(4, 4);
            m2.set(0, 0, 1.0);
            m2.set(1, 1, -3.0); // cancels out
    
            let sum = m1.add(&m2);
            assert_eq!(sum.get(0, 0), 2.0); // 1+1
            assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
                                            // m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
            assert_eq!(sum.nnz(), 5);
        }
    }

    Deep Comparison

    Sparse Matrix — Comparison

    Core Insight

    A sparse matrix stores only non-zero entries, saving memory when most values are 0. Both languages use a hash map from (row, col) pairs to floats. OCaml requires a custom hashtable module (Hashtbl.Make) because standard Hashtbl needs a custom hash for tuple keys. Rust's HashMap<(usize, usize), f64> works out of the box — tuples derive Hash automatically.

    OCaml Approach

  • module IntPair = struct type t = int * int; let equal ...; let hash ... end
  • module PairHash = Hashtbl.Make(IntPair) — functor application for typed hashtable
  • PairHash.find_opt m.data (r,c) |> Option.value ~default:0.0
  • PairHash.remove when setting to 0 (keep sparsity invariant)
  • PairHash.iter for matvec and transpose iteration
  • • Floats compared with = 0.0 (works for exact zero)
  • Rust Approach

  • HashMap<(usize, usize), f64> — tuple key, hash derived automatically
  • .unwrap_or(&0.0) for zero default
  • .remove(&(r, c)) when setting to 0.0
  • for (&(r, c), &val) in &self.data — destructuring in for loop
  • .entry((r,c)).or_insert(0.0) for accumulate-or-init pattern
  • • Same float-zero comparison: v == 0.0
  • Comparison Table

    AspectOCamlRust
    Tuple key hashHashtbl.Make(IntPair) functorHashMap<(usize,usize), f64> (auto-Hash)
    Default zeroOption.value ~default:0.0.unwrap_or(&0.0)
    Remove zeroPairHash.remove m.data keydata.remove(&key)
    IterationPairHash.iter (fun (r,c) v -> ...)for (&(r,c), &v) in &data
    Accumulateexisting +. v; replace.entry(k).or_insert(0.0) then *e += v
    nnzPairHash.lengthdata.len()
    Index checkfailwithassert!

    Exercises

  • Implement add_matrices(a, b) -> SparseMatrix that merges two sparse matrices entry-by-entry.
  • Implement multiply(a: &SparseMatrix, b: &SparseMatrix) -> SparseMatrix — only compute non-zero products.
  • Add density(&self) -> f64 returning nnz / (rows * cols) — the fraction of non-zero entries.
  • Implement to_dense(&self) -> Vec<Vec<f64>> that materializes the full matrix.
  • Implement a CSR (Compressed Sparse Row) format alongside the HashMap format and benchmark mv for a 1000×1000 matrix with 1% fill.
  • Open Source Repos