Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Avhz RustQuant Overload Operators

From Leeroopedia


Knowledge Sources
Domains Automatic_Differentiation, Mathematics
Last Updated 2026-02-07 19:00 GMT

Overview

This module overloads Rust's standard arithmetic operators and mathematical functions for the Variable type so that every operation automatically records partial derivatives on the computational graph.

Description

The overload module is the core mechanism that makes automatic differentiation transparent to users of RustQuant. It provides trait implementations for all standard arithmetic operators (+, -, *, /, - (negation)) as well as their compound assignment forms (+=, -=, *=, /=) on the Variable<'v> type. Each operator variant is implemented for three operand combinations: Variable op Variable, Variable op f64, and f64 op Variable.

Beyond standard arithmetic, the module implements a comprehensive set of primitive mathematical functions as methods on Variable, including:

  • Trigonometric: sin, cos, tan, asin, acos, atan
  • Hyperbolic: sinh, cosh, tanh, asinh, acosh, atanh
  • Exponential/Logarithmic: exp, exp2, exp_m1, ln, ln_1p, log10, log2
  • Power: sqrt, cbrt, recip, abs
  • Error functions: erf, erfc

The module also defines custom traits for operations not available through std::ops:

  • Powf<T> -- floating-point exponentiation (x^y where both may be variables)
  • Powi<T> -- integer exponentiation (x^n)
  • Log<T> -- logarithm with variable base
  • Min<T> and Max<T> -- differentiable min/max with subgradients

Additionally, Sum and Product iterator traits are implemented for Variable, enabling idiomatic Rust iterator usage such as .sum() and .product() over collections of variables.

Each implementation pushes a new vertex onto the computation graph via graph.push(), recording the operation's arity (nullary, unary, or binary), parent indices, and partial derivatives with respect to each parent.

Usage

Use the overloaded operators whenever you build differentiable expressions with Variable values. Standard Rust syntax applies -- no special function calls are needed for basic arithmetic. For power, logarithm, min, and max operations, import and use the corresponding custom traits (Powf, Powi, Log, Min, Max).

Code Reference

Source Location

Signature

// Standard arithmetic operators (Add, Sub, Mul, Div, Neg, and Assign variants)
impl<'v> Add<Variable<'v>> for Variable<'v> {
    type Output = Variable<'v>;
    fn add(self, other: Variable<'v>) -> Self::Output;
}

impl<'v> Sub<Variable<'v>> for Variable<'v> {
    type Output = Variable<'v>;
    fn sub(self, other: Variable<'v>) -> Self::Output;
}

impl<'v> Mul<Variable<'v>> for Variable<'v> {
    type Output = Variable<'v>;
    fn mul(self, other: Variable<'v>) -> Self::Output;
}

impl<'v> Div<Variable<'v>> for Variable<'v> {
    type Output = Variable<'v>;
    fn div(self, other: Variable<'v>) -> Self::Output;
}

impl<'v> Neg for Variable<'v> {
    type Output = Self;
    fn neg(self) -> Self::Output;
}

// Primitive mathematical functions on Variable
impl<'v> Variable<'v> {
    pub fn abs(self) -> Self;
    pub fn acos(self) -> Self;
    pub fn acosh(self) -> Self;
    pub fn asin(self) -> Self;
    pub fn asinh(self) -> Self;
    pub fn atan(self) -> Self;
    pub fn atanh(self) -> Self;
    pub fn cbrt(self) -> Self;
    pub fn cos(self) -> Self;
    pub fn cosh(self) -> Self;
    pub fn exp(self) -> Self;
    pub fn exp2(self) -> Self;
    pub fn exp_m1(self) -> Self;
    pub fn ln(self) -> Self;
    pub fn ln_1p(self) -> Self;
    pub fn log10(self) -> Self;
    pub fn log2(self) -> Self;
    pub fn recip(self) -> Self;
    pub fn sin(self) -> Self;
    pub fn sinh(self) -> Self;
    pub fn sqrt(self) -> Self;
    pub fn tan(self) -> Self;
    pub fn tanh(self) -> Self;
    pub fn erf(self) -> Self;
    pub fn erfc(self) -> Self;
}

// Custom traits for power, log, min, max
pub trait Powf<T> {
    type Output;
    fn powf(&self, other: T) -> Self::Output;
}

pub trait Powi<T> {
    type Output;
    fn powi(&self, other: T) -> Self::Output;
}

pub trait Log<T> {
    type Output;
    fn log(&self, base: T) -> Self::Output;
}

pub trait Min<T> {
    type Output;
    fn min(&self, other: T) -> Self::Output;
}

pub trait Max<T> {
    type Output;
    fn max(&self, other: T) -> Self::Output;
}

// Iterator traits
impl<'v> Sum<Variable<'v>> for Variable<'v>;
impl<'v> Product<Variable<'v>> for Variable<'v>;

Import

use RustQuant::autodiff::*;
// Custom traits (needed for .powf(), .log(), Min::min(), Max::max())
use RustQuant::autodiff::{Powf, Powi, Log, Min, Max};

I/O Contract

Inputs

Name Type Required Description
self Variable<'v> Yes The left-hand operand (or sole operand for unary functions)
other Variable<'v> or f64 or i32 Depends on operation The right-hand operand for binary operations

Outputs

Name Type Description
result Variable<'v> A new variable whose value contains the computed result and whose index points to the newly created vertex in the computation graph, recording partial derivatives for reverse-mode accumulation.

Usage Examples

Arithmetic Operators

use RustQuant_autodiff::*;

let g = Graph::new();

let x = g.var(5.0);
let y = g.var(2.0);

// Addition: d/dx (x+y) = 1, d/dy (x+y) = 1
let z = x + y;
let grad = z.accumulate();
assert_eq!(z.value, 7.0);
assert_eq!(grad.wrt(&x), 1.0);
assert_eq!(grad.wrt(&y), 1.0);

Multiplication with Mixed Types

use RustQuant_autodiff::*;

let g = Graph::new();

let x = g.var(5.0);
let a = 2.0_f64;
// Variable * f64
let z = x * a;
let grad = z.accumulate();
assert_eq!(z.value, 10.0);
assert_eq!(grad.wrt(&x), 2.0);

Trigonometric Functions

use RustQuant_autodiff::*;

let g = Graph::new();

let x = g.var(1.0);
// d/dx sin(x) = cos(x)
let z = x.sin();
let grad = z.accumulate();
// z.value ~ 0.8415, grad.wrt(&x) ~ 0.5403

Power Function with Powf Trait

use RustQuant_autodiff::*;

let g = Graph::new();

let x = g.var(2.0);
// x^3.0, d/dx = 3 * x^2 = 12.0
let z = x.powf(3.0);
let grad = z.accumulate();
assert_eq!(z.value, 8.0);

Sum and Product Iterators

use RustQuant_autodiff::*;

let g = Graph::new();

let params = (0..100).map(|x| g.var(x as f64)).collect::<Vec<_>>();
let sum = params.iter().copied().sum::<Variable>();
let derivs = sum.accumulate();

// d/dx_i (sum of all x_j) = 1.0 for all i
for i in derivs.wrt(&params) {
    assert_eq!(i, 1.0);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment