Implementation:LaurentMazare Tch rs Custom Optimizer Example
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Optimization, Image Classification |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Trains an MNIST classifier using a custom SparseAdam optimizer, demonstrating how to integrate user-defined optimizers with the tch-rs training loop.
Description
This example shows how to replace tch-rs's built-in optimizers with a custom optimizer implementation. It defines a simple two-layer feedforward neural network (784 -> 128 -> 10) using nn::seq with ReLU activation, and trains it on MNIST using the SparseAdam optimizer defined in a companion module.
Key aspects:
- The network architecture is a sequential model: a linear layer mapping 784 input dimensions to 128 hidden nodes, followed by ReLU, then a linear layer mapping to 10 output classes.
- The custom optimizer (
SparseAdam) is instantiated with learning rate 0.005, beta1=0.9, beta2=0.999, and epsilon=1e-8. Aforce_sparseflag allows testing sparse gradient updates on a dense problem. - The training loop manually calls
opt.zero_grad(),loss.backward(), andopt.step()rather than using the built-inbackward_stepconvenience method. - The example targets 97% test accuracy on MNIST over 200 epochs.
Usage
Use this example as a reference for implementing and integrating custom optimizers within the tch-rs framework, especially when the built-in Adam/SGD optimizers are insufficient for specialized training requirements such as sparse gradient updates.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: examples/custom-optimizer/main.rs
- Lines: 1-50
Signature
fn net(vs: &nn::Path) -> impl Module
pub fn run() -> Result<()>
fn main()
Import
// Standalone binary example. Run with:
// cargo run --example custom-optimizer
mod sparse_adam;
use anyhow::Result;
use tch::{nn, nn::Module, Device};
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| data/ | Directory | Yes | MNIST dataset directory loadable via tch::vision::mnist::load_dir.
|
Outputs
| Name | Type | Description |
|---|---|---|
| stdout | Text | Per-epoch training loss and test accuracy percentage. |
Usage Examples
mod sparse_adam;
use anyhow::Result;
use tch::{nn, nn::Module, Device};
const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;
fn net(vs: &nn::Path) -> impl Module {
nn::seq()
.add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}
pub fn run() -> Result<()> {
let m = tch::vision::mnist::load_dir("data")?;
let vs = nn::VarStore::new(Device::Cpu);
let net = net(&vs.root());
let force_sparse = false;
let mut opt = sparse_adam::SparseAdam::new(&vs, 5e-3, 0.9, 0.999, 1e-8, force_sparse);
for epoch in 1..200 {
let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
opt.zero_grad();
loss.backward();
opt.step();
let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
println!(
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
epoch,
f64::try_from(&loss)?,
100. * f64::try_from(&test_accuracy)?,
);
}
Ok(())
}