Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime TrainingUtils

From Leeroopedia


Knowledge Sources
Domains Training, API, Utilities
Last Updated 2026-02-10 04:00 GMT

Overview

Provides utility types and functions for the ORT Training API, including ModelIdentifiers for specifying training model paths, graph introspection helpers, and OrtValue conversion utilities.

Description

The `utils.h` header in the training API provides foundational utilities used across the training components:

  • `ModelIdentifiers`: A struct that holds references to the training, evaluation, and optimizer ONNX models. Each model can be specified either as a file path (`std::string`) or an in-memory byte array (`gsl::span<const uint8_t>`). The training model is required; eval and optimizer models are optional. Methods `IsEvalModelAvailable()` and `IsOptimizerModelAvailable()` check availability.
  • `GetGraphInputOutputNames`: Extracts input and output names from an InferenceSession.
  • `GetParamNameFromSuffix` / `GetParamNameFromGradient`: Parse parameter names from gradient suffixes in graph naming conventions.
  • `CreateZeroValuedOrtValueLike`: Creates a zero-initialized OrtValue with the same type and device as a reference value.
  • `WrapInOrtValue<T>`: Creates a scalar OrtValue from a single value of type T.
  • `GetScalarFromOrtValue<T>`: Extracts a scalar value from a rank-0 or rank-1 OrtValue tensor.
  • `CopyTensorToTensorProto`: Converts a Tensor to a TensorProto for checkpoint serialization.

Usage

Use this header for model identification, graph introspection, and OrtValue manipulation when working with the ORT Training API internals.

Code Reference

Source Location

Signature

struct ModelIdentifiers {
  std::variant<std::string, gsl::span<const uint8_t>> train_model;
  std::variant<std::optional<std::string>, gsl::span<const uint8_t>> eval_model;
  std::variant<std::optional<std::string>, gsl::span<const uint8_t>> optim_model;

  ModelIdentifiers(std::variant<std::string, gsl::span<const uint8_t>> training_model,
                   std::variant<std::optional<std::string>, gsl::span<const uint8_t>> evaluation_model,
                   std::variant<std::optional<std::string>, gsl::span<const uint8_t>> optimzer_model);

  bool IsEvalModelAvailable() const;
  bool IsOptimizerModelAvailable() const;
};

namespace utils {
void GetGraphInputOutputNames(const std::unique_ptr<InferenceSession>& session_object,
                              InlinedVector<std::string>& input_names,
                              InlinedVector<std::string>& output_names);

bool GetParamNameFromSuffix(const std::string& name, const std::string& suffix,
                            std::string& param_name);
bool GetParamNameFromGradient(const std::string& grad_name, std::string& param_name);

Status CreateZeroValuedOrtValueLike(const SessionState& sess_state,
                                    const OrtValue& input_val, OrtValue& output_val);

template <typename T>
void WrapInOrtValue(T value, OrtValue* p_ortvalue, AllocatorPtr alloc = nullptr);

template <typename T>
T GetScalarFromOrtValue(OrtValue& ort_value);

ONNX_NAMESPACE::TensorProto CopyTensorToTensorProto(const Tensor& tensor,
    const std::string& tensor_proto_name, const DataTransferManager& data_transfer_manager);
}  // namespace utils

Import

#include "orttraining/training_api/utils.h"

I/O Contract

Function Inputs Outputs Description
ModelIdentifiers (ctor) train_model, eval_model, optim_model (path or bytes) ModelIdentifiers Holds model references for training session creation
WrapInOrtValue<T> T value, OrtValue*, AllocatorPtr void Creates a scalar tensor OrtValue from a single value
GetScalarFromOrtValue<T> OrtValue& T Extracts scalar from a rank-0/rank-1 tensor
CreateZeroValuedOrtValueLike SessionState, OrtValue (reference) OrtValue (zero-valued), Status Creates zero-initialized OrtValue matching reference type/device
CopyTensorToTensorProto Tensor, name, DataTransferManager TensorProto Converts Tensor to TensorProto for serialization

Usage Examples

#include "orttraining/training_api/utils.h"

using namespace onnxruntime::training::api;

// Create model identifiers from file paths
ModelIdentifiers ids(
    std::string("train_model.onnx"),
    std::optional<std::string>("eval_model.onnx"),
    std::optional<std::string>("optimizer_model.onnx"));

// Wrap a scalar value into OrtValue
OrtValue lr_value;
utils::WrapInOrtValue<float>(0.001f, &lr_value);

// Extract scalar from OrtValue
float lr = utils::GetScalarFromOrtValue<float>(lr_value);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment