Implementation:LaurentMazare Tch rs Sparse Adam
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Optimization, Sparse Computation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
A custom Adam optimizer implementation in Rust that supports both sparse and dense gradient updates, demonstrating how to build optimizers from scratch using the tch-rs variable store.
Description
The SparseAdam optimizer implements the Adam optimization algorithm with support for sparse gradients. It is composed of two main structs:
Buffer: Stores the first-order moment (mean) and second-order moment (uncentered variance) for each trainable variable, along with a timestep counter. The moments are initialized to zero tensors matching the variable dimensions.
SparseAdam: The optimizer struct that holds hyperparameters (learning rate, beta1, beta2, epsilon), aforce_sparseflag, a shared reference to the variable store'sVariables(viaArc<Mutex<Variables>>), and per-variable buffers.
The update step (_step) iterates through all trainable variables and branches based on whether the gradient is sparse:
- Sparse path: Converts dense gradients to sparse format if
force_sparseis enabled. Deduplicates sparse indices withcoalesce(), extracts indices and values, then performs indexed updates on the moment buffers usingindex_add_. The parameter update uses bias-corrected moments applied only at the non-zero gradient indices.
- Dense path: Performs standard Adam updates using exponential moving averages of the gradient (first moment) and squared gradient (second moment), with bias correction. Parameters are updated in-place using
addcdiv_.
The step method wraps _step inside tch::no_grad to ensure gradient updates do not participate in the computation graph.
Usage
Use this implementation when you need an Adam optimizer that efficiently handles sparse gradients (e.g., for embedding layers), or as a reference for building custom optimizers that interact with the tch-rs VarStore and Variables internals.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/custom-optimizer/sparse_adam.rs
- Lines: 1-147
Signature
struct Buffer {
pub first_moment: Tensor,
pub second_moment: Tensor,
idx: usize,
}
impl Buffer {
pub fn new(size: &[i64]) -> Buffer
pub fn inc(&mut self) -> usize
}
pub struct SparseAdam {
lr: f64,
beta1: f64,
beta2: f64,
eps: f64,
force_sparse: bool,
vars: Arc<Mutex<Variables>>,
buffers: Vec<Buffer>,
}
impl SparseAdam {
pub fn new(vs: &VarStore, lr: f64, beta1: f64, beta2: f64, eps: f64, force_sparse: bool) -> SparseAdam
pub fn step(&mut self)
pub fn _step(&mut self)
pub fn zero_grad(&mut self)
}
Import
// Used as a module within the custom-optimizer example:
// mod sparse_adam;
use std::sync::{Arc, Mutex};
use tch::nn::{VarStore, Variables};
use tch::{no_grad, Device, Kind, Tensor};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| vs | &VarStore | Yes | The variable store containing trainable parameters. |
| lr | f64 | Yes | Learning rate. |
| beta1 | f64 | Yes | Exponential decay rate for the first moment estimate (typically 0.9). |
| beta2 | f64 | Yes | Exponential decay rate for the second moment estimate (typically 0.999). |
| eps | f64 | Yes | Small constant for numerical stability (typically 1e-8). |
| force_sparse | bool | Yes | If true, converts dense gradients to sparse before updating. |
Outputs
| Name | Type | Description |
|---|---|---|
| (in-place) | Tensor mutations | Trainable variable tensors in the VarStore are updated in-place. |
Usage Examples
use tch::{nn, nn::Module, Device};
// Build a model and variable store
let vs = nn::VarStore::new(Device::Cpu);
let net = nn::seq()
.add(nn::linear(vs.root() / "layer1", 784, 128, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs.root(), 128, 10, Default::default()));
// Create SparseAdam optimizer
let mut opt = sparse_adam::SparseAdam::new(&vs, 5e-3, 0.9, 0.999, 1e-8, false);
// Training step
let loss = net.forward(&images).cross_entropy_for_logits(&labels);
opt.zero_grad();
loss.backward();
opt.step();