1057-matrix-chain — Matrix Chain Multiplication
Tutorial
The Problem
Multiplying a sequence of matrices is associative: (AB)C = A(BC), but the computational cost varies dramatically with parenthesization. Multiplying a 10×30 matrix by a 30×5 matrix by a 5×60 matrix: (AB)C costs 10×30×5 + 10×5×60 = 4,500 + 3,000 = 7,500 operations; A(BC) costs 30×5×60 + 10×30×60 = 9,000 + 18,000 = 27,000. The optimal ordering can be 10–100× faster for large chains.
Matrix chain ordering is a classic interval DP problem and a fundamental optimization in scientific computing, neural network inference, and linear algebra libraries.
🎯 Learning Outcomes
dp[i][j] = minimum cost for matrices i..jCode Example
#![allow(clippy::all)]
// 1057: Matrix Chain Multiplication — Optimal Parenthesization
use std::collections::HashMap;
// Approach 1: Bottom-up DP
fn matrix_chain_dp(dims: &[usize]) -> usize {
let n = dims.len() - 1;
let mut dp = vec![vec![0usize; n]; n];
for l in 2..=n {
for i in 0..=(n - l) {
let j = i + l - 1;
dp[i][j] = usize::MAX;
for k in i..j {
let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
dp[i][j] = dp[i][j].min(cost);
}
}
}
dp[0][n - 1]
}
// Approach 2: With parenthesization tracking
fn matrix_chain_parens(dims: &[usize]) -> (usize, String) {
let n = dims.len() - 1;
let mut dp = vec![vec![0usize; n]; n];
let mut split = vec![vec![0usize; n]; n];
for l in 2..=n {
for i in 0..=(n - l) {
let j = i + l - 1;
dp[i][j] = usize::MAX;
for k in i..j {
let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
if cost < dp[i][j] {
dp[i][j] = cost;
split[i][j] = k;
}
}
}
}
fn build(i: usize, j: usize, split: &[Vec<usize>]) -> String {
if i == j {
format!("A{}", i + 1)
} else {
format!(
"({}*{})",
build(i, split[i][j], split),
build(split[i][j] + 1, j, split)
)
}
}
(dp[0][n - 1], build(0, n - 1, &split))
}
// Approach 3: Recursive with memoization
fn matrix_chain_memo(dims: &[usize]) -> usize {
fn solve(
i: usize,
j: usize,
dims: &[usize],
cache: &mut HashMap<(usize, usize), usize>,
) -> usize {
if i == j {
return 0;
}
if let Some(&v) = cache.get(&(i, j)) {
return v;
}
let mut best = usize::MAX;
for k in i..j {
let cost = solve(i, k, dims, cache)
+ solve(k + 1, j, dims, cache)
+ dims[i] * dims[k + 1] * dims[j + 1];
best = best.min(cost);
}
cache.insert((i, j), best);
best
}
let mut cache = HashMap::new();
solve(0, dims.len() - 2, dims, &mut cache)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matrix_chain_dp() {
assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
}
#[test]
fn test_matrix_chain_parens() {
let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
assert_eq!(cost, 15125);
assert!(!parens.is_empty());
}
#[test]
fn test_matrix_chain_memo() {
assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
}
}Key Differences
usize::MAX vs max_int**: Rust uses usize::MAX as infinity; OCaml uses max_int. Both risk overflow on addition — use careful comparison before adding.l — the outer loop determines chain length before the inner loop tries split points.split table during DP and recursively decode it to produce a parenthesization string.candle apply similar optimizations.OCaml Approach
let matrix_chain dims =
let n = Array.length dims - 1 in
let dp = Array.make_matrix n n 0 in
for l = 2 to n do
for i = 0 to n - l do
let j = i + l - 1 in
dp.(i).(j) <- max_int;
for k = i to j - 1 do
let cost = dp.(i).(k) + dp.(k+1).(j) + dims.(i) * dims.(k+1) * dims.(j+1) in
if cost < dp.(i).(j) then dp.(i).(j) <- cost
done
done
done;
dp.(0).(n-1)
The algorithm is identical. Interval DP is a mathematical technique with a canonical implementation structure.
Full Source
#![allow(clippy::all)]
// 1057: Matrix Chain Multiplication — Optimal Parenthesization
use std::collections::HashMap;
// Approach 1: Bottom-up DP
fn matrix_chain_dp(dims: &[usize]) -> usize {
let n = dims.len() - 1;
let mut dp = vec![vec![0usize; n]; n];
for l in 2..=n {
for i in 0..=(n - l) {
let j = i + l - 1;
dp[i][j] = usize::MAX;
for k in i..j {
let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
dp[i][j] = dp[i][j].min(cost);
}
}
}
dp[0][n - 1]
}
// Approach 2: With parenthesization tracking
fn matrix_chain_parens(dims: &[usize]) -> (usize, String) {
let n = dims.len() - 1;
let mut dp = vec![vec![0usize; n]; n];
let mut split = vec![vec![0usize; n]; n];
for l in 2..=n {
for i in 0..=(n - l) {
let j = i + l - 1;
dp[i][j] = usize::MAX;
for k in i..j {
let cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];
if cost < dp[i][j] {
dp[i][j] = cost;
split[i][j] = k;
}
}
}
}
fn build(i: usize, j: usize, split: &[Vec<usize>]) -> String {
if i == j {
format!("A{}", i + 1)
} else {
format!(
"({}*{})",
build(i, split[i][j], split),
build(split[i][j] + 1, j, split)
)
}
}
(dp[0][n - 1], build(0, n - 1, &split))
}
// Approach 3: Recursive with memoization
fn matrix_chain_memo(dims: &[usize]) -> usize {
fn solve(
i: usize,
j: usize,
dims: &[usize],
cache: &mut HashMap<(usize, usize), usize>,
) -> usize {
if i == j {
return 0;
}
if let Some(&v) = cache.get(&(i, j)) {
return v;
}
let mut best = usize::MAX;
for k in i..j {
let cost = solve(i, k, dims, cache)
+ solve(k + 1, j, dims, cache)
+ dims[i] * dims[k + 1] * dims[j + 1];
best = best.min(cost);
}
cache.insert((i, j), best);
best
}
let mut cache = HashMap::new();
solve(0, dims.len() - 2, dims, &mut cache)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matrix_chain_dp() {
assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
}
#[test]
fn test_matrix_chain_parens() {
let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
assert_eq!(cost, 15125);
assert!(!parens.is_empty());
}
#[test]
fn test_matrix_chain_memo() {
assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
}
}#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matrix_chain_dp() {
assert_eq!(matrix_chain_dp(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_dp(&[10, 20, 30, 40]), 18000);
}
#[test]
fn test_matrix_chain_parens() {
let (cost, parens) = matrix_chain_parens(&[30, 35, 15, 5, 10, 20, 25]);
assert_eq!(cost, 15125);
assert!(!parens.is_empty());
}
#[test]
fn test_matrix_chain_memo() {
assert_eq!(matrix_chain_memo(&[30, 35, 15, 5, 10, 20, 25]), 15125);
assert_eq!(matrix_chain_memo(&[10, 20, 30, 40]), 18000);
}
}
Deep Comparison
Matrix Chain Multiplication — Comparison
Core Insight
Matrix chain multiplication is the canonical interval DP problem. The key is trying every split point k in range [i, j) and taking the minimum total cost. A separate split table enables reconstructing the optimal parenthesization.
OCaml Approach
Buffer for building parenthesization string recursivelyPrintf.sprintf for formatting matrix namesmax_int as initial sentinelref cells for tracking best in inner loopRust Approach
format! macro for string building in recursive parenthesizationusize::MAX as sentinelHashMap with tuple keys for memoizationComparison Table
| Aspect | OCaml | Rust |
|---|---|---|
| String building | Buffer + Printf.sprintf | format!() macro |
| Infinity sentinel | max_int | usize::MAX |
| 2D table init | Array.init n (fun _ -> Array.make n 0) | vec![vec![0; n]; n] |
| Split tracking | Parallel split array | Parallel split vec |
| Recursion | Natural OCaml recursion | Inner fn with explicit params |
Exercises
format_chain(split: &Vec<Vec<usize>>, i: usize, j: usize, names: &[&str]) -> String that produces a parenthesized expression like "((A×B)×(C×D))".