Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Sparse Adam

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Optimization, Sparse Computation
Last Updated 2026-02-08 00:00 GMT

Overview

A custom Adam optimizer implementation in Rust that supports both sparse and dense gradient updates, demonstrating how to build optimizers from scratch using the tch-rs variable store.

Description

The SparseAdam optimizer implements the Adam optimization algorithm with support for sparse gradients. It is composed of two main structs:

  • Buffer: Stores the first-order moment (mean) and second-order moment (uncentered variance) for each trainable variable, along with a timestep counter. The moments are initialized to zero tensors matching the variable dimensions.
  • SparseAdam: The optimizer struct that holds hyperparameters (learning rate, beta1, beta2, epsilon), a force_sparse flag, a shared reference to the variable store's Variables (via Arc<Mutex<Variables>>), and per-variable buffers.

The update step (_step) iterates through all trainable variables and branches based on whether the gradient is sparse:

  • Sparse path: Converts dense gradients to sparse format if force_sparse is enabled. Deduplicates sparse indices with coalesce(), extracts indices and values, then performs indexed updates on the moment buffers using index_add_. The parameter update uses bias-corrected moments applied only at the non-zero gradient indices.
  • Dense path: Performs standard Adam updates using exponential moving averages of the gradient (first moment) and squared gradient (second moment), with bias correction. Parameters are updated in-place using addcdiv_.

The step method wraps _step inside tch::no_grad to ensure gradient updates do not participate in the computation graph.

Usage

Use this implementation when you need an Adam optimizer that efficiently handles sparse gradients (e.g., for embedding layers), or as a reference for building custom optimizers that interact with the tch-rs VarStore and Variables internals.

Code Reference

Source Location

Signature

struct Buffer {
    pub first_moment: Tensor,
    pub second_moment: Tensor,
    idx: usize,
}

impl Buffer {
    pub fn new(size: &[i64]) -> Buffer
    pub fn inc(&mut self) -> usize
}

pub struct SparseAdam {
    lr: f64,
    beta1: f64,
    beta2: f64,
    eps: f64,
    force_sparse: bool,
    vars: Arc<Mutex<Variables>>,
    buffers: Vec<Buffer>,
}

impl SparseAdam {
    pub fn new(vs: &VarStore, lr: f64, beta1: f64, beta2: f64, eps: f64, force_sparse: bool) -> SparseAdam
    pub fn step(&mut self)
    pub fn _step(&mut self)
    pub fn zero_grad(&mut self)
}

Import

// Used as a module within the custom-optimizer example:
// mod sparse_adam;
use std::sync::{Arc, Mutex};
use tch::nn::{VarStore, Variables};
use tch::{no_grad, Device, Kind, Tensor};

I/O Contract

Inputs

Name Type Required Description
vs &VarStore Yes The variable store containing trainable parameters.
lr f64 Yes Learning rate.
beta1 f64 Yes Exponential decay rate for the first moment estimate (typically 0.9).
beta2 f64 Yes Exponential decay rate for the second moment estimate (typically 0.999).
eps f64 Yes Small constant for numerical stability (typically 1e-8).
force_sparse bool Yes If true, converts dense gradients to sparse before updating.

Outputs

Name Type Description
(in-place) Tensor mutations Trainable variable tensors in the VarStore are updated in-place.

Usage Examples

use tch::{nn, nn::Module, Device};

// Build a model and variable store
let vs = nn::VarStore::new(Device::Cpu);
let net = nn::seq()
    .add(nn::linear(vs.root() / "layer1", 784, 128, Default::default()))
    .add_fn(|xs| xs.relu())
    .add(nn::linear(vs.root(), 128, 10, Default::default()));

// Create SparseAdam optimizer
let mut opt = sparse_adam::SparseAdam::new(&vs, 5e-3, 0.9, 0.999, 1e-8, false);

// Training step
let loss = net.forward(&images).cross_entropy_for_logits(&labels);
opt.zero_grad();
loss.backward();
opt.step();

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment