Implementation:Microsoft Onnxruntime ORTModule Wrap
Overview
Wraps a PyTorch nn.Module with ORTModule to transparently route forward and backward passes through the ONNX Runtime execution engine for accelerated training.
Metadata
| Field | Value |
|---|---|
| Implementation Name | ORTModule_Wrap |
| Type | API Doc |
| Language | Python |
| API | onnxruntime.training.ortmodule.ORTModule(model: torch.nn.Module, debug_options: DebugOptions = None) -> ORTModule
|
| Import | from onnxruntime.training.ortmodule import ORTModule
|
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Training_Guidelines.md:L35-49 |
| Last Updated | 2026-02-10 |
Description
The ORTModule constructor accepts a standard PyTorch nn.Module and returns a wrapped module that behaves identically from the PyTorch API perspective but executes computations through ONNX Runtime's optimized engine.
On the first forward call, ORTModule:
- Traces the wrapped model to produce an ONNX graph
- Applies ORT graph optimizations (operator fusion, memory planning, etc.)
- Caches the optimized graph for subsequent calls
- Executes the forward pass through ORT's execution engine
Subsequent forward/backward calls reuse the cached graph unless the computation graph changes (e.g., due to dynamic shapes or control flow changes).
The optional DebugOptions parameter enables developer-facing features:
save_onnx=True-- Saves the exported ONNX models to disk for inspectionlog_level=LogLevel.VERBOSE-- Enables detailed logging of ORT operationsonnx_prefix="model_name"-- Sets a prefix for saved ONNX file names
API Signature
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
# Basic usage
model = ORTModule(model)
# With debug options
model = ORTModule(
model,
DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="my_model"),
)
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| model | torch.nn.Module |
The PyTorch model to accelerate. Must be wrapped before other wrappers (DeepSpeed, DDP). |
| debug_options | DebugOptions (optional) |
Configuration for debugging: ONNX model saving, log level, and model name prefix. |
DebugOptions Fields
| Field | Type | Description |
|---|---|---|
| save_onnx | bool |
Whether to save exported ONNX models to disk |
| log_level | LogLevel |
Logging verbosity: FATAL, ERROR, WARNING (default), INFO, VERBOSE
|
| onnx_prefix | str |
Prefix for saved ONNX file names |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | torch.nn.Module |
Standard PyTorch model with registered parameters |
| Output | ORTModule |
Wrapped module that intercepts forward/backward and executes through ORT |
Code Reference
From docs/ORTModule_Training_Guidelines.md:
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
With debug options:
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
+ model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))
Usage Example
import torch
from onnxruntime.training.ortmodule import ORTModule
# Construct the model
model = build_model()
# Wrap with ORTModule (do this BEFORE other wrappers)
model = ORTModule(model)
# Now wrap with distributed training (if needed)
model = torch.nn.parallel.DistributedDataParallel(model)
# Training loop works exactly as before
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Important: ORTModule is not compatible with torch.nn.DataParallel. Use torch.nn.parallel.DistributedDataParallel instead.
Implements
Principle:Microsoft_Onnxruntime_ORT_Accelerated_Training
Related Pages
- Torch NN Module Construction -- Constructs the model that is wrapped
- FusedAdam FP16Optimizer -- Optimized optimizer used with ORTModule
- ORTModule Training Execution -- The training loop using the wrapped model
- Memory Opt Env Config -- Memory optimization during ORTModule training
- Environment:Microsoft_Onnxruntime_CUDA_GPU_Environment
- Heuristic:Microsoft_Onnxruntime_ORTModule_Wrapping_Order
- Heuristic:Microsoft_Onnxruntime_Flash_Attention_Optimization