Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft Onnxruntime PyTorch Model Export

From Leeroopedia


Overview

Exporting a PyTorch neural network model to the ONNX interchange format for on-device training.

Metadata

Field Value
Principle Name PyTorch_Model_Export
Category External Tool Doc
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference docs/python/on_device_training/training_artifacts.rst:L13-20
Last Updated 2026-02-10

Description

PyTorch models must be exported to ONNX format before they can be used with ONNX Runtime's on-device training API. The export uses torch.onnx.export() in training mode to preserve parameters and disable constant folding, producing a forward-only ONNX graph that serves as input for training artifact generation.

The export process requires three critical configuration parameters:

  • export_params=True -- Ensures that all model parameters are included in the exported ONNX graph as initializers. Without this, the graph would lack the trained weight values.
  • do_constant_folding=False -- Prevents the exporter from collapsing constant expressions. In training mode, parameters must remain mutable rather than being folded into constant nodes.
  • training=torch.onnx.TrainingMode.TRAINING -- Instructs the exporter to preserve training-specific behaviors such as dropout and batch normalization in training mode, rather than switching to inference mode.

The output is a forward-only ONNX model containing only the forward computation graph. This model does not yet include loss computation, gradient calculations, or optimizer logic. Those components are added in the subsequent Training Artifact Generation step.

Theoretical Basis

ONNX export traces the PyTorch model's forward computation graph using a dummy input tensor, converting PyTorch operations to ONNX operators. Training mode export preserves parameter mutability needed for gradient computation.

The tracing mechanism works by feeding the dummy input through the model and recording the operations that are executed. Each PyTorch operation is mapped to its corresponding ONNX operator. The resulting graph is a static representation of the forward pass computation.

Key theoretical considerations include:

  • Dynamic vs. Static Graphs -- PyTorch uses dynamic computation graphs, while ONNX uses static graphs. The export step captures a single execution trace, which means dynamic control flow may not be fully represented.
  • Operator Mapping -- Each PyTorch operator must have a corresponding ONNX operator or a composition of ONNX operators. The exporter maintains a mapping table for this conversion.
  • Parameter Preservation -- In training mode, all named parameters are preserved as mutable graph inputs rather than being baked into the graph as constants. This is essential because gradient computation requires the ability to modify parameter values.

Usage

The typical workflow for PyTorch model export in the on-device training pipeline is:

  1. Define and instantiate a standard torch.nn.Module model
  2. Create a dummy input tensor with the correct shape and data type
  3. Call torch.onnx.export() with training-specific parameters
  4. Use the resulting .onnx file as input to artifact generation
import torch

# Define or load the PyTorch model
model = MyModel()
model.train()

# Create a dummy input matching the model's expected input shape
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX with training-mode settings
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    export_params=True,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
)

Implemented By

Implementation:Microsoft_Onnxruntime_Torch_Onnx_Export

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment