410: Arithmetic Operator Overloading
Tutorial Video
Text description (accessibility)
This video demonstrates the "410: Arithmetic Operator Overloading" functional Rust example. Difficulty level: Fundamental. Key concepts covered: Functional Programming. Mathematical types ā vectors, matrices, complex numbers, rational numbers, polynomials ā are naturally expressed with arithmetic notation: `v1 + v2`, `m * v`, `-c`. Key difference from OCaml: 1. **Type dispatch**: Rust's operator overloading is type
Tutorial
The Problem
Mathematical types ā vectors, matrices, complex numbers, rational numbers, polynomials ā are naturally expressed with arithmetic notation: v1 + v2, m * v, -c. Without operator overloading, these become v1.add(&v2), making code verbose and harder to read than the underlying mathematics. Rust's std::ops module provides traits for every arithmetic operator: Add, Sub, Mul, Div, Neg, Rem, and their *Assign variants. Implementing these makes custom types feel like first-class numeric citizens.
Operator overloading powers Rust's numeric libraries: nalgebra (linear algebra), num (arbitrary precision), ndarray (array computing), and embedded DSP/signal processing libraries.
🎯 Learning Outcomes
std::ops::Add, Sub, Mul, Div, Neg enable arithmetic operators for custom typesOutput that determines the result type of each operationAddAssign enables += in addition to +Vec2 * f64) requires separate impl from vector multiplicationCopy types benefit from consuming self in operator implementationsCode Example
use std::ops::{Add, Mul};
#[derive(Clone, Copy)]
struct Vec2 { x: f64, y: f64 }
impl Add for Vec2 {
type Output = Vec2;
fn add(self, other: Vec2) -> Vec2 {
Vec2 { x: self.x + other.x, y: self.y + other.y }
}
}
impl Mul<f64> for Vec2 {
type Output = Vec2;
fn mul(self, s: f64) -> Vec2 {
Vec2 { x: self.x * s, y: self.y * s }
}
}
fn main() {
let a = Vec2 { x: 3.0, y: 4.0 };
let b = Vec2 { x: 1.0, y: 2.0 };
let sum = a + b; // Uses Add trait
let scaled = a * 2.0; // Uses Mul<f64>
}Key Differences
Add impl based on operand types; OCaml's operator shadows apply lexically.Add<f64> for Vec2 and Add<Vec2> for f64 separately; OCaml requires either uniform types or explicit conversions.let (+) = Vec2.add shadows the integer + for the whole module scope.a + b and b + a when types differ; OCaml's + is always a single function.OCaml Approach
OCaml supports operator overloading through module local redefinition. let (+) = Vec2.add redefined locally within a module shadow the built-in + for the scope. This is more fragile than Rust's trait system ā the shadows apply to the entire module and cannot be scope-restricted to specific types. The Base library uses module-local operator overloading extensively for its numeric types.
Full Source
#![allow(clippy::all)]
//! Arithmetic Operator Overloading
//!
//! Using Add, Sub, Mul, Div, Neg traits for custom types.
use std::fmt;
use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub};
/// A 2D vector with operator overloading.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Vec2 {
pub x: f64,
pub y: f64,
}
impl Vec2 {
/// Creates a new vector.
pub fn new(x: f64, y: f64) -> Self {
Vec2 { x, y }
}
/// Creates a zero vector.
pub fn zero() -> Self {
Vec2 { x: 0.0, y: 0.0 }
}
/// Creates a unit vector in the X direction.
pub fn unit_x() -> Self {
Vec2 { x: 1.0, y: 0.0 }
}
/// Creates a unit vector in the Y direction.
pub fn unit_y() -> Self {
Vec2 { x: 0.0, y: 1.0 }
}
/// Calculates the magnitude (length) of the vector.
pub fn magnitude(self) -> f64 {
(self.x * self.x + self.y * self.y).sqrt()
}
/// Calculates the dot product with another vector.
pub fn dot(self, other: Vec2) -> f64 {
self.x * other.x + self.y * other.y
}
/// Returns a normalized (unit) vector.
pub fn normalized(self) -> Vec2 {
let mag = self.magnitude();
if mag > 0.0 {
self / mag
} else {
Vec2::zero()
}
}
}
impl Add for Vec2 {
type Output = Vec2;
fn add(self, other: Vec2) -> Vec2 {
Vec2::new(self.x + other.x, self.y + other.y)
}
}
impl Sub for Vec2 {
type Output = Vec2;
fn sub(self, other: Vec2) -> Vec2 {
Vec2::new(self.x - other.x, self.y - other.y)
}
}
impl Mul<f64> for Vec2 {
type Output = Vec2;
fn mul(self, scalar: f64) -> Vec2 {
Vec2::new(self.x * scalar, self.y * scalar)
}
}
// Scalar * Vec2
impl Mul<Vec2> for f64 {
type Output = Vec2;
fn mul(self, v: Vec2) -> Vec2 {
Vec2::new(self * v.x, self * v.y)
}
}
impl Div<f64> for Vec2 {
type Output = Vec2;
fn div(self, scalar: f64) -> Vec2 {
Vec2::new(self.x / scalar, self.y / scalar)
}
}
impl Neg for Vec2 {
type Output = Vec2;
fn neg(self) -> Vec2 {
Vec2::new(-self.x, -self.y)
}
}
impl AddAssign for Vec2 {
fn add_assign(&mut self, other: Vec2) {
self.x += other.x;
self.y += other.y;
}
}
impl fmt::Display for Vec2 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Vec2({:.2}, {:.2})", self.x, self.y)
}
}
/// A complex number with arithmetic operators.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Complex {
pub re: f64,
pub im: f64,
}
impl Complex {
pub fn new(re: f64, im: f64) -> Self {
Complex { re, im }
}
pub fn from_polar(r: f64, theta: f64) -> Self {
Complex {
re: r * theta.cos(),
im: r * theta.sin(),
}
}
pub fn magnitude(self) -> f64 {
(self.re * self.re + self.im * self.im).sqrt()
}
pub fn conjugate(self) -> Complex {
Complex::new(self.re, -self.im)
}
}
impl Add for Complex {
type Output = Complex;
fn add(self, other: Complex) -> Complex {
Complex::new(self.re + other.re, self.im + other.im)
}
}
impl Sub for Complex {
type Output = Complex;
fn sub(self, other: Complex) -> Complex {
Complex::new(self.re - other.re, self.im - other.im)
}
}
impl Mul for Complex {
type Output = Complex;
fn mul(self, other: Complex) -> Complex {
// (a + bi)(c + di) = (ac - bd) + (ad + bc)i
Complex::new(
self.re * other.re - self.im * other.im,
self.re * other.im + self.im * other.re,
)
}
}
impl Neg for Complex {
type Output = Complex;
fn neg(self) -> Complex {
Complex::new(-self.re, -self.im)
}
}
impl fmt::Display for Complex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.im >= 0.0 {
write!(f, "{:.2} + {:.2}i", self.re, self.im)
} else {
write!(f, "{:.2} - {:.2}i", self.re, -self.im)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-9;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_vec2_add() {
let a = Vec2::new(1.0, 2.0);
let b = Vec2::new(3.0, 4.0);
assert_eq!(a + b, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_sub() {
let a = Vec2::new(5.0, 7.0);
let b = Vec2::new(2.0, 3.0);
assert_eq!(a - b, Vec2::new(3.0, 4.0));
}
#[test]
fn test_vec2_mul_scalar() {
let v = Vec2::new(2.0, 3.0);
assert_eq!(v * 2.0, Vec2::new(4.0, 6.0));
assert_eq!(2.0 * v, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_div() {
let v = Vec2::new(4.0, 6.0);
assert_eq!(v / 2.0, Vec2::new(2.0, 3.0));
}
#[test]
fn test_vec2_neg() {
let v = Vec2::new(1.0, -2.0);
assert_eq!(-v, Vec2::new(-1.0, 2.0));
}
#[test]
fn test_vec2_add_assign() {
let mut v = Vec2::new(1.0, 2.0);
v += Vec2::new(3.0, 4.0);
assert_eq!(v, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_magnitude() {
let v = Vec2::new(3.0, 4.0);
assert!(approx_eq(v.magnitude(), 5.0));
}
#[test]
fn test_vec2_dot() {
let a = Vec2::new(1.0, 2.0);
let b = Vec2::new(3.0, 4.0);
assert!(approx_eq(a.dot(b), 11.0)); // 1*3 + 2*4
}
#[test]
fn test_vec2_normalized() {
let v = Vec2::new(3.0, 4.0);
let n = v.normalized();
assert!(approx_eq(n.magnitude(), 1.0));
}
#[test]
fn test_complex_add() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
assert_eq!(a + b, Complex::new(4.0, 6.0));
}
#[test]
fn test_complex_mul() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
// (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = 3 + 10i - 8 = -5 + 10i
assert_eq!(a * b, Complex::new(-5.0, 10.0));
}
#[test]
fn test_complex_conjugate() {
let c = Complex::new(3.0, 4.0);
let conj = c.conjugate();
assert_eq!(conj, Complex::new(3.0, -4.0));
}
#[test]
fn test_complex_magnitude() {
let c = Complex::new(3.0, 4.0);
assert!(approx_eq(c.magnitude(), 5.0));
}
}#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-9;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPSILON
}
#[test]
fn test_vec2_add() {
let a = Vec2::new(1.0, 2.0);
let b = Vec2::new(3.0, 4.0);
assert_eq!(a + b, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_sub() {
let a = Vec2::new(5.0, 7.0);
let b = Vec2::new(2.0, 3.0);
assert_eq!(a - b, Vec2::new(3.0, 4.0));
}
#[test]
fn test_vec2_mul_scalar() {
let v = Vec2::new(2.0, 3.0);
assert_eq!(v * 2.0, Vec2::new(4.0, 6.0));
assert_eq!(2.0 * v, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_div() {
let v = Vec2::new(4.0, 6.0);
assert_eq!(v / 2.0, Vec2::new(2.0, 3.0));
}
#[test]
fn test_vec2_neg() {
let v = Vec2::new(1.0, -2.0);
assert_eq!(-v, Vec2::new(-1.0, 2.0));
}
#[test]
fn test_vec2_add_assign() {
let mut v = Vec2::new(1.0, 2.0);
v += Vec2::new(3.0, 4.0);
assert_eq!(v, Vec2::new(4.0, 6.0));
}
#[test]
fn test_vec2_magnitude() {
let v = Vec2::new(3.0, 4.0);
assert!(approx_eq(v.magnitude(), 5.0));
}
#[test]
fn test_vec2_dot() {
let a = Vec2::new(1.0, 2.0);
let b = Vec2::new(3.0, 4.0);
assert!(approx_eq(a.dot(b), 11.0)); // 1*3 + 2*4
}
#[test]
fn test_vec2_normalized() {
let v = Vec2::new(3.0, 4.0);
let n = v.normalized();
assert!(approx_eq(n.magnitude(), 1.0));
}
#[test]
fn test_complex_add() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
assert_eq!(a + b, Complex::new(4.0, 6.0));
}
#[test]
fn test_complex_mul() {
let a = Complex::new(1.0, 2.0);
let b = Complex::new(3.0, 4.0);
// (1 + 2i)(3 + 4i) = 3 + 4i + 6i + 8i² = 3 + 10i - 8 = -5 + 10i
assert_eq!(a * b, Complex::new(-5.0, 10.0));
}
#[test]
fn test_complex_conjugate() {
let c = Complex::new(3.0, 4.0);
let conj = c.conjugate();
assert_eq!(conj, Complex::new(3.0, -4.0));
}
#[test]
fn test_complex_magnitude() {
let c = Complex::new(3.0, 4.0);
assert!(approx_eq(c.magnitude(), 5.0));
}
}
Deep Comparison
OCaml vs Rust: Arithmetic Operator Overloading
Side-by-Side Code
OCaml ā Custom infix operators
type vec2 = { x: float; y: float }
(* Define custom operators *)
let ( +^ ) a b = { x = a.x +. b.x; y = a.y +. b.y }
let ( -^ ) a b = { x = a.x -. b.x; y = a.y -. b.y }
let ( *^ ) s v = { x = s *. v.x; y = s *. v.y }
let () =
let a = { x = 3.0; y = 4.0 } in
let b = { x = 1.0; y = 2.0 } in
let sum = a +^ b in
let scaled = 2.0 *^ a in
()
Rust ā Traits for standard operators
use std::ops::{Add, Mul};
#[derive(Clone, Copy)]
struct Vec2 { x: f64, y: f64 }
impl Add for Vec2 {
type Output = Vec2;
fn add(self, other: Vec2) -> Vec2 {
Vec2 { x: self.x + other.x, y: self.y + other.y }
}
}
impl Mul<f64> for Vec2 {
type Output = Vec2;
fn mul(self, s: f64) -> Vec2 {
Vec2 { x: self.x * s, y: self.y * s }
}
}
fn main() {
let a = Vec2 { x: 3.0, y: 4.0 };
let b = Vec2 { x: 1.0, y: 2.0 };
let sum = a + b; // Uses Add trait
let scaled = a * 2.0; // Uses Mul<f64>
}
Comparison Table
| Aspect | OCaml | Rust |
|---|---|---|
| Operator style | Custom infix: +^, *^ | Standard: +, * |
| Definition | let ( +^ ) a b = ... | impl Add for Type |
| Return type | Inferred | type Output = ... |
| Symmetric ops | Define both sides | impl Mul<Vec2> for f64 |
| Compound ops | Not built-in | AddAssign, MulAssign, etc. |
Operator Trait Mapping
| Operator | Trait | Method |
|---|---|---|
+ | Add | add(self, rhs) |
- | Sub | sub(self, rhs) |
* | Mul | mul(self, rhs) |
/ | Div | div(self, rhs) |
% | Rem | rem(self, rhs) |
-x (unary) | Neg | neg(self) |
+= | AddAssign | add_assign(&mut self, rhs) |
Symmetric Operations
For 2.0 * vec, you need to implement both directions:
// vec * scalar
impl Mul<f64> for Vec2 {
type Output = Vec2;
fn mul(self, s: f64) -> Vec2 { ... }
}
// scalar * vec
impl Mul<Vec2> for f64 {
type Output = Vec2;
fn mul(self, v: Vec2) -> Vec2 { ... }
}
Compound Assignment
impl AddAssign for Vec2 {
fn add_assign(&mut self, other: Vec2) {
self.x += other.x;
self.y += other.y;
}
}
let mut v = Vec2::new(1.0, 2.0);
v += Vec2::new(3.0, 4.0); // Uses AddAssign
5 Takeaways
+, *.**Rust's approach feels more natural for math types.
Mul<f64> with type Output = Vec2 is explicit.
Vec2 * f64 and f64 * Vec2 are separate.
+= uses AddAssign, not Add.
Matrix * Vector = Vector (different output type).
Exercises
Complex { re: f64, im: f64 } with all four arithmetic operators plus Neg. Ensure i * i = -1 passes as a test. Implement Display as a + bi.Matrix2x2([[f64; 2]; 2]) and implement Mul<Matrix2x2> (matrix multiplication), Mul<Vec2> (matrix-vector product), and Add<Matrix2x2>. Verify with the identity matrix.Polynomial(Vec<f64>) where coeffs[i] is the coefficient of x^i. Implement Add (zip coefficients), Sub, Mul (polynomial multiplication using convolution), and Display (3x^2 + 2x + 1).