Implementation:Microsoft Onnxruntime LrScheduler
| Knowledge Sources | |
|---|---|
| Domains | Training, API, Optimization |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Defines learning rate scheduler base classes and a concrete LinearLRScheduler for adjusting the optimizer learning rate during training.
Description
This header provides a hierarchy of learning rate schedulers for the ORT Training API. `LRSchedulerBase` is the abstract base that holds a shared pointer to an `Optimizer` and exposes a `Step()` method which computes a new learning rate and applies it to the optimizer. It provides protected accessors for the current step and initial learning rate from the optimizer state. `MultiplicativeLRSchedulerBase` extends this by computing learning rate as `initial_lr * multiplicative_factor`, where subclasses implement `ComputeLRMultiplicativeFactorInternal(step)`. The concrete `LinearLRScheduler` implements a linear warmup followed by linear decay: during warmup, the factor increases linearly from 0 to 1; after warmup, it decays linearly to 0 over the total step count. The scheduler calls `Step()` in its constructor to initialize the learning rate immediately.
Usage
Use this header when you need to apply a learning rate schedule to an ORT training optimizer. The `LinearLRScheduler` is suitable for transformer-style training with a warmup phase followed by linear decay. Custom schedulers can be created by subclassing `LRSchedulerBase` or `MultiplicativeLRSchedulerBase`.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_api/lr_scheduler.h
- Lines: 1-92
Signature
struct LRSchedulerBase {
LRSchedulerBase(std::shared_ptr<Optimizer> optimizer);
virtual ~LRSchedulerBase() = default;
Status Step();
protected:
int64_t GetStepInternal();
float GetInitialLRInternal();
private:
virtual float ComputeLearningRateInternal() = 0;
};
struct MultiplicativeLRSchedulerBase : public LRSchedulerBase {
MultiplicativeLRSchedulerBase(std::shared_ptr<Optimizer> optimizer);
private:
float ComputeLearningRateInternal() override;
virtual float ComputeLRMultiplicativeFactorInternal(int64_t step) = 0;
};
struct LinearLRScheduler : public MultiplicativeLRSchedulerBase {
explicit LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
int64_t warmup_step_count,
int64_t total_step_count);
private:
float ComputeLRMultiplicativeFactorInternal(int64_t step) override;
};
Import
#include "orttraining/training_api/lr_scheduler.h"
I/O Contract
| Class | Input | Output | Description |
|---|---|---|---|
| LRSchedulerBase::Step | (none) | Status | Computes new LR via virtual method and updates the optimizer |
| LinearLRScheduler (ctor) | Optimizer, warmup_step_count, total_step_count | (initialized scheduler) | Creates a linear LR scheduler and calls Step() immediately |
| MultiplicativeLRSchedulerBase | step (int64_t) | float factor | Returns initial_lr * factor, where factor is computed by subclass |
Usage Examples
#include "orttraining/training_api/lr_scheduler.h"
using namespace onnxruntime::training::api;
auto optimizer = std::make_shared<Optimizer>(/* ... */);
int64_t warmup_steps = 100;
int64_t total_steps = 1000;
LinearLRScheduler scheduler(optimizer, warmup_steps, total_steps);
// Each training step:
for (int64_t step = 0; step < total_steps; ++step) {
// ... perform train step, optimizer step ...
scheduler.Step();
}