Implementation:Microsoft Onnxruntime TrainingUtils
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Training, API, Utilities |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Provides utility types and functions for the ORT Training API, including ModelIdentifiers for specifying training model paths, graph introspection helpers, and OrtValue conversion utilities.
Description
The `utils.h` header in the training API provides foundational utilities used across the training components:
- `ModelIdentifiers`: A struct that holds references to the training, evaluation, and optimizer ONNX models. Each model can be specified either as a file path (`std::string`) or an in-memory byte array (`gsl::span<const uint8_t>`). The training model is required; eval and optimizer models are optional. Methods `IsEvalModelAvailable()` and `IsOptimizerModelAvailable()` check availability.
- `GetGraphInputOutputNames`: Extracts input and output names from an InferenceSession.
- `GetParamNameFromSuffix` / `GetParamNameFromGradient`: Parse parameter names from gradient suffixes in graph naming conventions.
- `CreateZeroValuedOrtValueLike`: Creates a zero-initialized OrtValue with the same type and device as a reference value.
- `WrapInOrtValue<T>`: Creates a scalar OrtValue from a single value of type T.
- `GetScalarFromOrtValue<T>`: Extracts a scalar value from a rank-0 or rank-1 OrtValue tensor.
- `CopyTensorToTensorProto`: Converts a Tensor to a TensorProto for checkpoint serialization.
Usage
Use this header for model identification, graph introspection, and OrtValue manipulation when working with the ORT Training API internals.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_api/utils.h
- Lines: 1-106
Signature
struct ModelIdentifiers {
std::variant<std::string, gsl::span<const uint8_t>> train_model;
std::variant<std::optional<std::string>, gsl::span<const uint8_t>> eval_model;
std::variant<std::optional<std::string>, gsl::span<const uint8_t>> optim_model;
ModelIdentifiers(std::variant<std::string, gsl::span<const uint8_t>> training_model,
std::variant<std::optional<std::string>, gsl::span<const uint8_t>> evaluation_model,
std::variant<std::optional<std::string>, gsl::span<const uint8_t>> optimzer_model);
bool IsEvalModelAvailable() const;
bool IsOptimizerModelAvailable() const;
};
namespace utils {
void GetGraphInputOutputNames(const std::unique_ptr<InferenceSession>& session_object,
InlinedVector<std::string>& input_names,
InlinedVector<std::string>& output_names);
bool GetParamNameFromSuffix(const std::string& name, const std::string& suffix,
std::string& param_name);
bool GetParamNameFromGradient(const std::string& grad_name, std::string& param_name);
Status CreateZeroValuedOrtValueLike(const SessionState& sess_state,
const OrtValue& input_val, OrtValue& output_val);
template <typename T>
void WrapInOrtValue(T value, OrtValue* p_ortvalue, AllocatorPtr alloc = nullptr);
template <typename T>
T GetScalarFromOrtValue(OrtValue& ort_value);
ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& tensor,
const std::string& tensor_proto_name, const DataTransferManager& data_transfer_manager);
} // namespace utils
Import
#include "orttraining/training_api/utils.h"
I/O Contract
| Function | Inputs | Outputs | Description |
|---|---|---|---|
| ModelIdentifiers (ctor) | train_model, eval_model, optim_model (path or bytes) | ModelIdentifiers | Holds model references for training session creation |
| WrapInOrtValue<T> | T value, OrtValue*, AllocatorPtr | void | Creates a scalar tensor OrtValue from a single value |
| GetScalarFromOrtValue<T> | OrtValue& | T | Extracts scalar from a rank-0/rank-1 tensor |
| CreateZeroValuedOrtValueLike | SessionState, OrtValue (reference) | OrtValue (zero-valued), Status | Creates zero-initialized OrtValue matching reference type/device |
| CopyTensorToTensorProto | Tensor, name, DataTransferManager | TensorProto | Converts Tensor to TensorProto for serialization |
Usage Examples
#include "orttraining/training_api/utils.h"
using namespace onnxruntime::training::api;
// Create model identifiers from file paths
ModelIdentifiers ids(
std::string("train_model.onnx"),
std::optional<std::string>("eval_model.onnx"),
std::optional<std::string>("optimizer_model.onnx"));
// Wrap a scalar value into OrtValue
OrtValue lr_value;
utils::WrapInOrtValue<float>(0.001f, &lr_value);
// Extract scalar from OrtValue
float lr = utils::GetScalarFromOrtValue<float>(lr_value);
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment