Implementation:Microsoft Onnxruntime TrainingDataLoader
| Knowledge Sources | |
|---|---|
| Domains | Training, DataLoading, Models |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Defines data loading infrastructure for ORT training, including thread-safe buffering, file-based data loading with prefetching, and a single-dataset loader.
Description
This header provides the data loading layer for ORT Training model runners. It includes:
- `DataSetBuffer`: A thread-safe buffer that stores `DataSet` objects by index. It uses a mutex and condition variable so that `Get(index)` blocks until the requested dataset is available, enabling asynchronous prefetching. `Set(index, data_set)` stores a dataset and notifies waiting threads; `Remove(index)` releases a dataset.
- `IDataLoader`: An abstract interface defining the data loading contract. Key methods include `InitializeDataSetIndex`, `CurrentDataSet`, `MoveToNextDataSet`, `NumShards`, and `DataSetTensorNames`.
- `DataLoader`: The primary file-based implementation that reads training data from binary files in a specific format: each sample contains byte-size prefixed TensorProto features. It supports:
- Asynchronous prefetching with a configurable number of preloaded files (`max_num_files_preload`, default 2). - Distributed training sharding via `world_rank` and `world_size` parameters. - Input name remapping via `MapStringToString`. - A single-threaded thread pool for serialized loading operations.
- `SingleDataLoader`: A simple loader wrapping a single pre-loaded `DataSet` instance. It always returns the same dataset and reports one shard.
The binary file format stores samples as: `[Sample ByteSize][Feature0 ByteSize][Feature0 TensorProto]...[FeatureN ByteSize][FeatureN TensorProto]`, with all byte-size fields stored as `uint32_t`.
Usage
Use this header when implementing training runners that need to load batched training data from binary files. The `DataLoader` is used by BERT and other model training runners in the ORT Training examples.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/models/runner/data_loader.h
- Lines: 1-191
Signature
class DataSetBuffer {
public:
std::shared_ptr<DataSet> Get(size_t index);
void Set(size_t index, std::shared_ptr<DataSet> data_set);
bool Remove(size_t index);
};
class IDataLoader {
public:
virtual Status InitializeDataSetIndex(size_t initial_data_set_index) = 0;
virtual std::shared_ptr<DataSet> CurrentDataSet() = 0;
virtual size_t CurrentDataSetIndex() const = 0;
virtual size_t NumShards() const = 0;
virtual std::shared_ptr<DataSet> MoveToNextDataSet() = 0;
virtual const VectorString& DataSetTensorNames() const = 0;
};
class DataLoader : public IDataLoader {
public:
DataLoader(const MapStringToString& input_name_map,
const PathString& dir_path,
size_t max_num_files_preload = 2,
size_t world_rank = 0,
size_t world_size = 1);
// ... IDataLoader interface implementation ...
};
class SingleDataLoader : public IDataLoader {
public:
SingleDataLoader(std::shared_ptr<DataSet> single_data_set, VectorString input_tensor_names);
// ... IDataLoader interface implementation ...
};
Import
#include "orttraining/models/runner/data_loader.h"
I/O Contract
| Class/Method | Inputs | Outputs | Description |
|---|---|---|---|
| DataLoader (ctor) | input_name_map, dir_path, max_preload, rank, world_size | DataLoader instance | Creates a file-based data loader with distributed sharding |
| Get (DataSetBuffer) | index (size_t) | shared_ptr<DataSet> | Blocks until dataset at index is ready, then returns it |
| MoveToNextDataSet | (none) | shared_ptr<DataSet> | Advances to the next data shard (file) with prefetching |
| CurrentDataSet | (none) | shared_ptr<DataSet> | Returns the currently active dataset |
| NumShards | (none) | size_t | Returns the total number of data files/shards |
Usage Examples
#include "orttraining/models/runner/data_loader.h"
using namespace onnxruntime::training;
MapStringToString name_map = {{"input_ids", "input1"}, {"labels", "input2"}};
DataLoader loader(name_map, ORT_TSTR("/data/training/"), 2, 0, 1);
loader.InitializeDataSetIndex(0);
auto dataset = loader.CurrentDataSet();
size_t total_batches = dataset->TotalBatch(batch_size);
for (size_t b = 0; b < total_batches; ++b) {
auto batch = dataset->GetKthBatch(batch_size, b);
// ... train on batch ...
}
// Move to next shard
loader.MoveToNextDataSet();