Principle:Microsoft Onnxruntime PyTorch Model Export
Overview
Exporting a PyTorch neural network model to the ONNX interchange format for on-device training.
Metadata
| Field | Value |
|---|---|
| Principle Name | PyTorch_Model_Export |
| Category | External Tool Doc |
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/python/on_device_training/training_artifacts.rst:L13-20
|
| Last Updated | 2026-02-10 |
Description
PyTorch models must be exported to ONNX format before they can be used with ONNX Runtime's on-device training API. The export uses torch.onnx.export() in training mode to preserve parameters and disable constant folding, producing a forward-only ONNX graph that serves as input for training artifact generation.
The export process requires three critical configuration parameters:
- export_params=True -- Ensures that all model parameters are included in the exported ONNX graph as initializers. Without this, the graph would lack the trained weight values.
- do_constant_folding=False -- Prevents the exporter from collapsing constant expressions. In training mode, parameters must remain mutable rather than being folded into constant nodes.
- training=torch.onnx.TrainingMode.TRAINING -- Instructs the exporter to preserve training-specific behaviors such as dropout and batch normalization in training mode, rather than switching to inference mode.
The output is a forward-only ONNX model containing only the forward computation graph. This model does not yet include loss computation, gradient calculations, or optimizer logic. Those components are added in the subsequent Training Artifact Generation step.
Theoretical Basis
ONNX export traces the PyTorch model's forward computation graph using a dummy input tensor, converting PyTorch operations to ONNX operators. Training mode export preserves parameter mutability needed for gradient computation.
The tracing mechanism works by feeding the dummy input through the model and recording the operations that are executed. Each PyTorch operation is mapped to its corresponding ONNX operator. The resulting graph is a static representation of the forward pass computation.
Key theoretical considerations include:
- Dynamic vs. Static Graphs -- PyTorch uses dynamic computation graphs, while ONNX uses static graphs. The export step captures a single execution trace, which means dynamic control flow may not be fully represented.
- Operator Mapping -- Each PyTorch operator must have a corresponding ONNX operator or a composition of ONNX operators. The exporter maintains a mapping table for this conversion.
- Parameter Preservation -- In training mode, all named parameters are preserved as mutable graph inputs rather than being baked into the graph as constants. This is essential because gradient computation requires the ability to modify parameter values.
Usage
The typical workflow for PyTorch model export in the on-device training pipeline is:
- Define and instantiate a standard
torch.nn.Modulemodel - Create a dummy input tensor with the correct shape and data type
- Call
torch.onnx.export()with training-specific parameters - Use the resulting
.onnxfile as input to artifact generation
import torch
# Define or load the PyTorch model
model = MyModel()
model.train()
# Create a dummy input matching the model's expected input shape
dummy_input = torch.randn(1, 3, 224, 224)
# Export to ONNX with training-mode settings
torch.onnx.export(
model,
dummy_input,
"model.onnx",
export_params=True,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
)
Implemented By
Implementation:Microsoft_Onnxruntime_Torch_Onnx_Export
Related Pages
- Training Artifact Generation -- The next step in the pipeline after model export
- PyTorch Model Preparation -- Model construction for ORTModule-based training