Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime Module TrainStep

From Leeroopedia


Overview

Executes the on-device training loop operations: TrainStep (forward + backward), Optimizer::Step (parameter update), LazyResetGrad (gradient reset), and EvalStep (evaluation forward pass).

Metadata

Field Value
Implementation Name Module_TrainStep
Type API Doc
Language C++ and Python
API C++: Module::TrainStep(inputs, outputs), Optimizer::Step(), Module::LazyResetGrad(), Module::EvalStep(inputs, outputs). Python: module.train(), module(inputs), optimizer.step(), module.lazy_reset_grad(), module.eval()
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference orttraining/orttraining/training_api/module.cc:L623 (TrainStep), optimizer.cc:L183+ (Step), module.cc:L618 (LazyResetGrad), module.cc:L645 (EvalStep)
Last Updated 2026-02-10

Description

This implementation covers the four core operations of the on-device training loop:

  • TrainStep -- Constructs the feed list from user inputs, parameter weights, gradient buffers, and a gradient reset flag. Runs the training ONNX session to execute forward and backward passes. After execution, sets accumulate_gradient_ to true so subsequent calls accumulate gradients unless LazyResetGrad is called.
  • Optimizer::Step -- Runs the optimizer ONNX session with the current parameters, gradients, and momentum states as inputs. Updates parameters in-place and increments the step counter.
  • LazyResetGrad -- Sets accumulate_gradient_ = false, which causes the next TrainStep to reset gradients via the InPlaceAccumulator node in the training graph rather than performing an explicit zeroing operation.
  • EvalStep -- Runs the evaluation ONNX session with user inputs and parameter weights (no gradient buffers). Computes the forward pass only, returning the evaluation outputs.

Both TrainStep and EvalStep check that the checkpoint state is not nominal (i.e., that full parameter values have been loaded) before proceeding.

API Signature

C++

// Module methods
Status TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs);
Status EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs);
Status LazyResetGrad();

// Optimizer method
Status Step();

Python

# Switch to training mode
module.train()

# Execute forward + backward (returns loss)
training_loss = module(input_data)

# Update parameters
optimizer.step()

# Reset gradients for next iteration
module.lazy_reset_grad()

# Switch to eval mode and run evaluation
module.eval()
eval_loss = module(eval_data)

Key Parameters

Parameter Type Description
inputs std::vector<OrtValue> / Python tensors User-provided input data for the training or evaluation step
outputs std::vector<OrtValue>& Output values populated by the session run (C++ only; Python returns directly)

I/O Contract

Direction Type Description
Input (TrainStep) User input tensors Batch of training data matching the model's expected input shapes
Output (TrainStep) Loss and forward outputs Training loss value and any additional forward outputs
Side Effect (TrainStep) Gradient accumulation Gradients are accumulated in Parameter gradient buffers
Side Effect (Step) Parameter update Parameters are updated in-place using the optimizer's update rule
Side Effect (LazyResetGrad) Gradient reset flag Sets flag for gradient zeroing on the next TrainStep
Input (EvalStep) User input tensors Batch of evaluation data
Output (EvalStep) Evaluation outputs Evaluation loss and any additional forward outputs

Code Reference

From orttraining/orttraining/training_api/module.cc:

Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
  ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
                "Cannot perform TrainStep with a nominal state. Please load the model parameters first.");
  std::vector<OrtValue> feeds{inputs};
  feeds.insert(feeds.end(), weights_.begin(), weights_.end());
  feeds.insert(feeds.end(), gradients_.begin(), gradients_.end());

  OrtValue reset_grad_input;
  utils::WrapInOrtValue<bool>(!accumulate_gradient_, &reset_grad_input);
  feeds.push_back(reset_grad_input);

  ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_.AllInputNames(),
                                       feeds, train_output_names_, &outputs));
  accumulate_gradient_ = true;
  return Status::OK();
}

Status Module::LazyResetGrad() {
  accumulate_gradient_ = false;
  return Status::OK();
}

Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
  ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
                "Cannot perform EvalStep with a nominal state. Please load the model parameters first.");
  ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized.");
  std::vector<OrtValue> feeds{inputs};
  feeds.insert(feeds.end(), weights_.begin(), weights_.end());
  auto status = eval_sess_->Run(RunOptions(), eval_input_names_, feeds, eval_output_names_, &outputs);
  ORT_THROW_IF_ERROR(status);
  return Status::OK();
}

Usage Example

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Setup
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")
module = Module("training_artifacts/training_model.onnx", state,
                "training_artifacts/eval_model.onnx", device="cpu")
optimizer = Optimizer("training_artifacts/optimizer_model.onnx", module)

# Training loop
for epoch in range(num_epochs):
    for batch in train_dataloader:
        module.train()
        training_loss = module(batch)
        optimizer.step()
        module.lazy_reset_grad()

    # Periodic evaluation
    module.eval()
    total_eval_loss = 0.0
    for batch in eval_dataloader:
        eval_loss = module(batch)
        total_eval_loss += eval_loss
    print(f"Epoch {epoch}: eval_loss={total_eval_loss}")

# Save checkpoint after training
CheckpointState.save_checkpoint(state, "checkpoints/final")

Implements

Principle:Microsoft_Onnxruntime_On_Device_Training_Loop

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment