Implementation:Avhz RustQuant Overload Operators
| Knowledge Sources | |
|---|---|
| Domains | Automatic_Differentiation, Mathematics |
| Last Updated | 2026-02-07 19:00 GMT |
Overview
This module overloads Rust's standard arithmetic operators and mathematical functions for the Variable type so that every operation automatically records partial derivatives on the computational graph.
Description
The overload module is the core mechanism that makes automatic differentiation transparent to users of RustQuant. It provides trait implementations for all standard arithmetic operators (+, -, *, /, - (negation)) as well as their compound assignment forms (+=, -=, *=, /=) on the Variable<'v> type. Each operator variant is implemented for three operand combinations: Variable op Variable, Variable op f64, and f64 op Variable.
Beyond standard arithmetic, the module implements a comprehensive set of primitive mathematical functions as methods on Variable, including:
- Trigonometric:
sin,cos,tan,asin,acos,atan - Hyperbolic:
sinh,cosh,tanh,asinh,acosh,atanh - Exponential/Logarithmic:
exp,exp2,exp_m1,ln,ln_1p,log10,log2 - Power:
sqrt,cbrt,recip,abs - Error functions:
erf,erfc
The module also defines custom traits for operations not available through std::ops:
Powf<T>-- floating-point exponentiation (x^ywhere both may be variables)Powi<T>-- integer exponentiation (x^n)Log<T>-- logarithm with variable baseMin<T>andMax<T>-- differentiable min/max with subgradients
Additionally, Sum and Product iterator traits are implemented for Variable, enabling idiomatic Rust iterator usage such as .sum() and .product() over collections of variables.
Each implementation pushes a new vertex onto the computation graph via graph.push(), recording the operation's arity (nullary, unary, or binary), parent indices, and partial derivatives with respect to each parent.
Usage
Use the overloaded operators whenever you build differentiable expressions with Variable values. Standard Rust syntax applies -- no special function calls are needed for basic arithmetic. For power, logarithm, min, and max operations, import and use the corresponding custom traits (Powf, Powi, Log, Min, Max).
Code Reference
Source Location
- Repository: RustQuant
- File: crates/RustQuant_autodiff/src/overload.rs
- Lines: 1-2005
Signature
// Standard arithmetic operators (Add, Sub, Mul, Div, Neg, and Assign variants)
impl<'v> Add<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
fn add(self, other: Variable<'v>) -> Self::Output;
}
impl<'v> Sub<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
fn sub(self, other: Variable<'v>) -> Self::Output;
}
impl<'v> Mul<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
fn mul(self, other: Variable<'v>) -> Self::Output;
}
impl<'v> Div<Variable<'v>> for Variable<'v> {
type Output = Variable<'v>;
fn div(self, other: Variable<'v>) -> Self::Output;
}
impl<'v> Neg for Variable<'v> {
type Output = Self;
fn neg(self) -> Self::Output;
}
// Primitive mathematical functions on Variable
impl<'v> Variable<'v> {
pub fn abs(self) -> Self;
pub fn acos(self) -> Self;
pub fn acosh(self) -> Self;
pub fn asin(self) -> Self;
pub fn asinh(self) -> Self;
pub fn atan(self) -> Self;
pub fn atanh(self) -> Self;
pub fn cbrt(self) -> Self;
pub fn cos(self) -> Self;
pub fn cosh(self) -> Self;
pub fn exp(self) -> Self;
pub fn exp2(self) -> Self;
pub fn exp_m1(self) -> Self;
pub fn ln(self) -> Self;
pub fn ln_1p(self) -> Self;
pub fn log10(self) -> Self;
pub fn log2(self) -> Self;
pub fn recip(self) -> Self;
pub fn sin(self) -> Self;
pub fn sinh(self) -> Self;
pub fn sqrt(self) -> Self;
pub fn tan(self) -> Self;
pub fn tanh(self) -> Self;
pub fn erf(self) -> Self;
pub fn erfc(self) -> Self;
}
// Custom traits for power, log, min, max
pub trait Powf<T> {
type Output;
fn powf(&self, other: T) -> Self::Output;
}
pub trait Powi<T> {
type Output;
fn powi(&self, other: T) -> Self::Output;
}
pub trait Log<T> {
type Output;
fn log(&self, base: T) -> Self::Output;
}
pub trait Min<T> {
type Output;
fn min(&self, other: T) -> Self::Output;
}
pub trait Max<T> {
type Output;
fn max(&self, other: T) -> Self::Output;
}
// Iterator traits
impl<'v> Sum<Variable<'v>> for Variable<'v>;
impl<'v> Product<Variable<'v>> for Variable<'v>;
Import
use RustQuant::autodiff::*;
// Custom traits (needed for .powf(), .log(), Min::min(), Max::max())
use RustQuant::autodiff::{Powf, Powi, Log, Min, Max};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| self | Variable<'v> |
Yes | The left-hand operand (or sole operand for unary functions) |
| other | Variable<'v> or f64 or i32 |
Depends on operation | The right-hand operand for binary operations |
Outputs
| Name | Type | Description |
|---|---|---|
| result | Variable<'v> |
A new variable whose value contains the computed result and whose index points to the newly created vertex in the computation graph, recording partial derivatives for reverse-mode accumulation.
|
Usage Examples
Arithmetic Operators
use RustQuant_autodiff::*;
let g = Graph::new();
let x = g.var(5.0);
let y = g.var(2.0);
// Addition: d/dx (x+y) = 1, d/dy (x+y) = 1
let z = x + y;
let grad = z.accumulate();
assert_eq!(z.value, 7.0);
assert_eq!(grad.wrt(&x), 1.0);
assert_eq!(grad.wrt(&y), 1.0);
Multiplication with Mixed Types
use RustQuant_autodiff::*;
let g = Graph::new();
let x = g.var(5.0);
let a = 2.0_f64;
// Variable * f64
let z = x * a;
let grad = z.accumulate();
assert_eq!(z.value, 10.0);
assert_eq!(grad.wrt(&x), 2.0);
Trigonometric Functions
use RustQuant_autodiff::*;
let g = Graph::new();
let x = g.var(1.0);
// d/dx sin(x) = cos(x)
let z = x.sin();
let grad = z.accumulate();
// z.value ~ 0.8415, grad.wrt(&x) ~ 0.5403
Power Function with Powf Trait
use RustQuant_autodiff::*;
let g = Graph::new();
let x = g.var(2.0);
// x^3.0, d/dx = 3 * x^2 = 12.0
let z = x.powf(3.0);
let grad = z.accumulate();
assert_eq!(z.value, 8.0);
Sum and Product Iterators
use RustQuant_autodiff::*;
let g = Graph::new();
let params = (0..100).map(|x| g.var(x as f64)).collect::<Vec<_>>();
let sum = params.iter().copied().sum::<Variable>();
let derivs = sum.accumulate();
// d/dx_i (sum of all x_j) = 1.0 for all i
for i in derivs.wrt(¶ms) {
assert_eq!(i, 1.0);
}