ExamplesBy LevelBy TopicLearning Paths
366 Advanced

366: Segment Tree

Functional Programming

Tutorial Video

Text description (accessibility)

This video demonstrates the "366: Segment Tree" functional Rust example. Difficulty level: Advanced. Key concepts covered: Functional Programming. Database range aggregation, stock price range queries, and competitive programming problems frequently ask: "What is the sum/min/max over elements from index L to R?" With a plain array, this is O(n) per query. Key difference from OCaml: | Aspect | Rust `SegmentTree` | OCaml segment tree |

Tutorial

The Problem

Database range aggregation, stock price range queries, and competitive programming problems frequently ask: "What is the sum/min/max over elements from index L to R?" With a plain array, this is O(n) per query. Sorting allows binary search for range bounds but not arbitrary aggregates. The segment tree (developed independently in multiple contexts, popularized in competitive programming circa 1990s) achieves O(log n) for both range queries and point updates by maintaining a binary tree where each node stores the aggregate for a contiguous subarray. This is the data structure behind range aggregate functions in databases and time-series systems.

🎯 Learning Outcomes

  • • Build a segment tree stored in a flat array with the "1-indexed" convention (left child at 2v, right at 2v+1)
  • • Understand the heap-like layout: root at index 1, leaves at indices n..2n
  • • Implement recursive build from a base array in O(n)
  • • Implement recursive range query query(ql, qr) in O(log n)
  • • Implement point update update(pos, delta) propagating changes up the tree in O(log n)
  • • Generalize the aggregate operation (sum, min, max, GCD) by changing the merge function
  • Code Example

    fn build(&mut self, arr: &[i64], v: usize, l: usize, r: usize) {
        if l == r {
            self.data[v] = arr[l];
            return;
        }
        let m = (l + r) / 2;
        self.build(arr, 2 * v, l, m);
        self.build(arr, 2 * v + 1, m + 1, r);
        self.data[v] = self.data[2 * v] + self.data[2 * v + 1];
    }

    Key Differences

    AspectRust SegmentTreeOCaml segment tree
    StorageVec<i64> (heap)int array
    Mutability&mut self for updateArray mutation
    AggregateHardcoded + (extensible via generic)Hardcoded +
    Lazy propagationRequires lazy: Vec<i64> arraySame pattern
    AlternativeFenwick tree for prefix sums onlySame

    OCaml Approach

    let build arr =
      let n = Array.length arr in
      let data = Array.make (4 * n) 0 in
      let rec go v l r =
        if l = r then data.(v) <- arr.(l)
        else begin
          let m = (l + r) / 2 in
          go (2*v) l m; go (2*v+1) (m+1) r;
          data.(v) <- data.(2*v) + data.(2*v+1)
        end
      in
      go 1 0 (n-1); data
    
    let query data n ql qr =
      let rec go v l r =
        if ql > r || qr < l then 0
        else if ql <= l && r <= qr then data.(v)
        else let m = (l + r) / 2 in
          go (2*v) l m + go (2*v+1) (m+1) r
      in
      go 1 0 (n-1)
    

    The algorithm is identical — OCaml's recursive functions mirror Rust's recursive methods. Both use mutable arrays for the tree storage.

    Full Source

    #![allow(clippy::all)]
    //! Segment Tree for Range Queries
    //!
    //! O(log n) range queries and point updates.
    
    /// A segment tree for range sum queries
    pub struct SegmentTree {
        data: Vec<i64>,
        n: usize,
    }
    
    impl SegmentTree {
        // === Approach 1: Build from array ===
    
        /// Build a segment tree from an array
        pub fn new(arr: &[i64]) -> Self {
            let n = arr.len();
            let mut st = Self {
                data: vec![0; 4 * n],
                n,
            };
            if n > 0 {
                st.build(arr, 1, 0, n - 1);
            }
            st
        }
    
        fn build(&mut self, arr: &[i64], v: usize, l: usize, r: usize) {
            if l == r {
                self.data[v] = arr[l];
                return;
            }
            let m = (l + r) / 2;
            self.build(arr, 2 * v, l, m);
            self.build(arr, 2 * v + 1, m + 1, r);
            self.data[v] = self.data[2 * v] + self.data[2 * v + 1];
        }
    
        // === Approach 2: Range queries ===
    
        /// Query sum in range [ql, qr] - O(log n)
        pub fn query(&self, ql: usize, qr: usize) -> i64 {
            if self.n == 0 {
                return 0;
            }
            self.query_internal(1, 0, self.n - 1, ql, qr)
        }
    
        fn query_internal(&self, v: usize, l: usize, r: usize, ql: usize, qr: usize) -> i64 {
            if qr < l || r < ql {
                return 0;
            }
            if ql <= l && r <= qr {
                return self.data[v];
            }
            let m = (l + r) / 2;
            self.query_internal(2 * v, l, m, ql, qr) + self.query_internal(2 * v + 1, m + 1, r, ql, qr)
        }
    
        /// Alias for query - sum in range [l, r]
        pub fn sum(&self, l: usize, r: usize) -> i64 {
            self.query(l, r)
        }
    
        // === Approach 3: Point updates ===
    
        /// Set value at position pos - O(log n)
        pub fn set(&mut self, pos: usize, val: i64) {
            if self.n > 0 {
                self.update_internal(1, 0, self.n - 1, pos, val);
            }
        }
    
        fn update_internal(&mut self, v: usize, l: usize, r: usize, pos: usize, val: i64) {
            if l == r {
                self.data[v] = val;
                return;
            }
            let m = (l + r) / 2;
            if pos <= m {
                self.update_internal(2 * v, l, m, pos, val);
            } else {
                self.update_internal(2 * v + 1, m + 1, r, pos, val);
            }
            self.data[v] = self.data[2 * v] + self.data[2 * v + 1];
        }
    
        /// Add delta to value at position pos
        pub fn add(&mut self, pos: usize, delta: i64) {
            let current = self.query(pos, pos);
            self.set(pos, current + delta);
        }
    
        /// Get the size of the underlying array
        pub fn len(&self) -> usize {
            self.n
        }
    
        /// Check if empty
        pub fn is_empty(&self) -> bool {
            self.n == 0
        }
    }
    
    /// Generic segment tree with custom combine function
    pub struct GenericSegmentTree<T, F>
    where
        F: Fn(&T, &T) -> T,
    {
        data: Vec<Option<T>>,
        n: usize,
        identity: T,
        combine: F,
    }
    
    impl<T: Clone, F: Fn(&T, &T) -> T> GenericSegmentTree<T, F> {
        /// Create a new generic segment tree
        pub fn new(arr: &[T], identity: T, combine: F) -> Self {
            let n = arr.len();
            let mut st = Self {
                data: vec![None; 4 * n.max(1)],
                n,
                identity,
                combine,
            };
            if n > 0 {
                st.build(arr, 1, 0, n - 1);
            }
            st
        }
    
        fn build(&mut self, arr: &[T], v: usize, l: usize, r: usize) {
            if l == r {
                self.data[v] = Some(arr[l].clone());
                return;
            }
            let m = (l + r) / 2;
            self.build(arr, 2 * v, l, m);
            self.build(arr, 2 * v + 1, m + 1, r);
            let left = self.data[2 * v].as_ref().unwrap();
            let right = self.data[2 * v + 1].as_ref().unwrap();
            self.data[v] = Some((self.combine)(left, right));
        }
    
        /// Query range [ql, qr]
        pub fn query(&self, ql: usize, qr: usize) -> T {
            if self.n == 0 {
                return self.identity.clone();
            }
            self.query_internal(1, 0, self.n - 1, ql, qr)
        }
    
        fn query_internal(&self, v: usize, l: usize, r: usize, ql: usize, qr: usize) -> T {
            if qr < l || r < ql {
                return self.identity.clone();
            }
            if ql <= l && r <= qr {
                return self.data[v].as_ref().unwrap().clone();
            }
            let m = (l + r) / 2;
            let left = self.query_internal(2 * v, l, m, ql, qr);
            let right = self.query_internal(2 * v + 1, m + 1, r, ql, qr);
            (self.combine)(&left, &right)
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_range_sum() {
            let st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            assert_eq!(st.sum(0, 4), 15);
            assert_eq!(st.sum(1, 3), 9);
            assert_eq!(st.sum(2, 2), 3);
        }
    
        #[test]
        fn test_point_update() {
            let mut st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            st.set(2, 10);
            assert_eq!(st.sum(0, 4), 22); // 1+2+10+4+5
            assert_eq!(st.sum(2, 2), 10);
        }
    
        #[test]
        fn test_add() {
            let mut st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            st.add(2, 7);
            assert_eq!(st.sum(2, 2), 10);
        }
    
        #[test]
        fn test_single_element() {
            let st = SegmentTree::new(&[42]);
            assert_eq!(st.sum(0, 0), 42);
        }
    
        #[test]
        fn test_empty() {
            let st = SegmentTree::new(&[]);
            assert!(st.is_empty());
        }
    
        #[test]
        fn test_generic_max() {
            let arr = vec![3, 1, 4, 1, 5, 9, 2, 6];
            let st = GenericSegmentTree::new(&arr, i32::MIN, |a, b| *a.max(b));
            assert_eq!(st.query(0, 7), 9);
            assert_eq!(st.query(0, 4), 5);
            assert_eq!(st.query(5, 5), 9);
        }
    
        #[test]
        fn test_generic_min() {
            let arr = vec![3, 1, 4, 1, 5, 9, 2, 6];
            let st = GenericSegmentTree::new(&arr, i32::MAX, |a, b| *a.min(b));
            assert_eq!(st.query(0, 7), 1);
            assert_eq!(st.query(4, 7), 2);
        }
    
        #[test]
        fn test_multiple_updates() {
            let mut st = SegmentTree::new(&[1, 1, 1, 1, 1]);
            st.set(0, 5);
            st.set(4, 5);
            assert_eq!(st.sum(0, 4), 13);
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        #[test]
        fn test_range_sum() {
            let st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            assert_eq!(st.sum(0, 4), 15);
            assert_eq!(st.sum(1, 3), 9);
            assert_eq!(st.sum(2, 2), 3);
        }
    
        #[test]
        fn test_point_update() {
            let mut st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            st.set(2, 10);
            assert_eq!(st.sum(0, 4), 22); // 1+2+10+4+5
            assert_eq!(st.sum(2, 2), 10);
        }
    
        #[test]
        fn test_add() {
            let mut st = SegmentTree::new(&[1, 2, 3, 4, 5]);
            st.add(2, 7);
            assert_eq!(st.sum(2, 2), 10);
        }
    
        #[test]
        fn test_single_element() {
            let st = SegmentTree::new(&[42]);
            assert_eq!(st.sum(0, 0), 42);
        }
    
        #[test]
        fn test_empty() {
            let st = SegmentTree::new(&[]);
            assert!(st.is_empty());
        }
    
        #[test]
        fn test_generic_max() {
            let arr = vec![3, 1, 4, 1, 5, 9, 2, 6];
            let st = GenericSegmentTree::new(&arr, i32::MIN, |a, b| *a.max(b));
            assert_eq!(st.query(0, 7), 9);
            assert_eq!(st.query(0, 4), 5);
            assert_eq!(st.query(5, 5), 9);
        }
    
        #[test]
        fn test_generic_min() {
            let arr = vec![3, 1, 4, 1, 5, 9, 2, 6];
            let st = GenericSegmentTree::new(&arr, i32::MAX, |a, b| *a.min(b));
            assert_eq!(st.query(0, 7), 1);
            assert_eq!(st.query(4, 7), 2);
        }
    
        #[test]
        fn test_multiple_updates() {
            let mut st = SegmentTree::new(&[1, 1, 1, 1, 1]);
            st.set(0, 5);
            st.set(4, 5);
            assert_eq!(st.sum(0, 4), 13);
        }
    }

    Deep Comparison

    OCaml vs Rust: Segment Tree

    Side-by-Side Comparison

    Build

    OCaml:

    let rec build_ v l r =
      if l=r then t.data.(v) <- arr.(l)
      else begin
        let m = (l+r)/2 in
        build_ (2*v) l m;
        build_ (2*v+1) (m+1) r;
        t.data.(v) <- t.data.(2*v) + t.data.(2*v+1)
      end
    

    Rust:

    fn build(&mut self, arr: &[i64], v: usize, l: usize, r: usize) {
        if l == r {
            self.data[v] = arr[l];
            return;
        }
        let m = (l + r) / 2;
        self.build(arr, 2 * v, l, m);
        self.build(arr, 2 * v + 1, m + 1, r);
        self.data[v] = self.data[2 * v] + self.data[2 * v + 1];
    }
    

    Query

    OCaml:

    let rec query t v l r ql qr =
      if qr < l || r < ql then 0
      else if ql <= l && r <= qr then t.data.(v)
      else
        let m = (l+r)/2 in
        query t (2*v) l m ql qr + query t (2*v+1) (m+1) r ql qr
    

    Rust:

    fn query_internal(&self, v: usize, l: usize, r: usize, ql: usize, qr: usize) -> i64 {
        if qr < l || r < ql { return 0; }
        if ql <= l && r <= qr { return self.data[v]; }
        let m = (l + r) / 2;
        self.query_internal(2*v, l, m, ql, qr) + self.query_internal(2*v+1, m+1, r, ql, qr)
    }
    

    Key Differences

    AspectOCamlRust
    Index typeintusize
    Array accessarr.(i)arr[i]
    RecursionNaturalSame
    MutabilityMutable record&mut self

    Exercises

  • Min segment tree: Change the aggregate from + to min; implement range_min(ql, qr) and verify correctness on a range that spans multiple subtrees.
  • Lazy propagation: Implement range-update (add delta to all elements in [ql, qr]) using a lazy propagation array; push lazy values down when recursing into children.
  • Generic aggregate: Make SegmentTree<T> generic over a monoid (T, identity: T, combine: fn(T, T) -> T); test with (i64, 0, +), (i64, i64::MAX, min), and (i64, 1, *).
  • Open Source Repos