Implementation:Microsoft Onnxruntime Module Optimizer Scheduler Init
Overview
Initializes the core training components -- Module, Optimizer, and LinearLRScheduler -- from training artifacts and checkpoint state to form the complete on-device training pipeline.
Metadata
| Field | Value |
|---|---|
| Implementation Name | Module_Optimizer_Scheduler_Init |
| Type | API Doc |
| Language | C++ and Python |
| API | Python: Module(train_model_path, state, eval_model_path, device), Optimizer(optimizer_model_path, module). C++: Module(model_identifiers, state, session_options, env, providers), Optimizer(model_identifiers, state, session_options, env, providers), LinearLRScheduler(optimizer, warmup_step_count, total_step_count)
|
| Import | from onnxruntime.training.api import CheckpointState, Module, Optimizer
|
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | orttraining/orttraining/training_api/module.cc:L280-285 (Module), orttraining/orttraining/training_api/optimizer.cc:L183-188 (Optimizer), orttraining/orttraining/training_api/lr_scheduler.h:L74-81 (LinearLRScheduler) |
| Last Updated | 2026-02-10 |
Description
This implementation covers the construction of the three core training components that together form the on-device training pipeline:
- Module -- Loads the training ONNX model (and optionally the eval model) into inference sessions, resolves parameter placement across devices, and prepares input/output bindings.
- Optimizer -- Loads the optimizer ONNX model, initializes or restores momentum states (first and second order moments for AdamW), and constructs tensor sequence inputs for the optimizer graph.
- LinearLRScheduler -- Wraps the Optimizer to provide linearly decaying learning rates with optional warmup phases.
Both Module and Optimizer take a non-owning pointer to CheckpointState, establishing a shared-state architecture where parameter updates are immediately visible to all components.
API Signature
C++ Module Constructor
Module(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());
C++ Optimizer Constructor
Optimizer(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());
C++ LinearLRScheduler Constructor
LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
int64_t warmup_step_count,
int64_t total_step_count);
Python
from onnxruntime.training.api import CheckpointState, Module, Optimizer
module = Module(train_model_path, state, eval_model_path, device="cpu")
optimizer = Optimizer(optimizer_model_path, module)
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| model_identifiers (C++) / train_model_path (Python) | ModelIdentifiers / str |
Path(s) to the training ONNX model (and optionally eval model) |
| state | CheckpointState* / CheckpointState |
Loaded checkpoint state containing parameters and optimizer states |
| session_options (C++) | SessionOptions |
ORT session configuration (prepacking is automatically disabled) |
| env (C++) | Environment |
ORT execution environment |
| providers (C++) | vector<shared_ptr<IExecutionProvider>> |
Execution providers (CPU, CUDA, etc.) |
| eval_model_path (Python) | str |
Path to the eval ONNX model (optional) |
| device (Python) | str |
Target device for computation (e.g., "cpu", "cuda")
|
| warmup_step_count | int64_t |
Number of warmup steps for LinearLRScheduler |
| total_step_count | int64_t |
Total number of training steps for LinearLRScheduler |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | Training artifact files | training_model.onnx, optimizer_model.onnx, optionally eval_model.onnx
|
| Input | CheckpointState |
Loaded checkpoint with parameter values and optimizer states |
| Output | Module |
Configured module ready for TrainStep and EvalStep calls
|
| Output | Optimizer |
Configured optimizer ready for Step calls
|
| Output | LinearLRScheduler (optional) |
Configured scheduler that adjusts optimizer learning rate |
Code Reference
From orttraining/orttraining/training_api/module.cc:
Module::Module(const ModelIdentifiers& model_identifiers,
CheckpointState* state,
const onnxruntime::SessionOptions& session_options,
const Environment& env,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
[[maybe_unused]] gsl::span<OrtCustomOpDomain* const> op_domains)
: state_{state} {
// Enforce weight prepacking is disabled
// ...
train_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
// Load the training model
ORT_THROW_IF_ERROR(/* ... load model ... */);
// ...
}
From orttraining/orttraining/training_api/lr_scheduler.h:
struct LinearLRScheduler : public MultiplicativeLRSchedulerBase {
explicit LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
int64_t warmup_step_count,
int64_t total_step_count)
: MultiplicativeLRSchedulerBase(optimizer),
warmup_step_count_(warmup_step_count),
total_step_count_(total_step_count) {
ORT_THROW_IF_ERROR(Step());
}
};
Usage Example
Python
from onnxruntime.training.api import CheckpointState, Module, Optimizer
# Load checkpoint
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")
# Create module with training and eval models
module = Module(
"training_artifacts/training_model.onnx",
state,
"training_artifacts/eval_model.onnx",
device="cpu",
)
# Create optimizer
optimizer = Optimizer("training_artifacts/optimizer_model.onnx", module)
# Module and optimizer are now ready for the training loop
Implements
Principle:Microsoft_Onnxruntime_Training_Component_Assembly
Related Pages
- CheckpointState Load -- Provides the checkpoint state used during initialization
- Generate Artifacts -- Produces the artifact files consumed here
- Module TrainStep -- Uses the assembled components for training execution