Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime TrainingDataLoader

From Leeroopedia


Knowledge Sources
Domains Training, DataLoading, Models
Last Updated 2026-02-10 04:00 GMT

Overview

Defines data loading infrastructure for ORT training, including thread-safe buffering, file-based data loading with prefetching, and a single-dataset loader.

Description

This header provides the data loading layer for ORT Training model runners. It includes:

  • `DataSetBuffer`: A thread-safe buffer that stores `DataSet` objects by index. It uses a mutex and condition variable so that `Get(index)` blocks until the requested dataset is available, enabling asynchronous prefetching. `Set(index, data_set)` stores a dataset and notifies waiting threads; `Remove(index)` releases a dataset.
  • `IDataLoader`: An abstract interface defining the data loading contract. Key methods include `InitializeDataSetIndex`, `CurrentDataSet`, `MoveToNextDataSet`, `NumShards`, and `DataSetTensorNames`.
  • `DataLoader`: The primary file-based implementation that reads training data from binary files in a specific format: each sample contains byte-size prefixed TensorProto features. It supports:
 - Asynchronous prefetching with a configurable number of preloaded files (`max_num_files_preload`, default 2).
 - Distributed training sharding via `world_rank` and `world_size` parameters.
 - Input name remapping via `MapStringToString`.
 - A single-threaded thread pool for serialized loading operations.
  • `SingleDataLoader`: A simple loader wrapping a single pre-loaded `DataSet` instance. It always returns the same dataset and reports one shard.

The binary file format stores samples as: `[Sample ByteSize][Feature0 ByteSize][Feature0 TensorProto]...[FeatureN ByteSize][FeatureN TensorProto]`, with all byte-size fields stored as `uint32_t`.

Usage

Use this header when implementing training runners that need to load batched training data from binary files. The `DataLoader` is used by BERT and other model training runners in the ORT Training examples.

Code Reference

Source Location

Signature

class DataSetBuffer {
 public:
  std::shared_ptr<DataSet> Get(size_t index);
  void Set(size_t index, std::shared_ptr<DataSet> data_set);
  bool Remove(size_t index);
};

class IDataLoader {
 public:
  virtual Status InitializeDataSetIndex(size_t initial_data_set_index) = 0;
  virtual std::shared_ptr<DataSet> CurrentDataSet() = 0;
  virtual size_t CurrentDataSetIndex() const = 0;
  virtual size_t NumShards() const = 0;
  virtual std::shared_ptr<DataSet> MoveToNextDataSet() = 0;
  virtual const VectorString& DataSetTensorNames() const = 0;
};

class DataLoader : public IDataLoader {
 public:
  DataLoader(const MapStringToString& input_name_map,
             const PathString& dir_path,
             size_t max_num_files_preload = 2,
             size_t world_rank = 0,
             size_t world_size = 1);
  // ... IDataLoader interface implementation ...
};

class SingleDataLoader : public IDataLoader {
 public:
  SingleDataLoader(std::shared_ptr<DataSet> single_data_set, VectorString input_tensor_names);
  // ... IDataLoader interface implementation ...
};

Import

#include "orttraining/models/runner/data_loader.h"

I/O Contract

Class/Method Inputs Outputs Description
DataLoader (ctor) input_name_map, dir_path, max_preload, rank, world_size DataLoader instance Creates a file-based data loader with distributed sharding
Get (DataSetBuffer) index (size_t) shared_ptr<DataSet> Blocks until dataset at index is ready, then returns it
MoveToNextDataSet (none) shared_ptr<DataSet> Advances to the next data shard (file) with prefetching
CurrentDataSet (none) shared_ptr<DataSet> Returns the currently active dataset
NumShards (none) size_t Returns the total number of data files/shards

Usage Examples

#include "orttraining/models/runner/data_loader.h"

using namespace onnxruntime::training;

MapStringToString name_map = {{"input_ids", "input1"}, {"labels", "input2"}};
DataLoader loader(name_map, ORT_TSTR("/data/training/"), 2, 0, 1);
loader.InitializeDataSetIndex(0);

auto dataset = loader.CurrentDataSet();
size_t total_batches = dataset->TotalBatch(batch_size);

for (size_t b = 0; b < total_batches; ++b) {
  auto batch = dataset->GetKthBatch(batch_size, b);
  // ... train on batch ...
}

// Move to next shard
loader.MoveToNextDataSet();

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment