Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime Torch Onnx Export

From Leeroopedia


Overview

Uses torch.onnx.export() to convert a PyTorch nn.Module into a forward-only ONNX model file configured for subsequent on-device training artifact generation.

Metadata

Field Value
Implementation Name Torch_Onnx_Export
Type External Tool Doc
Language Python
API torch.onnx.export(model, dummy_input, path, export_params=True, do_constant_folding=False, training=torch.onnx.TrainingMode.TRAINING)
Import import torch
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference docs/python/on_device_training/training_artifacts.rst:L13-20
Last Updated 2026-02-10

Description

This implementation uses PyTorch's built-in ONNX export functionality to convert a torch.nn.Module into the ONNX interchange format. The export is configured specifically for on-device training by setting three critical parameters that preserve parameter mutability and training-mode behavior.

The export is an external tool operation -- it uses PyTorch's torch.onnx.export() rather than an ONNX Runtime API. The resulting ONNX file serves as input to ONNX Runtime's artifact generation pipeline.

API Signature

torch.onnx.export(
    model,               # torch.nn.Module: the model to export
    dummy_input,         # Union[torch.Tensor, Tuple]: example input(s)
    path,                # str: output file path for the .onnx file
    export_params=True,  # bool: include trained parameters in the model
    do_constant_folding=False,  # bool: disable constant folding for training
    training=torch.onnx.TrainingMode.TRAINING,  # TrainingMode: preserve training behavior
)

Key Parameters

Parameter Type Description
model torch.nn.Module The PyTorch model to be exported. Must be set to training mode (model.train()) before export.
dummy_input torch.Tensor or tuple A sample input tensor (or tuple of tensors) used to trace the forward computation graph. The shape and dtype must match the model's expected input.
path str File path where the exported ONNX model will be saved. Typically ends in .onnx.
export_params bool Must be True. Includes all model parameters as initializers in the ONNX graph so they are available for training.
do_constant_folding bool Must be False. Prevents collapsing constant expressions so that parameters remain mutable for gradient computation.
training torch.onnx.TrainingMode Must be TrainingMode.TRAINING. Preserves training-specific behavior (e.g., dropout active, batch norm using batch statistics).

I/O Contract

Direction Type Description
Input torch.nn.Module A valid PyTorch model in training mode
Input torch.Tensor Dummy input tensor matching the model's expected input shape and dtype
Output .onnx file Forward-only ONNX model file containing the computation graph and parameter initializers

Code Reference

From docs/python/on_device_training/training_artifacts.rst:

If using PyTorch to export the model, please use the following export arguments
so training artifact generation can be successful:

- ``export_params``: ``True``
- ``do_constant_folding``: ``False``
- ``training``: ``torch.onnx.TrainingMode.TRAINING``

Usage Example

import torch
import torch.nn as nn

# Define a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Create and set model to training mode
model = SimpleNet()
model.train()

# Create dummy input matching the expected input shape
dummy_input = torch.randn(1, 784)

# Export to ONNX for on-device training
torch.onnx.export(
    model,
    dummy_input,
    "simple_net.onnx",
    export_params=True,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
)
# Output: simple_net.onnx (forward-only ONNX model)

Implements

Principle:Microsoft_Onnxruntime_PyTorch_Model_Export

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment