ExamplesBy LevelBy TopicLearning Paths
410 Fundamental

410: Arithmetic Operator Overloading

Functional Programming

Tutorial Video

Text description (accessibility)

This video demonstrates the "410: Arithmetic Operator Overloading" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Mathematical types — vectors, matrices, complex numbers, rational numbers, polynomials — are naturally expressed with arithmetic notation: `v1 + v2`, `m * v`, `-c`. Key difference from OCaml: 1. **Type dispatch**: Rust's operator overloading is type

Tutorial

The Problem

Mathematical types — vectors, matrices, complex numbers, rational numbers, polynomials — are naturally expressed with arithmetic notation: v1 + v2, m * v, -c. Without operator overloading, these become v1.add(&v2), making code verbose and harder to read than the underlying mathematics. Rust's std::ops module provides traits for every arithmetic operator: Add, Sub, Mul, Div, Neg, Rem, and their *Assign variants. Implementing these makes custom types feel like first-class numeric citizens.

Operator overloading powers Rust's numeric libraries: nalgebra (linear algebra), num (arbitrary precision), ndarray (array computing), and embedded DSP/signal processing libraries.

🎯 Learning Outcomes

  • • Understand how std::ops::Add, Sub, Mul, Div, Neg enable arithmetic operators for custom types
  • • Learn the associated type Output that determines the result type of each operation
  • • See how AddAssign enables += in addition to +
  • • Understand how scalar multiplication (Vec2 * f64) requires separate impl from vector multiplication
  • • Learn how Copy types benefit from consuming self in operator implementations
  • Code Example

    use std::ops::{Add, Mul};
    
    #[derive(Clone, Copy)]
    struct Vec2 { x: f64, y: f64 }
    
    impl Add for Vec2 {
        type Output = Vec2;
        fn add(self, other: Vec2) -> Vec2 {
            Vec2 { x: self.x + other.x, y: self.y + other.y }
        }
    }
    
    impl Mul<f64> for Vec2 {
        type Output = Vec2;
        fn mul(self, s: f64) -> Vec2 {
            Vec2 { x: self.x * s, y: self.y * s }
        }
    }
    
    fn main() {
        let a = Vec2 { x: 3.0, y: 4.0 };
        let b = Vec2 { x: 1.0, y: 2.0 };
        let sum = a + b;       // Uses Add trait
        let scaled = a * 2.0;  // Uses Mul<f64>
    }

    Key Differences

  • Type dispatch: Rust's operator overloading is type-directed — the compiler selects the right Add impl based on operand types; OCaml's operator shadows apply lexically.
  • Mixed-type ops: Rust can implement Add<f64> for Vec2 and Add<Vec2> for f64 separately; OCaml requires either uniform types or explicit conversions.
  • No global pollution: Rust's overloaded operators only affect the specific types involved; OCaml's let (+) = Vec2.add shadows the integer + for the whole module scope.
  • Commutativity: Rust requires separate impls for a + b and b + a when types differ; OCaml's + is always a single function.
  • OCaml Approach

    OCaml supports operator overloading through module local redefinition. let (+) = Vec2.add redefined locally within a module shadow the built-in + for the scope. This is more fragile than Rust's trait system — the shadows apply to the entire module and cannot be scope-restricted to specific types. The Base library uses module-local operator overloading extensively for its numeric types.

    Full Source

    #![allow(clippy::all)]
    //! Arithmetic Operator Overloading
    //!
    //! Using Add, Sub, Mul, Div, Neg traits for custom types.
    
    use std::fmt;
    use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
    
    /// A 2D vector with operator overloading.
    #[derive(Debug, Clone, Copy, PartialEq)]
    pub struct Vec2 {
        pub x: f64,
        pub y: f64,
    }
    
    impl Vec2 {
        /// Creates a new vector.
        pub fn new(x: f64, y: f64) -> Self {
            Vec2 { x, y }
        }
    
        /// Creates a zero vector.
        pub fn zero() -> Self {
            Vec2 { x: 0.0, y: 0.0 }
        }
    
        /// Creates a unit vector in the X direction.
        pub fn unit_x() -> Self {
            Vec2 { x: 1.0, y: 0.0 }
        }
    
        /// Creates a unit vector in the Y direction.
        pub fn unit_y() -> Self {
            Vec2 { x: 0.0, y: 1.0 }
        }
    
        /// Calculates the magnitude (length) of the vector.
        pub fn magnitude(self) -> f64 {
            (self.x * self.x + self.y * self.y).sqrt()
        }
    
        /// Calculates the dot product with another vector.
        pub fn dot(self, other: Vec2) -> f64 {
            self.x * other.x + self.y * other.y
        }
    
        /// Returns a normalized (unit) vector.
        pub fn normalized(self) -> Vec2 {
            let mag = self.magnitude();
            if mag > 0.0 {
                self / mag
            } else {
                Vec2::zero()
            }
        }
    }
    
    impl Add for Vec2 {
        type Output = Vec2;
    
        fn add(self, other: Vec2) -> Vec2 {
            Vec2::new(self.x + other.x, self.y + other.y)
        }
    }
    
    impl Sub for Vec2 {
        type Output = Vec2;
    
        fn sub(self, other: Vec2) -> Vec2 {
            Vec2::new(self.x - other.x, self.y - other.y)
        }
    }
    
    impl Mul<f64> for Vec2 {
        type Output = Vec2;
    
        fn mul(self, scalar: f64) -> Vec2 {
            Vec2::new(self.x * scalar, self.y * scalar)
        }
    }
    
    // Scalar * Vec2
    impl Mul<Vec2> for f64 {
        type Output = Vec2;
    
        fn mul(self, v: Vec2) -> Vec2 {
            Vec2::new(self * v.x, self * v.y)
        }
    }
    
    impl Div<f64> for Vec2 {
        type Output = Vec2;
    
        fn div(self, scalar: f64) -> Vec2 {
            Vec2::new(self.x / scalar, self.y / scalar)
        }
    }
    
    impl Neg for Vec2 {
        type Output = Vec2;
    
        fn neg(self) -> Vec2 {
            Vec2::new(-self.x, -self.y)
        }
    }
    
    impl AddAssign for Vec2 {
        fn add_assign(&mut self, other: Vec2) {
            self.x += other.x;
            self.y += other.y;
        }
    }
    
    impl fmt::Display for Vec2 {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            write!(f, "Vec2({:.2}, {:.2})", self.x, self.y)
        }
    }
    
    /// A complex number with arithmetic operators.
    #[derive(Debug, Clone, Copy, PartialEq)]
    pub struct Complex {
        pub re: f64,
        pub im: f64,
    }
    
    impl Complex {
        pub fn new(re: f64, im: f64) -> Self {
            Complex { re, im }
        }
    
        pub fn from_polar(r: f64, theta: f64) -> Self {
            Complex {
                re: r * theta.cos(),
                im: r * theta.sin(),
            }
        }
    
        pub fn magnitude(self) -> f64 {
            (self.re * self.re + self.im * self.im).sqrt()
        }
    
        pub fn conjugate(self) -> Complex {
            Complex::new(self.re, -self.im)
        }
    }
    
    impl Add for Complex {
        type Output = Complex;
    
        fn add(self, other: Complex) -> Complex {
            Complex::new(self.re + other.re, self.im + other.im)
        }
    }
    
    impl Sub for Complex {
        type Output = Complex;
    
        fn sub(self, other: Complex) -> Complex {
            Complex::new(self.re - other.re, self.im - other.im)
        }
    }
    
    impl Mul for Complex {
        type Output = Complex;
    
        fn mul(self, other: Complex) -> Complex {
            // (a + bi)(c + di) = (ac - bd) + (ad + bc)i
            Complex::new(
                self.re * other.re - self.im * other.im,
                self.re * other.im + self.im * other.re,
            )
        }
    }
    
    impl Neg for Complex {
        type Output = Complex;
    
        fn neg(self) -> Complex {
            Complex::new(-self.re, -self.im)
        }
    }
    
    impl fmt::Display for Complex {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
            if self.im >= 0.0 {
                write!(f, "{:.2} + {:.2}i", self.re, self.im)
            } else {
                write!(f, "{:.2} - {:.2}i", self.re, -self.im)
            }
        }
    }
    
    #[cfg(test)]
    mod tests {
        use super::*;
    
        const EPSILON: f64 = 1e-9;
    
        fn approx_eq(a: f64, b: f64) -> bool {
            (a - b).abs() < EPSILON
        }
    
        #[test]
        fn test_vec2_add() {
            let a = Vec2::new(1.0, 2.0);
            let b = Vec2::new(3.0, 4.0);
            assert_eq!(a + b, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_sub() {
            let a = Vec2::new(5.0, 7.0);
            let b = Vec2::new(2.0, 3.0);
            assert_eq!(a - b, Vec2::new(3.0, 4.0));
        }
    
        #[test]
        fn test_vec2_mul_scalar() {
            let v = Vec2::new(2.0, 3.0);
            assert_eq!(v * 2.0, Vec2::new(4.0, 6.0));
            assert_eq!(2.0 * v, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_div() {
            let v = Vec2::new(4.0, 6.0);
            assert_eq!(v / 2.0, Vec2::new(2.0, 3.0));
        }
    
        #[test]
        fn test_vec2_neg() {
            let v = Vec2::new(1.0, -2.0);
            assert_eq!(-v, Vec2::new(-1.0, 2.0));
        }
    
        #[test]
        fn test_vec2_add_assign() {
            let mut v = Vec2::new(1.0, 2.0);
            v += Vec2::new(3.0, 4.0);
            assert_eq!(v, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_magnitude() {
            let v = Vec2::new(3.0, 4.0);
            assert!(approx_eq(v.magnitude(), 5.0));
        }
    
        #[test]
        fn test_vec2_dot() {
            let a = Vec2::new(1.0, 2.0);
            let b = Vec2::new(3.0, 4.0);
            assert!(approx_eq(a.dot(b), 11.0)); // 1*3 + 2*4
        }
    
        #[test]
        fn test_vec2_normalized() {
            let v = Vec2::new(3.0, 4.0);
            let n = v.normalized();
            assert!(approx_eq(n.magnitude(), 1.0));
        }
    
        #[test]
        fn test_complex_add() {
            let a = Complex::new(1.0, 2.0);
            let b = Complex::new(3.0, 4.0);
            assert_eq!(a + b, Complex::new(4.0, 6.0));
        }
    
        #[test]
        fn test_complex_mul() {
            let a = Complex::new(1.0, 2.0);
            let b = Complex::new(3.0, 4.0);
            // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = 3 + 10i - 8 = -5 + 10i
            assert_eq!(a * b, Complex::new(-5.0, 10.0));
        }
    
        #[test]
        fn test_complex_conjugate() {
            let c = Complex::new(3.0, 4.0);
            let conj = c.conjugate();
            assert_eq!(conj, Complex::new(3.0, -4.0));
        }
    
        #[test]
        fn test_complex_magnitude() {
            let c = Complex::new(3.0, 4.0);
            assert!(approx_eq(c.magnitude(), 5.0));
        }
    }
    ✓ Tests Rust test suite
    #[cfg(test)]
    mod tests {
        use super::*;
    
        const EPSILON: f64 = 1e-9;
    
        fn approx_eq(a: f64, b: f64) -> bool {
            (a - b).abs() < EPSILON
        }
    
        #[test]
        fn test_vec2_add() {
            let a = Vec2::new(1.0, 2.0);
            let b = Vec2::new(3.0, 4.0);
            assert_eq!(a + b, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_sub() {
            let a = Vec2::new(5.0, 7.0);
            let b = Vec2::new(2.0, 3.0);
            assert_eq!(a - b, Vec2::new(3.0, 4.0));
        }
    
        #[test]
        fn test_vec2_mul_scalar() {
            let v = Vec2::new(2.0, 3.0);
            assert_eq!(v * 2.0, Vec2::new(4.0, 6.0));
            assert_eq!(2.0 * v, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_div() {
            let v = Vec2::new(4.0, 6.0);
            assert_eq!(v / 2.0, Vec2::new(2.0, 3.0));
        }
    
        #[test]
        fn test_vec2_neg() {
            let v = Vec2::new(1.0, -2.0);
            assert_eq!(-v, Vec2::new(-1.0, 2.0));
        }
    
        #[test]
        fn test_vec2_add_assign() {
            let mut v = Vec2::new(1.0, 2.0);
            v += Vec2::new(3.0, 4.0);
            assert_eq!(v, Vec2::new(4.0, 6.0));
        }
    
        #[test]
        fn test_vec2_magnitude() {
            let v = Vec2::new(3.0, 4.0);
            assert!(approx_eq(v.magnitude(), 5.0));
        }
    
        #[test]
        fn test_vec2_dot() {
            let a = Vec2::new(1.0, 2.0);
            let b = Vec2::new(3.0, 4.0);
            assert!(approx_eq(a.dot(b), 11.0)); // 1*3 + 2*4
        }
    
        #[test]
        fn test_vec2_normalized() {
            let v = Vec2::new(3.0, 4.0);
            let n = v.normalized();
            assert!(approx_eq(n.magnitude(), 1.0));
        }
    
        #[test]
        fn test_complex_add() {
            let a = Complex::new(1.0, 2.0);
            let b = Complex::new(3.0, 4.0);
            assert_eq!(a + b, Complex::new(4.0, 6.0));
        }
    
        #[test]
        fn test_complex_mul() {
            let a = Complex::new(1.0, 2.0);
            let b = Complex::new(3.0, 4.0);
            // (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = 3 + 10i - 8 = -5 + 10i
            assert_eq!(a * b, Complex::new(-5.0, 10.0));
        }
    
        #[test]
        fn test_complex_conjugate() {
            let c = Complex::new(3.0, 4.0);
            let conj = c.conjugate();
            assert_eq!(conj, Complex::new(3.0, -4.0));
        }
    
        #[test]
        fn test_complex_magnitude() {
            let c = Complex::new(3.0, 4.0);
            assert!(approx_eq(c.magnitude(), 5.0));
        }
    }

    Deep Comparison

    OCaml vs Rust: Arithmetic Operator Overloading

    Side-by-Side Code

    OCaml — Custom infix operators

    type vec2 = { x: float; y: float }
    
    (* Define custom operators *)
    let ( +^ ) a b = { x = a.x +. b.x; y = a.y +. b.y }
    let ( -^ ) a b = { x = a.x -. b.x; y = a.y -. b.y }
    let ( *^ ) s v = { x = s *. v.x; y = s *. v.y }
    
    let () =
      let a = { x = 3.0; y = 4.0 } in
      let b = { x = 1.0; y = 2.0 } in
      let sum = a +^ b in
      let scaled = 2.0 *^ a in
      ()
    

    Rust — Traits for standard operators

    use std::ops::{Add, Mul};
    
    #[derive(Clone, Copy)]
    struct Vec2 { x: f64, y: f64 }
    
    impl Add for Vec2 {
        type Output = Vec2;
        fn add(self, other: Vec2) -> Vec2 {
            Vec2 { x: self.x + other.x, y: self.y + other.y }
        }
    }
    
    impl Mul<f64> for Vec2 {
        type Output = Vec2;
        fn mul(self, s: f64) -> Vec2 {
            Vec2 { x: self.x * s, y: self.y * s }
        }
    }
    
    fn main() {
        let a = Vec2 { x: 3.0, y: 4.0 };
        let b = Vec2 { x: 1.0, y: 2.0 };
        let sum = a + b;       // Uses Add trait
        let scaled = a * 2.0;  // Uses Mul<f64>
    }
    

    Comparison Table

    AspectOCamlRust
    Operator styleCustom infix: +^, *^Standard: +, *
    Definitionlet ( +^ ) a b = ...impl Add for Type
    Return typeInferredtype Output = ...
    Symmetric opsDefine both sidesimpl Mul<Vec2> for f64
    Compound opsNot built-inAddAssign, MulAssign, etc.

    Operator Trait Mapping

    OperatorTraitMethod
    +Addadd(self, rhs)
    -Subsub(self, rhs)
    *Mulmul(self, rhs)
    /Divdiv(self, rhs)
    %Remrem(self, rhs)
    -x (unary)Negneg(self)
    +=AddAssignadd_assign(&mut self, rhs)

    Symmetric Operations

    For 2.0 * vec, you need to implement both directions:

    // vec * scalar
    impl Mul<f64> for Vec2 {
        type Output = Vec2;
        fn mul(self, s: f64) -> Vec2 { ... }
    }
    
    // scalar * vec
    impl Mul<Vec2> for f64 {
        type Output = Vec2;
        fn mul(self, v: Vec2) -> Vec2 { ... }
    }
    

    Compound Assignment

    impl AddAssign for Vec2 {
        fn add_assign(&mut self, other: Vec2) {
            self.x += other.x;
            self.y += other.y;
        }
    }
    
    let mut v = Vec2::new(1.0, 2.0);
    v += Vec2::new(3.0, 4.0);  // Uses AddAssign
    

    5 Takeaways

  • **OCaml requires custom operator names; Rust uses standard +, *.**
  • Rust's approach feels more natural for math types.

  • Rust traits define both input and output types.
  • Mul<f64> with type Output = Vec2 is explicit.

  • Symmetric operations need two impls.
  • Vec2 * f64 and f64 * Vec2 are separate.

  • Compound assignment is a separate trait.
  • += uses AddAssign, not Add.

  • Rust's Output associated type enables type transformations.
  • Matrix * Vector = Vector (different output type).

    Exercises

  • Complex number: Implement Complex { re: f64, im: f64 } with all four arithmetic operators plus Neg. Ensure i * i = -1 passes as a test. Implement Display as a + bi.
  • Matrix multiplication: Create a Matrix2x2([[f64; 2]; 2]) and implement Mul<Matrix2x2> (matrix multiplication), Mul<Vec2> (matrix-vector product), and Add<Matrix2x2>. Verify with the identity matrix.
  • Polynomial arithmetic: Implement Polynomial(Vec<f64>) where coeffs[i] is the coefficient of x^i. Implement Add (zip coefficients), Sub, Mul (polynomial multiplication using convolution), and Display (3x^2 + 2x + 1).
  • Open Source Repos