Implementation:Microsoft Onnxruntime Torch Onnx Export
Overview
Uses torch.onnx.export() to convert a PyTorch nn.Module into a forward-only ONNX model file configured for subsequent on-device training artifact generation.
Metadata
| Field | Value |
|---|---|
| Implementation Name | Torch_Onnx_Export |
| Type | External Tool Doc |
| Language | Python |
| API | torch.onnx.export(model, dummy_input, path, export_params=True, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING)
|
| Import | import torch
|
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/python/on_device_training/training_artifacts.rst:L13-20 |
| Last Updated | 2026-02-10 |
Description
This implementation uses PyTorch's built-in ONNX export functionality to convert a torch.nn.Module into the ONNX interchange format. The export is configured specifically for on-device training by setting three critical parameters that preserve parameter mutability and training-mode behavior.
The export is an external tool operation -- it uses PyTorch's torch.onnx.export() rather than an ONNX Runtime API. The resulting ONNX file serves as input to ONNX Runtime's artifact generation pipeline.
API Signature
torch.onnx.export(
model, # torch.nn.Module: the model to export
dummy_input, # Union[torch.Tensor, Tuple]: example input(s)
path, # str: output file path for the .onnx file
export_params=True, # bool: include trained parameters in the model
do_constant_folding=False, # bool: disable constant folding for training
training=torch.onnx.TrainingMode.TRAINING, # TrainingMode: preserve training behavior
)
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| model | torch.nn.Module |
The PyTorch model to be exported. Must be set to training mode (model.train()) before export.
|
| dummy_input | torch.Tensor or tuple |
A sample input tensor (or tuple of tensors) used to trace the forward computation graph. The shape and dtype must match the model's expected input. |
| path | str |
File path where the exported ONNX model will be saved. Typically ends in .onnx.
|
| export_params | bool |
Must be True. Includes all model parameters as initializers in the ONNX graph so they are available for training.
|
| do_constant_folding | bool |
Must be False. Prevents collapsing constant expressions so that parameters remain mutable for gradient computation.
|
| training | torch.onnx.TrainingMode |
Must be TrainingMode.TRAINING. Preserves training-specific behavior (e.g., dropout active, batch norm using batch statistics).
|
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | torch.nn.Module |
A valid PyTorch model in training mode |
| Input | torch.Tensor |
Dummy input tensor matching the model's expected input shape and dtype |
| Output | .onnx file |
Forward-only ONNX model file containing the computation graph and parameter initializers |
Code Reference
From docs/python/on_device_training/training_artifacts.rst:
If using PyTorch to export the model, please use the following export arguments
so training artifact generation can be successful:
- ``export_params``: ``True``
- ``do_constant_folding``: ``False``
- ``training``: ``torch.onnx.TrainingMode.TRAINING``
Usage Example
import torch
import torch.nn as nn
# Define a simple model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)
# Create and set model to training mode
model = SimpleNet()
model.train()
# Create dummy input matching the expected input shape
dummy_input = torch.randn(1, 784)
# Export to ONNX for on-device training
torch.onnx.export(
model,
dummy_input,
"simple_net.onnx",
export_params=True,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
)
# Output: simple_net.onnx (forward-only ONNX model)
Implements
Principle:Microsoft_Onnxruntime_PyTorch_Model_Export
Related Pages
- Generate Artifacts -- Consumes the exported ONNX model to produce training artifacts
- Torch NN Module Construction -- Model construction for ORTModule training