Implementation:Microsoft Onnxruntime TrainingSession
| Knowledge Sources | |
|---|---|
| Domains | Training, API, Session |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Defines the TrainingSession class, which is the high-level wrapper combining Module, Optimizer, and LRScheduler and serves as the only class exposed via the ORT C APIs.
Description
The `TrainingSession` class is the top-level orchestrator for the ORT Training API. It aggregates the three core training components -- `Module` (for forward/backward computation), `Optimizer` (for gradient updates), and `LRSchedulerBase` (for learning rate scheduling) -- behind a single, unified interface. The session holds a non-owning pointer to `CheckpointState`, which must outlive the session.
Key methods include:
- `TrainStep` / `EvalStep`: Delegate to `Module::TrainStep` and `Module::EvalStep` respectively, with additional `RunOptions`.
- `OptimizerStep`: Executes a single optimizer step via the internal `Optimizer`.
- `LazyResetGrad`: Resets all parameter gradients.
- `SetLearningRate` / `GetLearningRate`: Direct learning rate manipulation.
- `RegisterScheduler`: Registers a learning rate scheduler factory function and initial learning rate.
- `SchedulerStep`: Advances the registered scheduler by one step.
- Parameter buffer operations: `CopyParametersToBuffer` and `CopyBufferToParameters`.
- Model export: `ExportModelForInferencing` (non-minimal build only).
- Input/output introspection for both training and eval graphs.
The class is non-copyable, non-movable (`ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE`).
Usage
Use this class when you need a complete training session that manages the training loop lifecycle. This is the class instantiated by the ORT C API's `CreateTrainingSession`.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_api/training_session.h
- Lines: 1-86
Signature
class TrainingSession {
public:
TrainingSession(const Environment& session_env,
const SessionOptions& session_options,
const std::vector<std::shared_ptr<IExecutionProvider>>& providers,
CheckpointState* state,
const ModelIdentifiers& model_identifiers,
gsl::span<OrtCustomOpDomain* const> custom_op_domains = {});
Status RegisterScheduler(
const std::function<std::unique_ptr<LRSchedulerBase>(std::shared_ptr<Optimizer>)>& get_scheduler,
float initial_lr);
Status TrainStep(const RunOptions& run_options,
const std::vector<OrtValue>& inputs,
std::vector<OrtValue>& fetches);
Status EvalStep(const RunOptions& run_options,
const std::vector<OrtValue>& inputs,
std::vector<OrtValue>& fetches) const;
Status LazyResetGrad();
Status OptimizerStep(const RunOptions& run_options);
Status SetLearningRate(float learning_rate) noexcept;
float GetLearningRate() const;
Status SchedulerStep() noexcept;
size_t GetParametersSize(const bool trainable_only = true) const;
Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true);
Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true);
Status ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const;
private:
CheckpointState* state_;
std::unique_ptr<Module> module_;
std::shared_ptr<Optimizer> optimizer_;
std::unique_ptr<LRSchedulerBase> scheduler_;
};
Import
#include "orttraining/training_api/training_session.h"
I/O Contract
| Method | Inputs | Outputs | Description |
|---|---|---|---|
| TrainStep | RunOptions, vector<OrtValue> inputs | vector<OrtValue> fetches, Status | Runs forward+backward, returns forward outputs |
| EvalStep | RunOptions, vector<OrtValue> inputs | vector<OrtValue> fetches, Status | Runs forward-only evaluation |
| OptimizerStep | RunOptions | Status | Executes one optimizer update |
| RegisterScheduler | scheduler factory, initial_lr | Status | Registers a learning rate scheduler |
| SchedulerStep | (none) | Status | Advances the LR scheduler by one step |
Usage Examples
#include "orttraining/training_api/training_session.h"
using namespace onnxruntime::training::api;
TrainingSession session(env, session_options, providers, &state, model_ids);
// Register linear LR scheduler
session.RegisterScheduler(
[](std::shared_ptr<Optimizer> opt) {
return std::make_unique<LinearLRScheduler>(opt, 100, 1000);
}, 0.001f);
// Training loop
for (int step = 0; step < 1000; ++step) {
session.TrainStep(run_options, inputs, outputs);
session.OptimizerStep(run_options);
session.SchedulerStep();
session.LazyResetGrad();
}