ExamplesBy LevelBy TopicLearning Paths
965 Advanced

965 Segment Tree

Functional Programming

Tutorial

The Problem

Implement a segment tree for range sum queries with point updates. The tree is stored in a flat array of size 4 * n. Each internal node stores the sum of its range. query(l, r) returns the sum over index range [l, r] in O(log n). update(pos, value) replaces the value at pos and propagates the change upward in O(log n).

🎯 Learning Outcomes

  • • Represent the segment tree in a flat Vec<i64> with 1-indexed nodes: children of node i are 2*i and 2*i+1
  • • Implement build recursively: leaf nodes store array values; internal nodes store left+right sums
  • • Implement query(l, r) with range splitting: when the query range fully covers a node's range, return directly; otherwise recurse into children
  • • Implement update(pos, value) that replaces the leaf and propagates sums back up the tree
  • • Understand why the flat array needs 4 * n entries (not 2 * n) to safely accommodate all possible recursion patterns
  • Code Example

    #![allow(clippy::all)]
    // 965: Segment Tree for Range Sum Queries
    // 1-indexed internal nodes; O(log n) point update and range sum
    
    pub struct SegmentTree {
        n: usize,
        tree: Vec<i64>,
    }
    
    impl SegmentTree {
        pub fn new(n: usize) -> Self {
            SegmentTree {
                n,
                tree: vec![0i64; 4 * n],
            }
        }
    
        pub fn build(&mut self, arr: &[i64]) {
            self.build_rec(1, 0, self.n - 1, arr);
        }
    
        fn build_rec(&mut self, node: usize, lo: usize, hi: usize, arr: &[i64]) {
            if lo == hi {
                self.tree[node] = arr[lo];
            } else {
                let mid = (lo + hi) / 2;
                self.build_rec(2 * node, lo, mid, arr);
                self.build_rec(2 * node + 1, mid + 1, hi, arr);
                self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
            }
        }
    
        /// Point update: set position `pos` to `value`
        pub fn update(&mut self, pos: usize, value: i64) {
            self.update_rec(1, 0, self.n - 1, pos, value);
        }
    
        fn update_rec(&mut self, node: usize, lo: usize, hi: usize, pos: usize, value: i64) {
            if lo == hi {
                self.tree[node] = value;
            } else {
                let mid = (lo + hi) / 2;
                if pos <= mid {
                    self.update_rec(2 * node, lo, mid, pos, value);
                } else {
                    self.update_rec(2 * node + 1, mid + 1, hi, pos, value);
                }
                self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
            }
        }
    
        /// Range sum query [l, r] (inclusive, 0-indexed)
        pub fn query(&self, l: usize, r: usize) -> i64 {
            self.query_rec(1, 0, self.n - 1, l, r)
        }
    
        fn query_rec(&self, node: usize, lo: usize, hi: usize, l: usize, r: usize) -> i64 {
            if r < lo || hi < l {
                0
            } else if l <= lo && hi <= r {
                self.tree[node]
            } else {
                let mid = (lo + hi) / 2;
                self.query_rec(2 * node, lo, mid, l, r)
                    + self.query_rec(2 * node + 1, mid + 1, hi, l, r)
            }
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_tree() -> SegmentTree {
            let arr = vec![1i64, 3, 5, 7, 9, 11];
            let mut st = SegmentTree::new(arr.len());
            st.build(&arr);
            st
        }
    
        #[test]
        fn test_total_sum() {
            let st = make_tree();
            assert_eq!(st.query(0, 5), 36);
        }
    
        #[test]
        fn test_range_queries() {
            let st = make_tree();
            assert_eq!(st.query(0, 2), 9); // 1+3+5
            assert_eq!(st.query(2, 4), 21); // 5+7+9
            assert_eq!(st.query(1, 3), 15); // 3+5+7
            assert_eq!(st.query(5, 5), 11); // single element
        }
    
        #[test]
        fn test_point_update() {
            let mut st = make_tree();
            st.update(2, 10); // replace 5 with 10
            assert_eq!(st.query(0, 5), 41); // 36 - 5 + 10
            assert_eq!(st.query(0, 2), 14); // 1+3+10
            assert_eq!(st.query(2, 4), 26); // 10+7+9
        }
    
        #[test]
        fn test_single_element() {
            let arr = vec![42i64];
            let mut st = SegmentTree::new(1);
            st.build(&arr);
            assert_eq!(st.query(0, 0), 42);
            st.update(0, 100);
            assert_eq!(st.query(0, 0), 100);
        }
    
        #[test]
        fn test_multiple_updates() {
            let mut st = make_tree();
            st.update(0, 0);
            st.update(5, 0);
            assert_eq!(st.query(0, 5), 24); // 0+3+5+7+9+0
        }
    }

    Key Differences

    AspectRustOCaml
    Array indexingself.tree[i]st.tree.(i)
    Recursive methods&mut self — unique mutable referenceMutable record fields
    Node size4 * nSame
    Identity value0 (hardcoded for sum)Same

    Segment trees support any associative operation (sum, min, max, GCD) by replacing the + in build, query, and update with the desired operation. The structure generalizes to lazy propagation for range updates.

    OCaml Approach

    type segment_tree = {
      n: int;
      tree: int array;
    }
    
    let create n = { n; tree = Array.make (4 * n) 0 }
    
    let rec build_rec st node lo hi arr =
      if lo = hi then st.tree.(node) <- arr.(lo)
      else
        let mid = (lo + hi) / 2 in
        build_rec st (2 * node) lo mid arr;
        build_rec st (2 * node + 1) (mid + 1) hi arr;
        st.tree.(node) <- st.tree.(2 * node) + st.tree.(2 * node + 1)
    
    let rec query_rec st node lo hi l r =
      if r < lo || hi < l then 0
      else if l <= lo && hi <= r then st.tree.(node)
      else
        let mid = (lo + hi) / 2 in
        query_rec st (2 * node) lo mid l r +
        query_rec st (2 * node + 1) (mid + 1) hi l r
    

    OCaml's mutable array st.tree.(node) <- value corresponds directly to Rust's self.tree[node] = value. The algorithm is structurally identical; the main syntactic difference is .(i) vs [i] for array indexing.

    Full Source

    #![allow(clippy::all)]
    // 965: Segment Tree for Range Sum Queries
    // 1-indexed internal nodes; O(log n) point update and range sum
    
    pub struct SegmentTree {
        n: usize,
        tree: Vec<i64>,
    }
    
    impl SegmentTree {
        pub fn new(n: usize) -> Self {
            SegmentTree {
                n,
                tree: vec![0i64; 4 * n],
            }
        }
    
        pub fn build(&mut self, arr: &[i64]) {
            self.build_rec(1, 0, self.n - 1, arr);
        }
    
        fn build_rec(&mut self, node: usize, lo: usize, hi: usize, arr: &[i64]) {
            if lo == hi {
                self.tree[node] = arr[lo];
            } else {
                let mid = (lo + hi) / 2;
                self.build_rec(2 * node, lo, mid, arr);
                self.build_rec(2 * node + 1, mid + 1, hi, arr);
                self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
            }
        }
    
        /// Point update: set position `pos` to `value`
        pub fn update(&mut self, pos: usize, value: i64) {
            self.update_rec(1, 0, self.n - 1, pos, value);
        }
    
        fn update_rec(&mut self, node: usize, lo: usize, hi: usize, pos: usize, value: i64) {
            if lo == hi {
                self.tree[node] = value;
            } else {
                let mid = (lo + hi) / 2;
                if pos <= mid {
                    self.update_rec(2 * node, lo, mid, pos, value);
                } else {
                    self.update_rec(2 * node + 1, mid + 1, hi, pos, value);
                }
                self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1];
            }
        }
    
        /// Range sum query [l, r] (inclusive, 0-indexed)
        pub fn query(&self, l: usize, r: usize) -> i64 {
            self.query_rec(1, 0, self.n - 1, l, r)
        }
    
        fn query_rec(&self, node: usize, lo: usize, hi: usize, l: usize, r: usize) -> i64 {
            if r < lo || hi < l {
                0
            } else if l <= lo && hi <= r {
                self.tree[node]
            } else {
                let mid = (lo + hi) / 2;
                self.query_rec(2 * node, lo, mid, l, r)
                    + self.query_rec(2 * node + 1, mid + 1, hi, l, r)
            }
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_tree() -> SegmentTree {
            let arr = vec![1i64, 3, 5, 7, 9, 11];
            let mut st = SegmentTree::new(arr.len());
            st.build(&arr);
            st
        }
    
        #[test]
        fn test_total_sum() {
            let st = make_tree();
            assert_eq!(st.query(0, 5), 36);
        }
    
        #[test]
        fn test_range_queries() {
            let st = make_tree();
            assert_eq!(st.query(0, 2), 9); // 1+3+5
            assert_eq!(st.query(2, 4), 21); // 5+7+9
            assert_eq!(st.query(1, 3), 15); // 3+5+7
            assert_eq!(st.query(5, 5), 11); // single element
        }
    
        #[test]
        fn test_point_update() {
            let mut st = make_tree();
            st.update(2, 10); // replace 5 with 10
            assert_eq!(st.query(0, 5), 41); // 36 - 5 + 10
            assert_eq!(st.query(0, 2), 14); // 1+3+10
            assert_eq!(st.query(2, 4), 26); // 10+7+9
        }
    
        #[test]
        fn test_single_element() {
            let arr = vec![42i64];
            let mut st = SegmentTree::new(1);
            st.build(&arr);
            assert_eq!(st.query(0, 0), 42);
            st.update(0, 100);
            assert_eq!(st.query(0, 0), 100);
        }
    
        #[test]
        fn test_multiple_updates() {
            let mut st = make_tree();
            st.update(0, 0);
            st.update(5, 0);
            assert_eq!(st.query(0, 5), 24); // 0+3+5+7+9+0
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        fn make_tree() -> SegmentTree {
            let arr = vec![1i64, 3, 5, 7, 9, 11];
            let mut st = SegmentTree::new(arr.len());
            st.build(&arr);
            st
        }
    
        #[test]
        fn test_total_sum() {
            let st = make_tree();
            assert_eq!(st.query(0, 5), 36);
        }
    
        #[test]
        fn test_range_queries() {
            let st = make_tree();
            assert_eq!(st.query(0, 2), 9); // 1+3+5
            assert_eq!(st.query(2, 4), 21); // 5+7+9
            assert_eq!(st.query(1, 3), 15); // 3+5+7
            assert_eq!(st.query(5, 5), 11); // single element
        }
    
        #[test]
        fn test_point_update() {
            let mut st = make_tree();
            st.update(2, 10); // replace 5 with 10
            assert_eq!(st.query(0, 5), 41); // 36 - 5 + 10
            assert_eq!(st.query(0, 2), 14); // 1+3+10
            assert_eq!(st.query(2, 4), 26); // 10+7+9
        }
    
        #[test]
        fn test_single_element() {
            let arr = vec![42i64];
            let mut st = SegmentTree::new(1);
            st.build(&arr);
            assert_eq!(st.query(0, 0), 42);
            st.update(0, 100);
            assert_eq!(st.query(0, 0), 100);
        }
    
        #[test]
        fn test_multiple_updates() {
            let mut st = make_tree();
            st.update(0, 0);
            st.update(5, 0);
            assert_eq!(st.query(0, 5), 24); // 0+3+5+7+9+0
        }
    }

    Deep Comparison

    Segment Tree — Comparison

    Core Insight

    A segment tree stores aggregate values (sums, min, max) for array ranges in a complete binary tree laid out in an array. Node i covers a range; its children at 2i and 2i+1 cover the halves. Both OCaml and Rust implement identical recursive build/update/query — the tree layout and algorithm are language-agnostic.

    OCaml Approach

  • Array.make (4 * n) 0 — 4n slots ensures enough space for any n
  • • Recursive build, update, query as top-level functions
  • • Public wrappers st_build, st_update, st_query start at node 1
  • st.tree.(node) <- st.tree.(2*node) + st.tree.(2*node+1) — push-up
  • • Early returns 0 for out-of-range, st.tree.(node) for fully covered
  • Rust Approach

  • vec![0i64; 4 * n] — same layout
  • • Private build_rec, update_rec, query_rec methods on struct
  • • Public build, update, query as clean API
  • self.tree[node] = self.tree[2*node] + self.tree[2*node+1] — push-up
  • usize indices (no negative — avoids signed/unsigned confusion)
  • Comparison Table

    AspectOCamlRust
    Storageint array (4n)Vec<i64> (4n)
    RecursionTop-level functions with st argPrivate _rec methods
    Index typeintusize
    Push-uptree.(n) <- tree.(2n) + tree.(2n+1)tree[n] = tree[2n] + tree[2n+1]
    Range miss00
    Range covertree.(node)tree[node]
    Split pointmid = (lo + hi) / 2mid = (lo + hi) / 2
    Build initArray.make (4*n) 0vec![0i64; 4*n]

    Exercises

  • Implement update(&mut self, pos: usize, value: i64) that replaces the value at pos and propagates.
  • Extend to support range minimum queries by replacing + with min in build and query.
  • Implement lazy propagation: range_add(l, r, delta) that adds delta to all elements in [l, r] in O(log n).
  • Generalize with a fold_fn: Fn(i64, i64) -> i64 and identity: i64 parameter to support arbitrary monoids.
  • Implement a persistent segment tree where each update returns a new root (for historical queries).
  • Open Source Repos