Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime TrainingSession

From Leeroopedia


Knowledge Sources
Domains Training, API, Session
Last Updated 2026-02-10 04:00 GMT

Overview

Defines the TrainingSession class, which is the high-level wrapper combining Module, Optimizer, and LRScheduler and serves as the only class exposed via the ORT C APIs.

Description

The `TrainingSession` class is the top-level orchestrator for the ORT Training API. It aggregates the three core training components -- `Module` (for forward/backward computation), `Optimizer` (for gradient updates), and `LRSchedulerBase` (for learning rate scheduling) -- behind a single, unified interface. The session holds a non-owning pointer to `CheckpointState`, which must outlive the session.

Key methods include:

  • `TrainStep` / `EvalStep`: Delegate to `Module::TrainStep` and `Module::EvalStep` respectively, with additional `RunOptions`.
  • `OptimizerStep`: Executes a single optimizer step via the internal `Optimizer`.
  • `LazyResetGrad`: Resets all parameter gradients.
  • `SetLearningRate` / `GetLearningRate`: Direct learning rate manipulation.
  • `RegisterScheduler`: Registers a learning rate scheduler factory function and initial learning rate.
  • `SchedulerStep`: Advances the registered scheduler by one step.
  • Parameter buffer operations: `CopyParametersToBuffer` and `CopyBufferToParameters`.
  • Model export: `ExportModelForInferencing` (non-minimal build only).
  • Input/output introspection for both training and eval graphs.

The class is non-copyable, non-movable (`ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE`).

Usage

Use this class when you need a complete training session that manages the training loop lifecycle. This is the class instantiated by the ORT C API's `CreateTrainingSession`.

Code Reference

Source Location

Signature

class TrainingSession {
 public:
  TrainingSession(const Environment& session_env,
                  const SessionOptions& session_options,
                  const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
                  CheckpointState* state,
                  const ModelIdentifiers& model_identifiers,
                  gsl::span<OrtCustomOpDomain* const> custom_op_domains = {});

  Status RegisterScheduler(
      const std::function<std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
      float initial_lr);

  Status TrainStep(const RunOptions& run_options,
                   const std::vector<OrtValue>& inputs,
                   std::vector<OrtValue>& fetches);
  Status EvalStep(const RunOptions& run_options,
                  const std::vector<OrtValue>& inputs,
                  std::vector<OrtValue>& fetches) const;
  Status LazyResetGrad();
  Status OptimizerStep(const RunOptions& run_options);
  Status SetLearningRate(float learning_rate) noexcept;
  float GetLearningRate() const;
  Status SchedulerStep() noexcept;
  size_t GetParametersSize(const bool trainable_only = true) const;
  Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true);
  Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
  Status ExportModelForInferencing(const std::string& inference_model_path,
                                   gsl::span<const std::string> graph_output_names) const;

 private:
  CheckpointState* state_;
  std::unique_ptr<Module> module_;
  std::shared_ptr<Optimizer> optimizer_;
  std::unique_ptr<LRSchedulerBase> scheduler_;
};

Import

#include "orttraining/training_api/training_session.h"

I/O Contract

Method Inputs Outputs Description
TrainStep RunOptions, vector<OrtValue> inputs vector<OrtValue> fetches, Status Runs forward+backward, returns forward outputs
EvalStep RunOptions, vector<OrtValue> inputs vector<OrtValue> fetches, Status Runs forward-only evaluation
OptimizerStep RunOptions Status Executes one optimizer update
RegisterScheduler scheduler factory, initial_lr Status Registers a learning rate scheduler
SchedulerStep (none) Status Advances the LR scheduler by one step

Usage Examples

#include "orttraining/training_api/training_session.h"

using namespace onnxruntime::training::api;

TrainingSession session(env, session_options, providers, &state, model_ids);

// Register linear LR scheduler
session.RegisterScheduler(
    [](std::shared_ptr<Optimizer> opt) {
      return std::make_unique<LinearLRScheduler>(opt, 100, 1000);
    }, 0.001f);

// Training loop
for (int step = 0; step < 1000; ++step) {
  session.TrainStep(run_options, inputs, outputs);
  session.OptimizerStep(run_options);
  session.SchedulerStep();
  session.LazyResetGrad();
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment