975 Sparse Matrix
Tutorial
The Problem
Implement a sparse matrix using HashMap<(usize, usize), f64> to store only non-zero elements. Operations include get (returns 0.0 for absent entries), set (removes entry on zero assignment), matrix-vector multiply, and transpose. Compare with OCaml's Hashtbl.Make approach using a custom key module.
🎯 Learning Outcomes
HashMap<(usize, usize), f64> — tuple keys are hashable and comparableset that removes the entry when v == 0.0 to maintain the sparse invariantget returning *data.get(&(r, c)).unwrap_or(&0.0) for implicit zeromv(vec: &[f64]) -> Vec<f64> by iterating only non-zero entriestranspose by building a new SparseMatrix with swapped row/column indicesCode Example
#![allow(clippy::all)]
// 975: Sparse Matrix
// Only store non-zero elements using HashMap<(usize,usize), f64>
// OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys
use std::collections::HashMap;
pub struct SparseMatrix {
rows: usize,
cols: usize,
data: HashMap<(usize, usize), f64>,
}
impl SparseMatrix {
pub fn new(rows: usize, cols: usize) -> Self {
SparseMatrix {
rows,
cols,
data: HashMap::new(),
}
}
pub fn set(&mut self, r: usize, c: usize, v: f64) {
assert!(r < self.rows && c < self.cols, "index out of bounds");
if v == 0.0 {
self.data.remove(&(r, c));
} else {
self.data.insert((r, c), v);
}
}
pub fn get(&self, r: usize, c: usize) -> f64 {
*self.data.get(&(r, c)).unwrap_or(&0.0)
}
/// Number of non-zero elements
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
/// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
assert_eq!(v.len(), self.cols, "vector length mismatch");
let mut result = vec![0.0f64; self.rows];
for (&(r, c), &val) in &self.data {
result[r] += val * v[c];
}
result
}
/// Transpose: returns new SparseMatrix with rows/cols swapped
pub fn transpose(&self) -> SparseMatrix {
let mut t = SparseMatrix::new(self.cols, self.rows);
for (&(r, c), &v) in &self.data {
t.data.insert((c, r), v);
}
t
}
/// Element-wise add: returns new matrix
pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
assert_eq!(self.rows, other.rows);
assert_eq!(self.cols, other.cols);
let mut result = SparseMatrix::new(self.rows, self.cols);
// Copy self
for (&k, &v) in &self.data {
result.data.insert(k, v);
}
// Add other
for (&(r, c), &v) in &other.data {
let entry = result.data.entry((r, c)).or_insert(0.0);
*entry += v;
if *entry == 0.0 {
result.data.remove(&(r, c));
}
}
result
}
/// Iterate non-zero entries (sorted for determinism in tests)
pub fn entries(&self) -> Vec<((usize, usize), f64)> {
let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
v.sort_by_key(|(k, _)| *k);
v
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_matrix() -> SparseMatrix {
let mut m = SparseMatrix::new(4, 4);
m.set(0, 0, 1.0);
m.set(0, 2, 2.0);
m.set(1, 1, 3.0);
m.set(2, 0, 4.0);
m.set(2, 3, 5.0);
m.set(3, 3, 6.0);
m
}
#[test]
fn test_get_set() {
let m = make_matrix();
assert_eq!(m.nnz(), 6);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(0, 1), 0.0); // zero element
assert_eq!(m.get(1, 1), 3.0);
}
#[test]
fn test_set_zero_removes() {
let mut m = make_matrix();
m.set(1, 1, 0.0);
assert_eq!(m.nnz(), 5);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_matvec() {
let mut m = make_matrix();
m.set(1, 1, 0.0); // remove entry
let v = vec![1.0, 0.0, 1.0, 0.0];
let result = m.matvec(&v);
assert_eq!(result[0], 3.0); // 1*1 + 2*1
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 4.0); // 4*1
}
#[test]
fn test_transpose() {
let m = make_matrix();
let mt = m.transpose();
assert_eq!(mt.get(0, 0), 1.0);
assert_eq!(mt.get(2, 0), 2.0);
assert_eq!(mt.get(0, 2), 4.0);
assert_eq!(mt.get(3, 2), 5.0);
assert_eq!(mt.get(3, 3), 6.0);
assert_eq!(mt.nnz(), 6);
}
#[test]
fn test_add() {
let m1 = make_matrix();
let mut m2 = SparseMatrix::new(4, 4);
m2.set(0, 0, 1.0);
m2.set(1, 1, -3.0); // cancels out
let sum = m1.add(&m2);
assert_eq!(sum.get(0, 0), 2.0); // 1+1
assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
// m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
assert_eq!(sum.nnz(), 5);
}
}Key Differences
| Aspect | Rust | OCaml |
|---|---|---|
| Tuple as key | (usize, usize) — Hash + Eq auto-derived | Requires Hashtbl.Make(IntPairKey) module |
| Default value | unwrap_or(&0.0) | try find ... with Not_found -> 0.0 |
| Sparse iterate | for (&(r, c), &v) in &self.data | Hashtbl.iter (fun (r,c) v -> ...) |
| Floating-point == | v == 0.0 (exact zero) | v = 0.0 (same) |
Comparing floating-point values with == 0.0 is deliberate here — we only want to remove entries that were explicitly set to zero, not entries that are approximately zero. For numerical stability, use a threshold in production code.
OCaml Approach
module IntPairKey = struct
type t = int * int
let compare = compare
let hash (r, c) = Hashtbl.hash (r lsl 32 lor c)
let equal a b = a = b
end
module SparseHashtbl = Hashtbl.Make(IntPairKey)
type sparse_matrix = {
rows: int; cols: int;
data: float SparseHashtbl.t;
}
let create rows cols =
{ rows; cols; data = SparseHashtbl.create 16 }
let set m r c v =
if v = 0.0 then SparseHashtbl.remove m.data (r, c)
else SparseHashtbl.replace m.data (r, c) v
let get m r c =
try SparseHashtbl.find m.data (r, c) with Not_found -> 0.0
let mv m vec =
let result = Array.make m.rows 0.0 in
SparseHashtbl.iter (fun (r, c) v ->
result.(r) <- result.(r) +. v *. vec.(c)
) m.data;
result
OCaml's Hashtbl.Make(Key) requires defining a key module with hash and equal. Rust's HashMap automatically uses Hash and Eq trait implementations on (usize, usize) — no extra boilerplate.
Full Source
#![allow(clippy::all)]
// 975: Sparse Matrix
// Only store non-zero elements using HashMap<(usize,usize), f64>
// OCaml uses custom Hashtbl.Make; Rust uses std HashMap with tuple keys
use std::collections::HashMap;
pub struct SparseMatrix {
rows: usize,
cols: usize,
data: HashMap<(usize, usize), f64>,
}
impl SparseMatrix {
pub fn new(rows: usize, cols: usize) -> Self {
SparseMatrix {
rows,
cols,
data: HashMap::new(),
}
}
pub fn set(&mut self, r: usize, c: usize, v: f64) {
assert!(r < self.rows && c < self.cols, "index out of bounds");
if v == 0.0 {
self.data.remove(&(r, c));
} else {
self.data.insert((r, c), v);
}
}
pub fn get(&self, r: usize, c: usize) -> f64 {
*self.data.get(&(r, c)).unwrap_or(&0.0)
}
/// Number of non-zero elements
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
/// Matrix-vector multiply: result[i] = sum_j mat[i,j] * v[j]
pub fn matvec(&self, v: &[f64]) -> Vec<f64> {
assert_eq!(v.len(), self.cols, "vector length mismatch");
let mut result = vec![0.0f64; self.rows];
for (&(r, c), &val) in &self.data {
result[r] += val * v[c];
}
result
}
/// Transpose: returns new SparseMatrix with rows/cols swapped
pub fn transpose(&self) -> SparseMatrix {
let mut t = SparseMatrix::new(self.cols, self.rows);
for (&(r, c), &v) in &self.data {
t.data.insert((c, r), v);
}
t
}
/// Element-wise add: returns new matrix
pub fn add(&self, other: &SparseMatrix) -> SparseMatrix {
assert_eq!(self.rows, other.rows);
assert_eq!(self.cols, other.cols);
let mut result = SparseMatrix::new(self.rows, self.cols);
// Copy self
for (&k, &v) in &self.data {
result.data.insert(k, v);
}
// Add other
for (&(r, c), &v) in &other.data {
let entry = result.data.entry((r, c)).or_insert(0.0);
*entry += v;
if *entry == 0.0 {
result.data.remove(&(r, c));
}
}
result
}
/// Iterate non-zero entries (sorted for determinism in tests)
pub fn entries(&self) -> Vec<((usize, usize), f64)> {
let mut v: Vec<_> = self.data.iter().map(|(&k, &v)| (k, v)).collect();
v.sort_by_key(|(k, _)| *k);
v
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_matrix() -> SparseMatrix {
let mut m = SparseMatrix::new(4, 4);
m.set(0, 0, 1.0);
m.set(0, 2, 2.0);
m.set(1, 1, 3.0);
m.set(2, 0, 4.0);
m.set(2, 3, 5.0);
m.set(3, 3, 6.0);
m
}
#[test]
fn test_get_set() {
let m = make_matrix();
assert_eq!(m.nnz(), 6);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(0, 1), 0.0); // zero element
assert_eq!(m.get(1, 1), 3.0);
}
#[test]
fn test_set_zero_removes() {
let mut m = make_matrix();
m.set(1, 1, 0.0);
assert_eq!(m.nnz(), 5);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_matvec() {
let mut m = make_matrix();
m.set(1, 1, 0.0); // remove entry
let v = vec![1.0, 0.0, 1.0, 0.0];
let result = m.matvec(&v);
assert_eq!(result[0], 3.0); // 1*1 + 2*1
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 4.0); // 4*1
}
#[test]
fn test_transpose() {
let m = make_matrix();
let mt = m.transpose();
assert_eq!(mt.get(0, 0), 1.0);
assert_eq!(mt.get(2, 0), 2.0);
assert_eq!(mt.get(0, 2), 4.0);
assert_eq!(mt.get(3, 2), 5.0);
assert_eq!(mt.get(3, 3), 6.0);
assert_eq!(mt.nnz(), 6);
}
#[test]
fn test_add() {
let m1 = make_matrix();
let mut m2 = SparseMatrix::new(4, 4);
m2.set(0, 0, 1.0);
m2.set(1, 1, -3.0); // cancels out
let sum = m1.add(&m2);
assert_eq!(sum.get(0, 0), 2.0); // 1+1
assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
// m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
assert_eq!(sum.nnz(), 5);
}
}#[cfg(test)]
mod tests {
use super::*;
fn make_matrix() -> SparseMatrix {
let mut m = SparseMatrix::new(4, 4);
m.set(0, 0, 1.0);
m.set(0, 2, 2.0);
m.set(1, 1, 3.0);
m.set(2, 0, 4.0);
m.set(2, 3, 5.0);
m.set(3, 3, 6.0);
m
}
#[test]
fn test_get_set() {
let m = make_matrix();
assert_eq!(m.nnz(), 6);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(0, 1), 0.0); // zero element
assert_eq!(m.get(1, 1), 3.0);
}
#[test]
fn test_set_zero_removes() {
let mut m = make_matrix();
m.set(1, 1, 0.0);
assert_eq!(m.nnz(), 5);
assert_eq!(m.get(1, 1), 0.0);
}
#[test]
fn test_matvec() {
let mut m = make_matrix();
m.set(1, 1, 0.0); // remove entry
let v = vec![1.0, 0.0, 1.0, 0.0];
let result = m.matvec(&v);
assert_eq!(result[0], 3.0); // 1*1 + 2*1
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 4.0); // 4*1
}
#[test]
fn test_transpose() {
let m = make_matrix();
let mt = m.transpose();
assert_eq!(mt.get(0, 0), 1.0);
assert_eq!(mt.get(2, 0), 2.0);
assert_eq!(mt.get(0, 2), 4.0);
assert_eq!(mt.get(3, 2), 5.0);
assert_eq!(mt.get(3, 3), 6.0);
assert_eq!(mt.nnz(), 6);
}
#[test]
fn test_add() {
let m1 = make_matrix();
let mut m2 = SparseMatrix::new(4, 4);
m2.set(0, 0, 1.0);
m2.set(1, 1, -3.0); // cancels out
let sum = m1.add(&m2);
assert_eq!(sum.get(0, 0), 2.0); // 1+1
assert_eq!(sum.get(1, 1), 0.0); // 3+(-3)=0, removed
// m1 had 6 entries, m2 adds (0,0) which merges, (1,1) cancels → 5 non-zero
assert_eq!(sum.nnz(), 5);
}
}
Deep Comparison
Sparse Matrix — Comparison
Core Insight
A sparse matrix stores only non-zero entries, saving memory when most values are 0. Both languages use a hash map from (row, col) pairs to floats. OCaml requires a custom hashtable module (Hashtbl.Make) because standard Hashtbl needs a custom hash for tuple keys. Rust's HashMap<(usize, usize), f64> works out of the box — tuples derive Hash automatically.
OCaml Approach
module IntPair = struct type t = int * int; let equal ...; let hash ... endmodule PairHash = Hashtbl.Make(IntPair) — functor application for typed hashtablePairHash.find_opt m.data (r,c) |> Option.value ~default:0.0PairHash.remove when setting to 0 (keep sparsity invariant)PairHash.iter for matvec and transpose iteration= 0.0 (works for exact zero)Rust Approach
HashMap<(usize, usize), f64> — tuple key, hash derived automatically.unwrap_or(&0.0) for zero default.remove(&(r, c)) when setting to 0.0for (&(r, c), &val) in &self.data — destructuring in for loop.entry((r,c)).or_insert(0.0) for accumulate-or-init patternv == 0.0Comparison Table
| Aspect | OCaml | Rust |
|---|---|---|
| Tuple key hash | Hashtbl.Make(IntPair) functor | HashMap<(usize,usize), f64> (auto-Hash) |
| Default zero | Option.value ~default:0.0 | .unwrap_or(&0.0) |
| Remove zero | PairHash.remove m.data key | data.remove(&key) |
| Iteration | PairHash.iter (fun (r,c) v -> ...) | for (&(r,c), &v) in &data |
| Accumulate | existing +. v; replace | .entry(k).or_insert(0.0) then *e += v |
| nnz | PairHash.length | data.len() |
| Index check | failwith | assert! |
Exercises
add_matrices(a, b) -> SparseMatrix that merges two sparse matrices entry-by-entry.multiply(a: &SparseMatrix, b: &SparseMatrix) -> SparseMatrix — only compute non-zero products.density(&self) -> f64 returning nnz / (rows * cols) — the fraction of non-zero entries.to_dense(&self) -> Vec<Vec<f64>> that materializes the full matrix.HashMap format and benchmark mv for a 1000×1000 matrix with 1% fill.