Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft Onnxruntime On Device Training Loop

From Leeroopedia


Overview

Iterative execution of forward pass, backward pass, and parameter updates on device-local data.

Metadata

Field Value
Principle Name On_Device_Training_Loop
Category API Doc
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference orttraining/orttraining/training_api/module.cc:L623 (TrainStep), optimizer.cc:L183+ (Step), module.cc:L618 (LazyResetGrad), module.cc:L645 (EvalStep)
Last Updated 2026-02-10

Description

The training loop repeatedly executes TrainStep (forward + backward), OptimizerStep (parameter update), and LazyResetGrad (gradient reset) for each batch of training data. EvalStep can be called periodically to assess model performance without updating parameters.

The training loop consists of the following operations per iteration:

  • TrainStep -- Executes the training ONNX session, which runs the forward graph to compute the loss and the backward graph to compute gradients. User inputs and parameter tensors are fed into the session. Gradients are accumulated in-place within Parameter objects.
  • OptimizerStep -- Executes the optimizer ONNX session, which reads the accumulated gradients and current optimizer states (momentum buffers) to compute parameter updates. The step counter is incremented.
  • LazyResetGrad -- Sets a flag to reset gradients on the next TrainStep call, rather than immediately zeroing them. This avoids unnecessary memory operations when gradient accumulation is not needed.
  • EvalStep -- Executes the evaluation ONNX session (forward only, no backward) using shared parameter weights. Used to compute validation metrics without modifying parameters.

The "lazy" gradient reset strategy works by passing a boolean flag (reset_grad) into the training session's InPlaceAccumulator. When the flag is set, gradients are zeroed as part of the forward/backward execution, eliminating a separate kernel launch.

Theoretical Basis

Standard gradient descent training loop: compute loss via forward pass, compute gradients via backward pass, update parameters using optimizer, reset gradients for next iteration.

  • Forward Pass -- Computes the model's prediction and loss function value given the input data and current parameters.
  • Backward Pass -- Applies the chain rule of calculus to propagate the loss gradient through the computation graph in reverse, computing the gradient of the loss with respect to each trainable parameter.
  • Parameter Update -- The optimizer applies its update rule (e.g., AdamW) to adjust each parameter in the direction that reduces the loss, scaled by the learning rate and regularized by weight decay.
  • Gradient Reset -- Gradients must be zeroed between iterations to prevent accumulation across batches (unless gradient accumulation over multiple mini-batches is intended).
  • Evaluation -- Running the forward pass without gradient computation provides an unbiased estimate of model performance on held-out data.

Usage

from onnxruntime.training.api import CheckpointState, Module, Optimizer

state = CheckpointState.load_checkpoint("output_artifacts/checkpoint")
module = Module("output_artifacts/training_model.onnx", state,
                "output_artifacts/eval_model.onnx", device="cpu")
optimizer = Optimizer("output_artifacts/optimizer_model.onnx", module)

# Training loop
for epoch in range(num_epochs):
    for batch in train_dataloader:
        module.train()
        training_loss = module(batch)
        optimizer.step()
        module.lazy_reset_grad()

    # Evaluation
    module.eval()
    for batch in eval_dataloader:
        eval_loss = module(batch)

Implemented By

Implementation:Microsoft_Onnxruntime_Module_TrainStep

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment