Principle:Microsoft Onnxruntime PyTorch Model Preparation
Overview
Preparation of a standard PyTorch nn.Module for accelerated training with ONNX Runtime.
Metadata
| Field | Value |
|---|---|
| Principle Name | PyTorch_Model_Preparation |
| Category | External Tool Doc |
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Training_Guidelines.md:L32-33
|
| Last Updated | 2026-02-10 |
Description
Before using ORTModule, a standard PyTorch model must be constructed following PyTorch conventions. The model should be a valid torch.nn.Module that can be traced by ONNX export. This serves as the input to ORTModule wrapping.
The model preparation step is deliberately simple -- it uses standard PyTorch model construction patterns. There are no ONNX Runtime-specific requirements at this stage. However, the following PyTorch conventions should be followed to ensure compatibility with ORTModule's ONNX export tracing:
- The model must be a subclass of
torch.nn.Module. - The model's
forward()method must accept tensor inputs and produce tensor outputs. - Dynamic control flow (e.g., Python if/else based on tensor values) may not be fully captured during ONNX tracing. Static control flow is preferred.
- Custom autograd functions should follow PyTorch's
torch.autograd.Functionconventions withforwardandbackwardmethods.
The prepared model is passed directly to the ORTModule constructor, which handles all ONNX-related conversion transparently.
Theoretical Basis
PyTorch's nn.Module is the standard abstraction for neural network layers and models. It provides:
- Parameter Management -- Automatic registration of learnable parameters via
nn.Parameterand submodule attributes. - Forward Computation -- The
forward()method defines the model's computation graph. - State Management --
state_dict()andload_state_dict()enable parameter serialization. - Device Placement --
.to(device)moves all parameters and buffers to the specified device.
ORTModule relies on these conventions to discover parameters, trace the forward computation, and manage device placement during its acceleration process.
Usage
import torch
import torch.nn as nn
# Standard PyTorch model construction
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = self.transformer(x)
return self.fc(x)
# Construct the model (standard PyTorch)
model = TransformerModel(vocab_size=30000, d_model=512, nhead=8, num_layers=6)
# or: model = build_model()
# Model is now ready for ORTModule wrapping
Implemented By
Implementation:Microsoft_Onnxruntime_Torch_NN_Module_Construction
Related Pages
- ORT Accelerated Training -- The next step: wrapping the model with ORTModule
- PyTorch Model Export -- Alternative path for on-device training