Implementation:Microsoft Onnxruntime Module TrainStep
Appearance
Overview
Executes the on-device training loop operations: TrainStep (forward + backward), Optimizer::Step (parameter update), LazyResetGrad (gradient reset), and EvalStep (evaluation forward pass).
Metadata
| Field | Value |
|---|---|
| Implementation Name | Module_TrainStep |
| Type | API Doc |
| Language | C++ and Python |
| API | C++: Module::TrainStep(inputs, outputs), Optimizer::Step(), Module::LazyResetGrad(), Module::EvalStep(inputs, outputs). Python: module.train(), module(inputs), optimizer.step(), module.lazy_reset_grad(), module.eval()
|
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | orttraining/orttraining/training_api/module.cc:L623 (TrainStep), optimizer.cc:L183+ (Step), module.cc:L618 (LazyResetGrad), module.cc:L645 (EvalStep) |
| Last Updated | 2026-02-10 |
Description
This implementation covers the four core operations of the on-device training loop:
- TrainStep -- Constructs the feed list from user inputs, parameter weights, gradient buffers, and a gradient reset flag. Runs the training ONNX session to execute forward and backward passes. After execution, sets
accumulate_gradient_totrueso subsequent calls accumulate gradients unlessLazyResetGradis called. - Optimizer::Step -- Runs the optimizer ONNX session with the current parameters, gradients, and momentum states as inputs. Updates parameters in-place and increments the step counter.
- LazyResetGrad -- Sets
accumulate_gradient_ = false, which causes the nextTrainStepto reset gradients via theInPlaceAccumulatornode in the training graph rather than performing an explicit zeroing operation. - EvalStep -- Runs the evaluation ONNX session with user inputs and parameter weights (no gradient buffers). Computes the forward pass only, returning the evaluation outputs.
Both TrainStep and EvalStep check that the checkpoint state is not nominal (i.e., that full parameter values have been loaded) before proceeding.
API Signature
C++
// Module methods
Status TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs);
Status EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs);
Status LazyResetGrad();
// Optimizer method
Status Step();
Python
# Switch to training mode
module.train()
# Execute forward + backward (returns loss)
training_loss = module(input_data)
# Update parameters
optimizer.step()
# Reset gradients for next iteration
module.lazy_reset_grad()
# Switch to eval mode and run evaluation
module.eval()
eval_loss = module(eval_data)
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| inputs | std::vector<OrtValue> / Python tensors |
User-provided input data for the training or evaluation step |
| outputs | std::vector<OrtValue>& |
Output values populated by the session run (C++ only; Python returns directly) |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input (TrainStep) | User input tensors | Batch of training data matching the model's expected input shapes |
| Output (TrainStep) | Loss and forward outputs | Training loss value and any additional forward outputs |
| Side Effect (TrainStep) | Gradient accumulation | Gradients are accumulated in Parameter gradient buffers
|
| Side Effect (Step) | Parameter update | Parameters are updated in-place using the optimizer's update rule |
| Side Effect (LazyResetGrad) | Gradient reset flag | Sets flag for gradient zeroing on the next TrainStep
|
| Input (EvalStep) | User input tensors | Batch of evaluation data |
| Output (EvalStep) | Evaluation outputs | Evaluation loss and any additional forward outputs |
Code Reference
From orttraining/orttraining/training_api/module.cc:
Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform TrainStep with a nominal state. Please load the model parameters first.");
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
feeds.insert(feeds.end(), gradients_.begin(), gradients_.end());
OrtValue reset_grad_input;
utils::WrapInOrtValue<bool>(!accumulate_gradient_, &reset_grad_input);
feeds.push_back(reset_grad_input);
ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_.AllInputNames(),
feeds, train_output_names_, &outputs));
accumulate_gradient_ = true;
return Status::OK();
}
Status Module::LazyResetGrad() {
accumulate_gradient_ = false;
return Status::OK();
}
Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform EvalStep with a nominal state. Please load the model parameters first.");
ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized.");
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
auto status = eval_sess_->Run(RunOptions(), eval_input_names_, feeds, eval_output_names_, &outputs);
ORT_THROW_IF_ERROR(status);
return Status::OK();
}
Usage Example
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Setup
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")
module = Module("training_artifacts/training_model.onnx", state,
"training_artifacts/eval_model.onnx", device="cpu")
optimizer = Optimizer("training_artifacts/optimizer_model.onnx", module)
# Training loop
for epoch in range(num_epochs):
for batch in train_dataloader:
module.train()
training_loss = module(batch)
optimizer.step()
module.lazy_reset_grad()
# Periodic evaluation
module.eval()
total_eval_loss = 0.0
for batch in eval_dataloader:
eval_loss = module(batch)
total_eval_loss += eval_loss
print(f"Epoch {epoch}: eval_loss={total_eval_loss}")
# Save checkpoint after training
CheckpointState.save_checkpoint(state, "checkpoints/final")
Implements
Principle:Microsoft_Onnxruntime_On_Device_Training_Loop
Related Pages
- Module Optimizer Scheduler Init -- Creates the components used in the training loop
- SaveCheckpoint -- Persists training state after the loop
- ExportModelForInferencing -- Exports the trained model for deployment
- Environment:Microsoft_Onnxruntime_CUDA_GPU_Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment