449: Rayon Join — Fork-Join Parallelism
Tutorial Video
Text description (accessibility)
This video demonstrates the "449: Rayon Join — Fork-Join Parallelism" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Divide-and-conquer algorithms (merge sort, quicksort, parallel tree traversal) split a problem into two independent sub-problems that can be solved concurrently. Key difference from OCaml: 1. **Adaptive**: `rayon::join` runs sequentially when the thread pool is saturated; manual `thread::spawn` always creates a thread.
Tutorial
The Problem
Divide-and-conquer algorithms (merge sort, quicksort, parallel tree traversal) split a problem into two independent sub-problems that can be solved concurrently. The fork-join model captures this: join(f, g) runs both f and g, potentially in parallel, waiting for both to complete. rayon::join is the idiomatic way to express this in Rust — it automatically decides whether to run in parallel (if threads are available) or sequentially (if the thread pool is saturated), adapting to load.
Fork-join appears in parallel merge sort, Fibonacci computation, tree operations, and any divide-and-conquer algorithm where sub-problems are independent.
🎯 Learning Outcomes
thread::spawn + join implements fork-join for 'static datathread::scope implements fork-join for borrowed datarayon::join's adaptive behavior (parallel or sequential based on pool state)Code Example
fn join<A, B>(f: impl FnOnce() -> A + Send,
g: impl FnOnce() -> B + Send) -> (A, B)
{
let handle = thread::spawn(f);
let b = g();
let a = handle.join().unwrap();
(a, b)
}Key Differences
rayon::join runs sequentially when the thread pool is saturated; manual thread::spawn always creates a thread.thread::scope enables fork-join with borrowed slices; rayon::join on borrowed data requires Sync bounds.join costs ~100μs; rayon::join uses the existing pool with near-zero overhead.rayon::join calls build a dynamic tree of tasks; manual recursive spawn causes thread count to grow exponentially.OCaml Approach
OCaml 5.x's Domain.spawn f + Domain.join h implements fork-join: let h = Domain.spawn f in let b = g () in let a = Domain.join h in (a, b). Domainslib.Task.async/await provide a higher-level composable version. For recursive divide-and-conquer, Domainslib.Task.parallel_for handles the recursion internally. OCaml 4.x's threads achieve fork-join but without parallelism.
Full Source
#![allow(clippy::all)]
//! # Rayon Join — Fork-Join Parallelism
//!
//! Execute two closures in parallel and wait for both results.
//! This is the core primitive for divide-and-conquer parallelism.
use std::thread;
/// Approach 1: Simple join - run two tasks in parallel
///
/// One task runs in a spawned thread, the other in the current thread.
pub fn join<A, B, FA, FB>(f: FA, g: FB) -> (A, B)
where
A: Send + 'static,
B: Send + 'static,
FA: FnOnce() -> A + Send + 'static,
FB: FnOnce() -> B + Send + 'static,
{
let handle = thread::spawn(f);
let b = g();
let a = handle.join().unwrap();
(a, b)
}
/// Approach 2: Scoped join for borrowed data
pub fn scoped_join<'a, A, B, FA, FB>(f: FA, g: FB) -> (A, B)
where
A: Send,
B: Send,
FA: FnOnce() -> A + Send + 'a,
FB: FnOnce() -> B + Send + 'a,
{
thread::scope(|s| {
let handle = s.spawn(f);
let b = g();
let a = handle.join().unwrap();
(a, b)
})
}
/// Merge two sorted vectors
fn merge(a: Vec<i64>, b: Vec<i64>) -> Vec<i64> {
let mut out = Vec::with_capacity(a.len() + b.len());
let (mut i, mut j) = (0, 0);
while i < a.len() && j < b.len() {
if a[i] <= b[j] {
out.push(a[i]);
i += 1;
} else {
out.push(b[j]);
j += 1;
}
}
out.extend_from_slice(&a[i..]);
out.extend_from_slice(&b[j..]);
out
}
/// Approach 3: Parallel merge sort using join
pub fn parallel_sort(mut v: Vec<i64>) -> Vec<i64> {
// Base case: small arrays use sequential sort
if v.len() <= 512 {
v.sort();
return v;
}
// Split and sort in parallel
let right = v.split_off(v.len() / 2);
let left = v;
let (sorted_left, sorted_right) =
join(move || parallel_sort(left), move || parallel_sort(right));
merge(sorted_left, sorted_right)
}
/// Approach 4: Parallel sum using join
pub fn parallel_sum(data: &[i64]) -> i64 {
const THRESHOLD: usize = 1000;
if data.len() <= THRESHOLD {
return data.iter().sum();
}
let mid = data.len() / 2;
let (left, right) = data.split_at(mid);
let (sum_left, sum_right) = scoped_join(|| parallel_sum(left), || parallel_sum(right));
sum_left + sum_right
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_join_basic() {
let (a, b) = join(|| 6 * 7, || "hello".len());
assert_eq!(a, 42);
assert_eq!(b, 5);
}
#[test]
fn test_join_independent_computations() {
let (sum1, sum2) = join(
|| (1u64..=5000).sum::<u64>(),
|| (5001u64..=10000).sum::<u64>(),
);
assert_eq!(sum1 + sum2, 50005000);
}
#[test]
fn test_scoped_join() {
let data = vec![1, 2, 3, 4, 5];
let (sum, len) = scoped_join(|| data.iter().sum::<i32>(), || data.len());
assert_eq!(sum, 15);
assert_eq!(len, 5);
}
#[test]
fn test_parallel_sort_small() {
let data = vec![5i64, 3, 8, 1, 9, 2, 7, 4, 6];
let sorted = parallel_sort(data);
assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_parallel_sort_large() {
let data: Vec<i64> = (0..2000).rev().collect();
let sorted = parallel_sort(data);
let expected: Vec<i64> = (0..2000).collect();
assert_eq!(sorted, expected);
}
#[test]
fn test_parallel_sum_small() {
let data: Vec<i64> = (1..=100).collect();
assert_eq!(parallel_sum(&data), 5050);
}
#[test]
fn test_parallel_sum_large() {
let data: Vec<i64> = (1..=10000).collect();
assert_eq!(parallel_sum(&data), 50005000);
}
#[test]
fn test_merge() {
let a = vec![1, 3, 5, 7];
let b = vec![2, 4, 6, 8];
assert_eq!(merge(a, b), vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
}#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_join_basic() {
let (a, b) = join(|| 6 * 7, || "hello".len());
assert_eq!(a, 42);
assert_eq!(b, 5);
}
#[test]
fn test_join_independent_computations() {
let (sum1, sum2) = join(
|| (1u64..=5000).sum::<u64>(),
|| (5001u64..=10000).sum::<u64>(),
);
assert_eq!(sum1 + sum2, 50005000);
}
#[test]
fn test_scoped_join() {
let data = vec![1, 2, 3, 4, 5];
let (sum, len) = scoped_join(|| data.iter().sum::<i32>(), || data.len());
assert_eq!(sum, 15);
assert_eq!(len, 5);
}
#[test]
fn test_parallel_sort_small() {
let data = vec![5i64, 3, 8, 1, 9, 2, 7, 4, 6];
let sorted = parallel_sort(data);
assert_eq!(sorted, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn test_parallel_sort_large() {
let data: Vec<i64> = (0..2000).rev().collect();
let sorted = parallel_sort(data);
let expected: Vec<i64> = (0..2000).collect();
assert_eq!(sorted, expected);
}
#[test]
fn test_parallel_sum_small() {
let data: Vec<i64> = (1..=100).collect();
assert_eq!(parallel_sum(&data), 5050);
}
#[test]
fn test_parallel_sum_large() {
let data: Vec<i64> = (1..=10000).collect();
assert_eq!(parallel_sum(&data), 50005000);
}
#[test]
fn test_merge() {
let a = vec![1, 3, 5, 7];
let b = vec![2, 4, 6, 8];
assert_eq!(merge(a, b), vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
}
Deep Comparison
OCaml vs Rust: Join (Fork-Join Parallelism)
Basic Join Pattern
OCaml
let join f g =
let result_g = ref None in
let thread = Thread.create (fun () ->
result_g := Some (g ())
) () in
let result_f = f () in
Thread.join thread;
(result_f, Option.get !result_g)
Rust
fn join<A, B>(f: impl FnOnce() -> A + Send,
g: impl FnOnce() -> B + Send) -> (A, B)
{
let handle = thread::spawn(f);
let b = g();
let a = handle.join().unwrap();
(a, b)
}
Key Differences
| Feature | OCaml | Rust |
|---|---|---|
| Return value | Via ref cell | Direct from join() |
| Type safety | Option.get can fail | Compile-time guaranteed |
| Thread spawning | Thread.create f () | thread::spawn(closure) |
| Result extraction | Manual unwrap | Built into handle |
Parallel Sum Example
OCaml
let rec psum arr lo hi =
if hi - lo <= 500 then
Array.fold_left (+) 0 (Array.sub arr lo (hi - lo))
else
let mid = (lo + hi) / 2 in
let (l, r) = join
(fun () -> psum arr lo mid)
(fun () -> psum arr mid hi)
in l + r
Rust
fn parallel_sum(data: &[i64]) -> i64 {
if data.len() <= 1000 {
return data.iter().sum();
}
let mid = data.len() / 2;
let (left, right) = data.split_at(mid);
let (sum_l, sum_r) = scoped_join(
|| parallel_sum(left),
|| parallel_sum(right),
);
sum_l + sum_r
}
Parallel Merge Sort
Rust
fn parallel_sort(mut v: Vec<i64>) -> Vec<i64> {
if v.len() <= 512 {
v.sort();
return v;
}
let right = v.split_off(v.len() / 2);
let left = v;
let (sorted_l, sorted_r) = join(
move || parallel_sort(left),
move || parallel_sort(right),
);
merge(sorted_l, sorted_r)
}
With Rayon (One-liner)
use rayon::prelude::*;
// Rayon's join handles work-stealing automatically
rayon::join(
|| expensive_computation_a(),
|| expensive_computation_b(),
);
Exercises
join function. Set a threshold below which sequential sort is used. Benchmark against slice::sort() for arrays of 1M, 10M, and 100M elements.parallel_fold(tree, identity, combine) using fork-join: fold left subtree in one branch, right subtree in another, combine results. Test with a balanced tree of 65535 nodes.Vec. Split at the midpoint, search left and right in parallel, return the first match found. Compare with sequential binary search.