Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft Onnxruntime Training Artifact Generation

From Leeroopedia


Overview

Generation of required training artifacts (training model, eval model, optimizer model, checkpoint) from a forward-only ONNX model.

Metadata

Field Value
Principle Name Training_Artifact_Generation
Category API Doc
Domain On_Device_Training, Model_Optimization
Repository microsoft/onnxruntime
Source Reference docs/python/on_device_training/training_artifacts.rst:L23-40
Last Updated 2026-02-10

Description

The artifact generation step takes a forward-only ONNX model and produces all files needed for on-device training: a training model with backward graph, an evaluation model, an optimizer model, and an initial checkpoint. This automates the complex process of adding loss computation, gradient calculation, and optimizer logic to the forward graph.

The four generated artifacts serve distinct roles:

  • training_model.onnx -- Contains the forward graph augmented with a loss function and the corresponding backward graph for gradient computation. This is the model used during the TrainStep call.
  • eval_model.onnx -- Contains the forward graph augmented with the loss function but without the backward graph. Used during EvalStep to compute metrics without modifying parameters.
  • optimizer_model.onnx -- Encodes the optimizer's parameter update rules (e.g., AdamW momentum updates). Executed during OptimizerStep to update parameters using computed gradients.
  • checkpoint -- A flatbuffers-encoded file containing initial parameter values and optimizer state. Serves as the starting point for training.

The requires_grad and frozen_params lists control which parameters participate in gradient computation. Parameters listed in requires_grad will have gradients computed and will be updated by the optimizer. Parameters listed in frozen_params are included in the model but their values remain fixed during training.

Theoretical Basis

Automatic differentiation is applied to the forward ONNX graph to create a backward graph. The loss function is appended to the forward graph, and an optimizer model encodes the parameter update rules.

The artifact generation process applies the following transformations:

  • Loss Function Insertion -- The specified loss function (e.g., CrossEntropyLoss, MSELoss) is appended to the forward graph's output. This converts the model's raw predictions into a scalar loss value.
  • Automatic Differentiation -- The backward graph is derived by applying the chain rule to each operation in the forward graph, in reverse topological order. Each ONNX operator has a corresponding gradient implementation.
  • Optimizer Graph Construction -- A separate ONNX graph is created that encodes the optimizer's update rule (e.g., for AdamW: maintaining first and second moment estimates, applying weight decay, and computing parameter updates).
  • Checkpoint Initialization -- All parameters are serialized along with zero-initialized optimizer states (e.g., zero-initialized momentum buffers for AdamW).

Usage

The standard usage requires a forward-only ONNX model (typically exported from PyTorch) and specification of which parameters need gradients:

from onnxruntime.training import artifacts
import onnx

# Load the forward-only ONNX model
model = onnx.load("model.onnx")

# Generate all training artifacts
artifacts.generate_artifacts(
    model,
    requires_grad=["parameters", "needing", "gradients"],
    frozen_params=["parameters", "not", "needing", "gradients"],
    loss=artifacts.LossType.CrossEntropyLoss,
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="output_artifacts/",
)

Custom loss functions can be created by extending onnxruntime.training.onnxblock.Block and passing the instance as the loss parameter instead of a LossType enum value.

Implemented By

Implementation:Microsoft_Onnxruntime_Generate_Artifacts

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment