Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Custom Optimizer Example

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Optimization, Image Classification
Last Updated 2026-02-08 00:00 GMT

Overview

Trains an MNIST classifier using a custom SparseAdam optimizer, demonstrating how to integrate user-defined optimizers with the tch-rs training loop.

Description

This example shows how to replace tch-rs's built-in optimizers with a custom optimizer implementation. It defines a simple two-layer feedforward neural network (784 -> 128 -> 10) using nn::seq with ReLU activation, and trains it on MNIST using the SparseAdam optimizer defined in a companion module.

Key aspects:

  • The network architecture is a sequential model: a linear layer mapping 784 input dimensions to 128 hidden nodes, followed by ReLU, then a linear layer mapping to 10 output classes.
  • The custom optimizer (SparseAdam) is instantiated with learning rate 0.005, beta1=0.9, beta2=0.999, and epsilon=1e-8. A force_sparse flag allows testing sparse gradient updates on a dense problem.
  • The training loop manually calls opt.zero_grad(), loss.backward(), and opt.step() rather than using the built-in backward_step convenience method.
  • The example targets 97% test accuracy on MNIST over 200 epochs.

Usage

Use this example as a reference for implementing and integrating custom optimizers within the tch-rs framework, especially when the built-in Adam/SGD optimizers are insufficient for specialized training requirements such as sparse gradient updates.

Code Reference

Source Location

Signature

fn net(vs: &nn::Path) -> impl Module

pub fn run() -> Result<()>

fn main()

Import

// Standalone binary example. Run with:
// cargo run --example custom-optimizer
mod sparse_adam;

use anyhow::Result;
use tch::{nn, nn::Module, Device};

I/O Contract

Inputs

Name Type Required Description
data/ Directory Yes MNIST dataset directory loadable via tch::vision::mnist::load_dir.

Outputs

Name Type Description
stdout Text Per-epoch training loss and test accuracy percentage.

Usage Examples

mod sparse_adam;

use anyhow::Result;
use tch::{nn, nn::Module, Device};

const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;

fn net(vs: &nn::Path) -> impl Module {
    nn::seq()
        .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
        .add_fn(|xs| xs.relu())
        .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
}

pub fn run() -> Result<()> {
    let m = tch::vision::mnist::load_dir("data")?;
    let vs = nn::VarStore::new(Device::Cpu);
    let net = net(&vs.root());

    let force_sparse = false;
    let mut opt = sparse_adam::SparseAdam::new(&vs, 5e-3, 0.9, 0.999, 1e-8, force_sparse);

    for epoch in 1..200 {
        let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
        opt.zero_grad();
        loss.backward();
        opt.step();

        let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
        println!(
            "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
            epoch,
            f64::try_from(&loss)?,
            100. * f64::try_from(&test_accuracy)?,
        );
    }
    Ok(())
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment