ExamplesBy LevelBy TopicLearning Paths
1057 Advanced

1057-matrix-chain — Matrix Chain Multiplication

Functional Programming

Tutorial

The Problem

Multiplying a sequence of matrices is associative: (AB)C = A(BC), but the computational cost varies dramatically with parenthesization. Multiplying a 10×30 matrix by a 30×5 matrix by a 5×60 matrix: (AB)C costs 10×30×5 + 10×5×60 = 4,500 + 3,000 = 7,500 operations; A(BC) costs 30×5×60 + 10×30×60 = 9,000 + 18,000 = 27,000. The optimal ordering can be 10–100× faster for large chains.

Matrix chain ordering is a classic interval DP problem and a fundamental optimization in scientific computing, neural network inference, and linear algebra libraries.

🎯 Learning Outcomes

  • • Implement matrix chain DP with dp[i][j] = minimum cost for matrices i..j
  • • Understand interval DP: fill by increasing chain length
  • • Recover the optimal parenthesization using a split table
  • • Recognize that matrix multiplication associativity enables optimization
  • • Connect to BLAS/LAPACK and deep learning frameworks that optimize compute graphs
  • Code Example

    #![allow(clippy::all)]
    // 1057: Matrix Chain Multiplication — Optimal Parenthesization
    
    use std::collections::HashMap;
    
    // Approach 1: Bottom-up DP
    fn matrix_chain_dp(dims: &[usize]) -> usize {
        let n = dims.len() - 1;
        let mut dp = vec![vec![0usize; n]; n];
        for l in 2..=n {
            for i in 0..=(n - l) {
                let j = i + l - 1;
                dp[i][j] = usize::MAX;
                for k in i..j {
                    let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                    dp[i][j] = dp[i][j].min(cost);
                }
            }
        }
        dp[0][n - 1]
    }
    
    // Approach 2: With parenthesization tracking
    fn matrix_chain_parens(dims: &[usize]) -> (usize, String) {
        let n = dims.len() - 1;
        let mut dp = vec![vec![0usize; n]; n];
        let mut split = vec![vec![0usize; n]; n];
        for l in 2..=n {
            for i in 0..=(n - l) {
                let j = i + l - 1;
                dp[i][j] = usize::MAX;
                for k in i..j {
                    let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                    if cost < dp[i][j] {
                        dp[i][j] = cost;
                        split[i][j] = k;
                    }
                }
            }
        }
        fn build(i: usize, j: usize, split: &[Vec<usize>]) -> String {
            if i == j {
                format!("A{}", i + 1)
            } else {
                format!(
                    "({}*{})",
                    build(i, split[i][j], split),
                    build(split[i][j] + 1, j, split)
                )
            }
        }
        (dp[0][n - 1], build(0, n - 1, &split))
    }
    
    // Approach 3: Recursive with memoization
    fn matrix_chain_memo(dims: &[usize]) -> usize {
        fn solve(
            i: usize,
            j: usize,
            dims: &[usize],
            cache: &mut HashMap<(usize, usize), usize>,
        ) -> usize {
            if i == j {
                return 0;
            }
            if let Some(&v) = cache.get(&(i, j)) {
                return v;
            }
            let mut best = usize::MAX;
            for k in i..j {
                let cost = solve(i, k, dims, cache)
                    + solve(k + 1, j, dims, cache)
                    + dims[i] * dims[k + 1] * dims[j + 1];
                best = best.min(cost);
            }
            cache.insert((i, j), best);
            best
        }
        let mut cache = HashMap::new();
        solve(0, dims.len() - 2, dims, &mut cache)
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_matrix_chain_dp() {
            assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
        }
    
        #[test]
        fn test_matrix_chain_parens() {
            let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
            assert_eq!(cost, 15125);
            assert!(!parens.is_empty());
        }
    
        #[test]
        fn test_matrix_chain_memo() {
            assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
        }
    }

    Key Differences

  • **usize::MAX vs max_int**: Rust uses usize::MAX as infinity; OCaml uses max_int. Both risk overflow on addition — use careful comparison before adding.
  • Interval filling order: Both fill by increasing length l — the outer loop determines chain length before the inner loop tries split points.
  • Reconstruction: Both build a separate split table during DP and recursively decode it to produce a parenthesization string.
  • Applications: OCaml ML libraries use this optimization in tensor expression evaluation; Rust ML frameworks like candle apply similar optimizations.
  • OCaml Approach

    let matrix_chain dims =
      let n = Array.length dims - 1 in
      let dp = Array.make_matrix n n 0 in
      for l = 2 to n do
        for i = 0 to n - l do
          let j = i + l - 1 in
          dp.(i).(j) <- max_int;
          for k = i to j - 1 do
            let cost = dp.(i).(k) + dp.(k+1).(j) + dims.(i) * dims.(k+1) * dims.(j+1) in
            if cost < dp.(i).(j) then dp.(i).(j) <- cost
          done
        done
      done;
      dp.(0).(n-1)
    

    The algorithm is identical. Interval DP is a mathematical technique with a canonical implementation structure.

    Full Source

    #![allow(clippy::all)]
    // 1057: Matrix Chain Multiplication — Optimal Parenthesization
    
    use std::collections::HashMap;
    
    // Approach 1: Bottom-up DP
    fn matrix_chain_dp(dims: &[usize]) -> usize {
        let n = dims.len() - 1;
        let mut dp = vec![vec![0usize; n]; n];
        for l in 2..=n {
            for i in 0..=(n - l) {
                let j = i + l - 1;
                dp[i][j] = usize::MAX;
                for k in i..j {
                    let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                    dp[i][j] = dp[i][j].min(cost);
                }
            }
        }
        dp[0][n - 1]
    }
    
    // Approach 2: With parenthesization tracking
    fn matrix_chain_parens(dims: &[usize]) -> (usize, String) {
        let n = dims.len() - 1;
        let mut dp = vec![vec![0usize; n]; n];
        let mut split = vec![vec![0usize; n]; n];
        for l in 2..=n {
            for i in 0..=(n - l) {
                let j = i + l - 1;
                dp[i][j] = usize::MAX;
                for k in i..j {
                    let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
                    if cost < dp[i][j] {
                        dp[i][j] = cost;
                        split[i][j] = k;
                    }
                }
            }
        }
        fn build(i: usize, j: usize, split: &[Vec<usize>]) -> String {
            if i == j {
                format!("A{}", i + 1)
            } else {
                format!(
                    "({}*{})",
                    build(i, split[i][j], split),
                    build(split[i][j] + 1, j, split)
                )
            }
        }
        (dp[0][n - 1], build(0, n - 1, &split))
    }
    
    // Approach 3: Recursive with memoization
    fn matrix_chain_memo(dims: &[usize]) -> usize {
        fn solve(
            i: usize,
            j: usize,
            dims: &[usize],
            cache: &mut HashMap<(usize, usize), usize>,
        ) -> usize {
            if i == j {
                return 0;
            }
            if let Some(&v) = cache.get(&(i, j)) {
                return v;
            }
            let mut best = usize::MAX;
            for k in i..j {
                let cost = solve(i, k, dims, cache)
                    + solve(k + 1, j, dims, cache)
                    + dims[i] * dims[k + 1] * dims[j + 1];
                best = best.min(cost);
            }
            cache.insert((i, j), best);
            best
        }
        let mut cache = HashMap::new();
        solve(0, dims.len() - 2, dims, &mut cache)
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_matrix_chain_dp() {
            assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
        }
    
        #[test]
        fn test_matrix_chain_parens() {
            let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
            assert_eq!(cost, 15125);
            assert!(!parens.is_empty());
        }
    
        #[test]
        fn test_matrix_chain_memo() {
            assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_matrix_chain_dp() {
            assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
        }
    
        #[test]
        fn test_matrix_chain_parens() {
            let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
            assert_eq!(cost, 15125);
            assert!(!parens.is_empty());
        }
    
        #[test]
        fn test_matrix_chain_memo() {
            assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
            assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
        }
    }

    Deep Comparison

    Matrix Chain Multiplication — Comparison

    Core Insight

    Matrix chain multiplication is the canonical interval DP problem. The key is trying every split point k in range [i, j) and taking the minimum total cost. A separate split table enables reconstructing the optimal parenthesization.

    OCaml Approach

  • Buffer for building parenthesization string recursively
  • Printf.sprintf for formatting matrix names
  • max_int as initial sentinel
  • ref cells for tracking best in inner loop
  • Rust Approach

  • format! macro for string building in recursive parenthesization
  • usize::MAX as sentinel
  • • Nested function for recursive string building
  • HashMap with tuple keys for memoization
  • Comparison Table

    AspectOCamlRust
    String buildingBuffer + Printf.sprintfformat!() macro
    Infinity sentinelmax_intusize::MAX
    2D table initArray.init n (fun _ -> Array.make n 0)vec![vec![0; n]; n]
    Split trackingParallel split arrayParallel split vec
    RecursionNatural OCaml recursionInner fn with explicit params

    Exercises

  • Add memoized top-down implementation and verify it produces the same answer as the bottom-up version.
  • Implement the reconstruction function format_chain(split: &Vec<Vec<usize>>, i: usize, j: usize, names: &[&str]) -> String that produces a parenthesized expression like "((A×B)×(C×D))".
  • Extend to weighted matrix chain where some multiplications have additional overhead (e.g., GPU memory transfer costs).
  • Open Source Repos