Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime TrainingUtil

From Leeroopedia


Knowledge Sources
Domains Training, Models, Utilities
Last Updated 2026-02-10 04:00 GMT

Overview

Implements the DataSet, RandomDataSet, TrainingUtil, LossScaler, and LearningRateScheduler classes for training data management, debugging, loss scaling, and learning rate scheduling.

Description

This source file (`training_util.cc`) provides the implementations for the training utility classes declared in `training_util.h`:

  • DataSet: Implements sample storage and batch retrieval. `AddData` supports both raw `OrtValue` vectors and `TensorProto` feature vectors (with internal buffer management). `GetKthBatch` constructs batched tensors by concatenating samples along a new batch dimension. `TotalBatch` computes the number of batches given a batch size. `GetTensorDimensionsFromInputs` extracts dimension metadata from input tensor shapes for configuration mapping. `RandomShuffle` performs in-place random shuffling of samples.
  • RandomDataSet: Overrides `GetKthBatch` to generate zero-filled tensors of specified shapes and types (INT64, INT32, FLOAT) for testing/benchmarking without actual data.
  • TrainingUtil: Static utility methods `PrintNameMLValMap` and `PrintTensor` provide debug printing for OrtValue maps and tensors respectively, supporting float, int64, and bool data types.
  • LossScaler: Implements dynamic loss scaling for mixed-precision training. `UpdateLossScale` doubles the loss scale after a configurable number of stable steps (no overflow) and halves it on overflow, bounded by min/max limits. `SaveToString` and `LoadFromString` provide checkpoint serialization.
  • LearningRateScheduler::Create: Factory method that instantiates the appropriate scheduler (NoWarmup, Cosine, Constant, Linear, or Poly) based on `LearningRateParameters::warmup_mode`.

Usage

Use this implementation when running ORT Training model runners (e.g., BERT training). The DataSet class is used by DataLoader; LossScaler is used for mixed-precision training; LearningRateScheduler provides the older-style scheduling (distinct from the training_api LRScheduler).

Code Reference

Source Location

Signature

// DataSet implementation
DataSet::DataSet(const vector<string>& tensor_names);
DataSet::~DataSet();
common::Status DataSet::AddData(DataSet::SampleType&& single_sample);
common::Status DataSet::AddData(const vector<ONNX_NAMESPACE::TensorProto>& features);
size_t DataSet::TotalBatch(size_t batch_size) const;
std::vector<OrtValue> DataSet::GetKthBatch(size_t batch_size, size_t k_th, AllocatorPtr allocator) const;
void DataSet::RandomShuffle();

// RandomDataSet implementation
std::vector<OrtValue> RandomDataSet::GetKthBatch(size_t batch_size, size_t k_th, AllocatorPtr allocator) const;

// TrainingUtil implementation
void TrainingUtil::PrintNameMLValMap(const NameMLValMap& mlvalue_map);
void TrainingUtil::PrintTensor(const string& name, const Tensor& tensor, ostream& os);

// LossScaler implementation
std::string LossScaler::SaveToString() const;
Status LossScaler::LoadFromString(const std::string& input);

// LearningRateScheduler factory
std::unique_ptr<LearningRateScheduler> LearningRateScheduler::Create(
    LearningRateParameters& lr_params, size_t training_step_count);

Import

#include "orttraining/models/runner/training_util.h"

I/O Contract

Function Inputs Outputs Description
DataSet::AddData SampleType (vector of OrtValues) Status Adds one data sample to the dataset
DataSet::GetKthBatch batch_size, k_th index, allocator vector<OrtValue> Returns the k-th batch of data as batched tensors
LossScaler::UpdateLossScale is_all_finite (bool) void Updates loss scale based on gradient overflow detection
LearningRateScheduler::Create LearningRateParameters, step_count unique_ptr<LearningRateScheduler> Factory for creating the appropriate scheduler type

Usage Examples

#include "orttraining/models/runner/training_util.h"

using namespace onnxruntime::training;

// Create dataset
DataSet dataset({"input_ids", "attention_mask", "labels"});
DataSet::SampleType sample = std::make_unique<std::vector<OrtValue>>();
// ... populate sample ...
dataset.AddData(std::move(sample));

// Get a batch
auto batch = dataset.GetKthBatch(32, 0);

// Dynamic loss scaling
LossScaler scaler("loss_scale", true, 65536.0f, 2000);
scaler.UpdateLossScale(true);  // stable step
float scale = scaler.GetLossScale();

// Create LR scheduler
LearningRateParameters params{0.001f, 0.1f, "Linear", "Learning_Rate"};
auto scheduler = LearningRateScheduler::Create(params, 10000);
float lr = scheduler->GetLearningRate(500);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment