Implementation:Microsoft Onnxruntime Torch NN Module Construction
Overview
Standard PyTorch torch.nn.Module construction, producing a model instance ready for wrapping with ORTModule for accelerated training.
Metadata
| Field | Value |
|---|---|
| Implementation Name | Torch_NN_Module_Construction |
| Type | External Tool Doc |
| Language | Python |
| API | Standard torch.nn.Module construction: model = MyModel() or model = build_model()
|
| Import | import torch
|
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Training_Guidelines.md:L32-33 |
| Last Updated | 2026-02-10 |
Description
This implementation covers the standard PyTorch model construction step that precedes ORTModule wrapping. The model is constructed using standard PyTorch conventions -- there are no ONNX Runtime-specific modifications required at this stage.
The model must satisfy the following requirements for compatibility with ORTModule:
- Must be a subclass of
torch.nn.Module - The
forward()method must accept and return tensors (or nested structures of tensors) - All learnable parameters must be registered via
nn.Parameteror contained in submodules - Dynamic control flow based on tensor values should be avoided where possible for optimal ONNX tracing
API Signature
import torch
import torch.nn as nn
# Option 1: Direct instantiation
model = MyModel(config)
# Option 2: Factory function
model = build_model()
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| Model class | torch.nn.Module subclass |
Any standard PyTorch model class |
| Constructor args | Varies | Model-specific configuration (hidden size, number of layers, etc.) |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | Model configuration | Architecture hyperparameters passed to the model constructor |
| Output | torch.nn.Module instance |
Constructed model with initialized parameters, ready for ORTModule wrapping |
Code Reference
From docs/ORTModule_Training_Guidelines.md:
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
The first line (model = build_model()) represents the standard model construction step covered by this implementation.
Usage Example
import torch
import torch.nn as nn
# Example 1: Simple feedforward model
class FeedForward(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
model = FeedForward(784, 256, 10)
# Example 2: Using a pretrained model from a library
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
# Example 3: Using a factory function
def build_model():
config = ModelConfig(hidden_size=768, num_layers=12, num_heads=12)
return TransformerModel(config)
model = build_model()
# Model is now ready for ORTModule wrapping
Implements
Principle:Microsoft_Onnxruntime_PyTorch_Model_Preparation
Related Pages
- ORTModule Wrap -- The next step: wrapping the constructed model with ORTModule
- Torch Onnx Export -- Alternative path for on-device training export