Implementation:Microsoft Onnxruntime TrainingUtil
| Knowledge Sources | |
|---|---|
| Domains | Training, Models, Utilities |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Implements the DataSet, RandomDataSet, TrainingUtil, LossScaler, and LearningRateScheduler classes for training data management, debugging, loss scaling, and learning rate scheduling.
Description
This source file (`training_util.cc`) provides the implementations for the training utility classes declared in `training_util.h`:
- DataSet: Implements sample storage and batch retrieval. `AddData` supports both raw `OrtValue` vectors and `TensorProto` feature vectors (with internal buffer management). `GetKthBatch` constructs batched tensors by concatenating samples along a new batch dimension. `TotalBatch` computes the number of batches given a batch size. `GetTensorDimensionsFromInputs` extracts dimension metadata from input tensor shapes for configuration mapping. `RandomShuffle` performs in-place random shuffling of samples.
- RandomDataSet: Overrides `GetKthBatch` to generate zero-filled tensors of specified shapes and types (INT64, INT32, FLOAT) for testing/benchmarking without actual data.
- TrainingUtil: Static utility methods `PrintNameMLValMap` and `PrintTensor` provide debug printing for OrtValue maps and tensors respectively, supporting float, int64, and bool data types.
- LossScaler: Implements dynamic loss scaling for mixed-precision training. `UpdateLossScale` doubles the loss scale after a configurable number of stable steps (no overflow) and halves it on overflow, bounded by min/max limits. `SaveToString` and `LoadFromString` provide checkpoint serialization.
- LearningRateScheduler::Create: Factory method that instantiates the appropriate scheduler (NoWarmup, Cosine, Constant, Linear, or Poly) based on `LearningRateParameters::warmup_mode`.
Usage
Use this implementation when running ORT Training model runners (e.g., BERT training). The DataSet class is used by DataLoader; LossScaler is used for mixed-precision training; LearningRateScheduler provides the older-style scheduling (distinct from the training_api LRScheduler).
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/models/runner/training_util.cc
- Lines: 1-233
Signature
// DataSet implementation
DataSet::DataSet(const vector<string>& tensor_names);
DataSet::~DataSet();
common::Status DataSet::AddData(DataSet::SampleType&& single_sample);
common::Status DataSet::AddData(const vector<ONNX_NAMESPACE::TensorProto>& features);
size_t DataSet::TotalBatch(size_t batch_size) const;
std::vector<OrtValue> DataSet::GetKthBatch(size_t batch_size, size_t k_th, AllocatorPtr allocator) const;
void DataSet::RandomShuffle();
// RandomDataSet implementation
std::vector<OrtValue> RandomDataSet::GetKthBatch(size_t batch_size, size_t k_th, AllocatorPtr allocator) const;
// TrainingUtil implementation
void TrainingUtil::PrintNameMLValMap(const NameMLValMap& mlvalue_map);
void TrainingUtil::PrintTensor(const string& name, const Tensor& tensor, ostream& os);
// LossScaler implementation
std::string LossScaler::SaveToString() const;
Status LossScaler::LoadFromString(const std::string& input);
// LearningRateScheduler factory
std::unique_ptr<LearningRateScheduler> LearningRateScheduler::Create(
LearningRateParameters& lr_params, size_t training_step_count);
Import
#include "orttraining/models/runner/training_util.h"
I/O Contract
| Function | Inputs | Outputs | Description |
|---|---|---|---|
| DataSet::AddData | SampleType (vector of OrtValues) | Status | Adds one data sample to the dataset |
| DataSet::GetKthBatch | batch_size, k_th index, allocator | vector<OrtValue> | Returns the k-th batch of data as batched tensors |
| LossScaler::UpdateLossScale | is_all_finite (bool) | void | Updates loss scale based on gradient overflow detection |
| LearningRateScheduler::Create | LearningRateParameters, step_count | unique_ptr<LearningRateScheduler> | Factory for creating the appropriate scheduler type |
Usage Examples
#include "orttraining/models/runner/training_util.h"
using namespace onnxruntime::training;
// Create dataset
DataSet dataset({"input_ids", "attention_mask", "labels"});
DataSet::SampleType sample = std::make_unique<std::vector<OrtValue>>();
// ... populate sample ...
dataset.AddData(std::move(sample));
// Get a batch
auto batch = dataset.GetKthBatch(32, 0);
// Dynamic loss scaling
LossScaler scaler("loss_scale", true, 65536.0f, 2000);
scaler.UpdateLossScale(true); // stable step
float scale = scaler.GetLossScale();
// Create LR scheduler
LearningRateParameters params{0.001f, 0.1f, "Linear", "Learning_Rate"};
auto scheduler = LearningRateScheduler::Create(params, 10000);
float lr = scheduler->GetLearningRate(500);