Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime Module Optimizer Scheduler Init

From Leeroopedia


Overview

Initializes the core training components -- Module, Optimizer, and LinearLRScheduler -- from training artifacts and checkpoint state to form the complete on-device training pipeline.

Metadata

Field Value
Implementation Name Module_Optimizer_Scheduler_Init
Type API Doc
Language C++ and Python
API Python: Module(train_model_path, state, eval_model_path, device), Optimizer(optimizer_model_path, module). C++: Module(model_identifiers, state, session_options, env, providers), Optimizer(model_identifiers, state, session_options, env, providers), LinearLRScheduler(optimizer, warmup_step_count, total_step_count)
Import from onnxruntime.training.api import CheckpointState, Module, Optimizer
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference orttraining/orttraining/training_api/module.cc:L280-285 (Module), orttraining/orttraining/training_api/optimizer.cc:L183-188 (Optimizer), orttraining/orttraining/training_api/lr_scheduler.h:L74-81 (LinearLRScheduler)
Last Updated 2026-02-10

Description

This implementation covers the construction of the three core training components that together form the on-device training pipeline:

  • Module -- Loads the training ONNX model (and optionally the eval model) into inference sessions, resolves parameter placement across devices, and prepares input/output bindings.
  • Optimizer -- Loads the optimizer ONNX model, initializes or restores momentum states (first and second order moments for AdamW), and constructs tensor sequence inputs for the optimizer graph.
  • LinearLRScheduler -- Wraps the Optimizer to provide linearly decaying learning rates with optional warmup phases.

Both Module and Optimizer take a non-owning pointer to CheckpointState, establishing a shared-state architecture where parameter updates are immediately visible to all components.

API Signature

C++ Module Constructor

Module(const ModelIdentifiers& model_identifiers,
       CheckpointState* state,
       const onnxruntime::SessionOptions& session_options,
       const Environment& env,
       const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
       gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());

C++ Optimizer Constructor

Optimizer(const ModelIdentifiers& model_identifiers,
          CheckpointState* state,
          const onnxruntime::SessionOptions& session_options,
          const Environment& env,
          const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
          gsl::span<OrtCustomOpDomain* const> op_domains = gsl::span<OrtCustomOpDomain* const>());

C++ LinearLRScheduler Constructor

LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
                  int64_t warmup_step_count,
                  int64_t total_step_count);

Python

from onnxruntime.training.api import CheckpointState, Module, Optimizer

module = Module(train_model_path, state, eval_model_path, device="cpu")
optimizer = Optimizer(optimizer_model_path, module)

Key Parameters

Parameter Type Description
model_identifiers (C++) / train_model_path (Python) ModelIdentifiers / str Path(s) to the training ONNX model (and optionally eval model)
state CheckpointState* / CheckpointState Loaded checkpoint state containing parameters and optimizer states
session_options (C++) SessionOptions ORT session configuration (prepacking is automatically disabled)
env (C++) Environment ORT execution environment
providers (C++) vector<shared_ptr<IExecutionProvider>> Execution providers (CPU, CUDA, etc.)
eval_model_path (Python) str Path to the eval ONNX model (optional)
device (Python) str Target device for computation (e.g., "cpu", "cuda")
warmup_step_count int64_t Number of warmup steps for LinearLRScheduler
total_step_count int64_t Total number of training steps for LinearLRScheduler

I/O Contract

Direction Type Description
Input Training artifact files training_model.onnx, optimizer_model.onnx, optionally eval_model.onnx
Input CheckpointState Loaded checkpoint with parameter values and optimizer states
Output Module Configured module ready for TrainStep and EvalStep calls
Output Optimizer Configured optimizer ready for Step calls
Output LinearLRScheduler (optional) Configured scheduler that adjusts optimizer learning rate

Code Reference

From orttraining/orttraining/training_api/module.cc:

Module::Module(const ModelIdentifiers& model_identifiers,
               CheckpointState* state,
               const onnxruntime::SessionOptions& session_options,
               const Environment& env,
               const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
               [[maybe_unused]] gsl::span<OrtCustomOpDomain* const> op_domains)
    : state_{state} {
  // Enforce weight prepacking is disabled
  // ...
  train_sess_ = std::make_unique<onnxruntime::InferenceSession>(session_options, env);
  // Load the training model
  ORT_THROW_IF_ERROR(/* ... load model ... */);
  // ...
}

From orttraining/orttraining/training_api/lr_scheduler.h:

struct LinearLRScheduler : public MultiplicativeLRSchedulerBase {
  explicit LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
                             int64_t warmup_step_count,
                             int64_t total_step_count)
      : MultiplicativeLRSchedulerBase(optimizer),
        warmup_step_count_(warmup_step_count),
        total_step_count_(total_step_count) {
    ORT_THROW_IF_ERROR(Step());
  }
};

Usage Example

Python

from onnxruntime.training.api import CheckpointState, Module, Optimizer

# Load checkpoint
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")

# Create module with training and eval models
module = Module(
    "training_artifacts/training_model.onnx",
    state,
    "training_artifacts/eval_model.onnx",
    device="cpu",
)

# Create optimizer
optimizer = Optimizer("training_artifacts/optimizer_model.onnx", module)

# Module and optimizer are now ready for the training loop

Implements

Principle:Microsoft_Onnxruntime_Training_Component_Assembly

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment