ExamplesBy LevelBy TopicLearning Paths
976 Fundamental

976 Matrix Multiply

Functional Programming

Tutorial

The Problem

Implement matrix multiplication in two styles: a naive O(n³) triple-loop version and a cache-friendly dot-product style using transpose. Also implement a 2Ɨ2 Strassen demo. Compare row-major access patterns and their effect on cache performance, and contrast with OCaml's list-of-lists functional approach.

🎯 Learning Outcomes

  • • Implement naive mat_multiply(a, b) -> Vec<Vec<f64>> with triple nested loops
  • • Implement cache-friendly multiply via transpose(b) then row-dot-row access
  • • Understand why transposing b improves cache performance: sequential row access in both operands
  • • Implement transpose with pre-allocated output matrix and index-swap
  • • Recognize that O(n³) is the practical complexity for dense matrices; Strassen achieves O(n^2.807) theoretically
  • Code Example

    #![allow(clippy::all)]
    // 976: Matrix Multiplication
    // Naive O(n³) and Strassen 2x2 demo
    // OCaml: list-of-lists (functional) + arrays; Rust: Vec<Vec<f64>>
    
    // Approach 1: Vec<Vec<f64>> naive multiply
    pub fn mat_multiply(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
        let n = a.len();
        let m = b[0].len();
        let k = b.len();
        assert_eq!(a[0].len(), k, "dimension mismatch");
    
        let mut result = vec![vec![0.0f64; m]; n];
        for i in 0..n {
            for j in 0..m {
                for l in 0..k {
                    result[i][j] += a[i][l] * b[l][j];
                }
            }
        }
        result
    }
    
    // Transpose a matrix
    pub fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
        if m.is_empty() {
            return vec![];
        }
        let rows = m.len();
        let cols = m[0].len();
        let mut t = vec![vec![0.0f64; rows]; cols];
        for i in 0..rows {
            for j in 0..cols {
                t[j][i] = m[i][j];
            }
        }
        t
    }
    
    // Approach 2: Dot-product style (cache-friendly via transpose)
    pub fn mat_multiply_transposed(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
        let n = a.len();
        let m = b[0].len();
        let bt = transpose(b);
    
        let mut result = vec![vec![0.0f64; m]; n];
        for i in 0..n {
            for j in 0..m {
                result[i][j] = a[i].iter().zip(&bt[j]).map(|(x, y)| x * y).sum();
            }
        }
        result
    }
    
    // Approach 3: Strassen 2x2 (demonstrates the 7-multiply algorithm)
    // Real Strassen: recursively split into n/2 x n/2 blocks
    pub fn strassen_2x2(a: &[[f64; 2]; 2], b: &[[f64; 2]; 2]) -> [[f64; 2]; 2] {
        let (a11, a12, a21, a22) = (a[0][0], a[0][1], a[1][0], a[1][1]);
        let (b11, b12, b21, b22) = (b[0][0], b[0][1], b[1][0], b[1][1]);
    
        let m1 = (a11 + a22) * (b11 + b22);
        let m2 = (a21 + a22) * b11;
        let m3 = a11 * (b12 - b22);
        let m4 = a22 * (b21 - b11);
        let m5 = (a11 + a12) * b22;
        let m6 = (a21 - a11) * (b11 + b12);
        let m7 = (a12 - a22) * (b21 + b22);
    
        [[m1 + m4 - m5 + m7, m3 + m5], [m2 + m4, m1 - m2 + m3 + m6]]
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_2x2_multiply() {
            let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
            let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
            let c = mat_multiply(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_non_square() {
            let m23 = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let m32 = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let result = mat_multiply(&m23, &m32);
            assert_eq!(result[0][0], 58.0);
            assert_eq!(result[0][1], 64.0);
            assert_eq!(result[1][0], 139.0);
            assert_eq!(result[1][1], 154.0);
        }
    
        #[test]
        fn test_transposed_matches_naive() {
            let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let b = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let naive = mat_multiply(&a, &b);
            let transposed = mat_multiply_transposed(&a, &b);
            assert_eq!(naive, transposed);
        }
    
        #[test]
        fn test_strassen_2x2() {
            let a = [[1.0, 2.0], [3.0, 4.0]];
            let b = [[5.0, 6.0], [7.0, 8.0]];
            let c = strassen_2x2(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_identity() {
            let a = vec![vec![3.0, 4.0], vec![5.0, 6.0]];
            let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
            let result = mat_multiply(&a, &identity);
            assert_eq!(result, a);
        }
    }

    Key Differences

    AspectRustOCaml
    Matrix typeVec<Vec<f64>> — heap rowsfloat list list — linked lists
    Cache behaviorRow access O(1), column access jumpsAll access O(n) pointer chasing
    Functional stylemap/zip/sum pipelineList.map (List.fold_left2 ...)
    Performancendarray for productionBigarray / owl for production

    For production matrix operations, use nalgebra or ndarray crates (Rust) or owl / Bigarray (OCaml). The Vec<Vec<f64>> approach here demonstrates the algorithm, not the optimal data layout.

    OCaml Approach

    let mat_multiply a b =
      let bt = transpose b in
      List.map (fun row_a ->
        List.map (fun row_bt ->
          List.fold_left2 (fun acc x y -> acc +. x *. y) 0.0 row_a row_bt
        ) bt
      ) a
    

    OCaml's list-of-lists approach is clean but cache-inefficient — linked lists have poor spatial locality. For performance-critical matrix work, OCaml uses Bigarray (C-backed arrays) or the owl library.

    Full Source

    #![allow(clippy::all)]
    // 976: Matrix Multiplication
    // Naive O(n³) and Strassen 2x2 demo
    // OCaml: list-of-lists (functional) + arrays; Rust: Vec<Vec<f64>>
    
    // Approach 1: Vec<Vec<f64>> naive multiply
    pub fn mat_multiply(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
        let n = a.len();
        let m = b[0].len();
        let k = b.len();
        assert_eq!(a[0].len(), k, "dimension mismatch");
    
        let mut result = vec![vec![0.0f64; m]; n];
        for i in 0..n {
            for j in 0..m {
                for l in 0..k {
                    result[i][j] += a[i][l] * b[l][j];
                }
            }
        }
        result
    }
    
    // Transpose a matrix
    pub fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
        if m.is_empty() {
            return vec![];
        }
        let rows = m.len();
        let cols = m[0].len();
        let mut t = vec![vec![0.0f64; rows]; cols];
        for i in 0..rows {
            for j in 0..cols {
                t[j][i] = m[i][j];
            }
        }
        t
    }
    
    // Approach 2: Dot-product style (cache-friendly via transpose)
    pub fn mat_multiply_transposed(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
        let n = a.len();
        let m = b[0].len();
        let bt = transpose(b);
    
        let mut result = vec![vec![0.0f64; m]; n];
        for i in 0..n {
            for j in 0..m {
                result[i][j] = a[i].iter().zip(&bt[j]).map(|(x, y)| x * y).sum();
            }
        }
        result
    }
    
    // Approach 3: Strassen 2x2 (demonstrates the 7-multiply algorithm)
    // Real Strassen: recursively split into n/2 x n/2 blocks
    pub fn strassen_2x2(a: &[[f64; 2]; 2], b: &[[f64; 2]; 2]) -> [[f64; 2]; 2] {
        let (a11, a12, a21, a22) = (a[0][0], a[0][1], a[1][0], a[1][1]);
        let (b11, b12, b21, b22) = (b[0][0], b[0][1], b[1][0], b[1][1]);
    
        let m1 = (a11 + a22) * (b11 + b22);
        let m2 = (a21 + a22) * b11;
        let m3 = a11 * (b12 - b22);
        let m4 = a22 * (b21 - b11);
        let m5 = (a11 + a12) * b22;
        let m6 = (a21 - a11) * (b11 + b12);
        let m7 = (a12 - a22) * (b21 + b22);
    
        [[m1 + m4 - m5 + m7, m3 + m5], [m2 + m4, m1 - m2 + m3 + m6]]
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_2x2_multiply() {
            let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
            let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
            let c = mat_multiply(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_non_square() {
            let m23 = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let m32 = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let result = mat_multiply(&m23, &m32);
            assert_eq!(result[0][0], 58.0);
            assert_eq!(result[0][1], 64.0);
            assert_eq!(result[1][0], 139.0);
            assert_eq!(result[1][1], 154.0);
        }
    
        #[test]
        fn test_transposed_matches_naive() {
            let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let b = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let naive = mat_multiply(&a, &b);
            let transposed = mat_multiply_transposed(&a, &b);
            assert_eq!(naive, transposed);
        }
    
        #[test]
        fn test_strassen_2x2() {
            let a = [[1.0, 2.0], [3.0, 4.0]];
            let b = [[5.0, 6.0], [7.0, 8.0]];
            let c = strassen_2x2(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_identity() {
            let a = vec![vec![3.0, 4.0], vec![5.0, 6.0]];
            let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
            let result = mat_multiply(&a, &identity);
            assert_eq!(result, a);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_2x2_multiply() {
            let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
            let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
            let c = mat_multiply(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_non_square() {
            let m23 = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let m32 = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let result = mat_multiply(&m23, &m32);
            assert_eq!(result[0][0], 58.0);
            assert_eq!(result[0][1], 64.0);
            assert_eq!(result[1][0], 139.0);
            assert_eq!(result[1][1], 154.0);
        }
    
        #[test]
        fn test_transposed_matches_naive() {
            let a = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
            let b = vec![vec![7.0, 8.0], vec![9.0, 10.0], vec![11.0, 12.0]];
            let naive = mat_multiply(&a, &b);
            let transposed = mat_multiply_transposed(&a, &b);
            assert_eq!(naive, transposed);
        }
    
        #[test]
        fn test_strassen_2x2() {
            let a = [[1.0, 2.0], [3.0, 4.0]];
            let b = [[5.0, 6.0], [7.0, 8.0]];
            let c = strassen_2x2(&a, &b);
            assert_eq!(c[0][0], 19.0);
            assert_eq!(c[0][1], 22.0);
            assert_eq!(c[1][0], 43.0);
            assert_eq!(c[1][1], 50.0);
        }
    
        #[test]
        fn test_identity() {
            let a = vec![vec![3.0, 4.0], vec![5.0, 6.0]];
            let identity = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
            let result = mat_multiply(&a, &identity);
            assert_eq!(result, a);
        }
    }

    Deep Comparison

    Matrix Multiply — Comparison

    Core Insight

    Matrix multiplication is O(n³) naive, O(n^2.807) Strassen. The triple loop for i, j, k: C[i][j] += A[i][k] * B[k][j] is identical in both languages. OCaml's functional list-of-lists approach is readable but slow; the array approach matches Rust's performance. Transposing B before multiplication improves cache locality (column access becomes row access).

    OCaml Approach

  • • [[1.0;2.0];[3.0;4.0]] — list of float lists (functional, poor cache)
  • • List.init cols (fun c -> List.map (fun row -> List.nth row c) m) — transpose
  • • List.fold_left2 for dot product
  • • Array.make_matrix n m 0.0 for imperative approach
  • • for i = 0 to n-1 do ... done — triple nested imperative loop
  • • !s +. a.(i).(l) *. b.(l).(j) — float arithmetic with . suffix
  • Rust Approach

  • • Vec<Vec<f64>> — row-major, similar memory layout to 2D array
  • • vec![vec![0.0f64; m]; n] — initialize result matrix
  • • result[i][j] += a[i][l] * b[l][j] — clean triple loop
  • • Transpose via double loop (same algorithm, no magic)
  • • .iter().zip(&bt[j]).map(|(x,y)| x*y).sum() — functional dot product
  • • [[f64; 2]; 2] for fixed-size Strassen (stack-allocated, no allocation)
  • Comparison Table

    AspectOCamlRust
    Functional matrixfloat list listVec<Vec<f64>>
    Init resultArray.make_matrix n m 0.0vec![vec![0.0; m]; n]
    Element accessa.(i).(l)a[i][l]
    Float arithmetic+., *. (explicit)+, * (same operators)
    Dot productList.fold_left2.zip().map().sum()
    TransposeList.init cols (fun c -> List.map ...)Double loop
    Fixed 2x2[| [| ... |] |][[f64; 2]; 2] (stack-allocated)
    Strassen7 muls, same formula7 muls, same formula

    Exercises

  • Implement Strassen's algorithm for 2Ɨ2 matrices using 7 multiplications instead of 8.
  • Implement identity_matrix(n: usize) -> Vec<Vec<f64>> and verify multiply(A, I) == A.
  • Add a mat_add(a, b) -> Vec<Vec<f64>> function and verify distributivity.
  • Implement block matrix multiplication: split matrices into 2Ɨ2 blocks and multiply recursively.
  • Benchmark naive vs transpose-based multiply for 256Ɨ256 and 512Ɨ512 matrices.
  • Open Source Repos