Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime LrScheduler

From Leeroopedia


Knowledge Sources
Domains Training, API, Optimization
Last Updated 2026-02-10 04:00 GMT

Overview

Defines learning rate scheduler base classes and a concrete LinearLRScheduler for adjusting the optimizer learning rate during training.

Description

This header provides a hierarchy of learning rate schedulers for the ORT Training API. `LRSchedulerBase` is the abstract base that holds a shared pointer to an `Optimizer` and exposes a `Step()` method which computes a new learning rate and applies it to the optimizer. It provides protected accessors for the current step and initial learning rate from the optimizer state. `MultiplicativeLRSchedulerBase` extends this by computing learning rate as `initial_lr * multiplicative_factor`, where subclasses implement `ComputeLRMultiplicativeFactorInternal(step)`. The concrete `LinearLRScheduler` implements a linear warmup followed by linear decay: during warmup, the factor increases linearly from 0 to 1; after warmup, it decays linearly to 0 over the total step count. The scheduler calls `Step()` in its constructor to initialize the learning rate immediately.

Usage

Use this header when you need to apply a learning rate schedule to an ORT training optimizer. The `LinearLRScheduler` is suitable for transformer-style training with a warmup phase followed by linear decay. Custom schedulers can be created by subclassing `LRSchedulerBase` or `MultiplicativeLRSchedulerBase`.

Code Reference

Source Location

Signature

struct LRSchedulerBase {
  LRSchedulerBase(std::shared_ptr<Optimizer> optimizer);
  virtual ~LRSchedulerBase() = default;
  Status Step();
protected:
  int64_t GetStepInternal();
  float GetInitialLRInternal();
private:
  virtual float ComputeLearningRateInternal() = 0;
};

struct MultiplicativeLRSchedulerBase : public LRSchedulerBase {
  MultiplicativeLRSchedulerBase(std::shared_ptr<Optimizer> optimizer);
private:
  float ComputeLearningRateInternal() override;
  virtual float ComputeLRMultiplicativeFactorInternal(int64_t step) = 0;
};

struct LinearLRScheduler : public MultiplicativeLRSchedulerBase {
  explicit LinearLRScheduler(std::shared_ptr<Optimizer> optimizer,
                             int64_t warmup_step_count,
                             int64_t total_step_count);
private:
  float ComputeLRMultiplicativeFactorInternal(int64_t step) override;
};

Import

#include "orttraining/training_api/lr_scheduler.h"

I/O Contract

Class Input Output Description
LRSchedulerBase::Step (none) Status Computes new LR via virtual method and updates the optimizer
LinearLRScheduler (ctor) Optimizer, warmup_step_count, total_step_count (initialized scheduler) Creates a linear LR scheduler and calls Step() immediately
MultiplicativeLRSchedulerBase step (int64_t) float factor Returns initial_lr * factor, where factor is computed by subclass

Usage Examples

#include "orttraining/training_api/lr_scheduler.h"

using namespace onnxruntime::training::api;

auto optimizer = std::make_shared<Optimizer>(/* ... */);
int64_t warmup_steps = 100;
int64_t total_steps = 1000;

LinearLRScheduler scheduler(optimizer, warmup_steps, total_steps);

// Each training step:
for (int64_t step = 0; step < total_steps; ++step) {
  // ... perform train step, optimizer step ...
  scheduler.Step();
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment