Principle:Microsoft Onnxruntime ORT Accelerated Training
Overview
Transparent acceleration of PyTorch model training by routing computation through the ONNX Runtime execution engine.
Metadata
| Field | Value |
|---|---|
| Principle Name | ORT_Accelerated_Training |
| Category | API Doc |
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Training_Guidelines.md:L35-49
|
| Last Updated | 2026-02-10 |
Description
ORTModule wraps a PyTorch nn.Module to intercept forward and backward passes, converting them to optimized ONNX graphs executed by ONNX Runtime. This is transparent to the training loop -- the wrapped model behaves like a standard PyTorch module but benefits from ORT's graph optimizations, operator fusion, and memory planning.
The key design principle is transparency: after wrapping with ORTModule, the model continues to work with standard PyTorch training constructs including:
- Standard loss functions (
nn.CrossEntropyLoss, etc.) - Standard optimizers (
torch.optim.AdamW, etc.) or ORT'sFusedAdam - Distributed training wrappers (
DistributedDataParallel, DeepSpeed) - Learning rate schedulers
ORTModule provides DebugOptions for developers, enabling ONNX model export for inspection, verbose logging, and custom model name prefixes.
Important ordering: It is strongly recommended to wrap the model with ORTModule before other module wrappers (e.g., DeepSpeed, DistributedDataParallel). ORTModule is not compatible with torch.nn.DataParallel.
Theoretical Basis
PyTorch models are dynamically traced to ONNX at the first forward call. The ONNX graph is then optimized and executed by ORT's execution engine. Subsequent calls reuse the cached graph unless the model structure changes.
The acceleration mechanism works through the following stages:
- Lazy Export -- On the first forward pass, ORTModule traces the PyTorch model to produce an ONNX graph. This uses PyTorch's ONNX export infrastructure but is triggered automatically.
- Graph Optimization -- ORT applies a series of graph transformations including operator fusion (e.g., combining MatMul + Add into a single Gemm), constant folding, and memory-efficient execution planning.
- Kernel Selection -- ORT's execution engine selects optimized kernel implementations for each operator based on the execution provider (CPU, CUDA, etc.), often using hand-tuned kernels that outperform PyTorch's default implementations.
- Graph Caching -- The optimized graph is cached and reused across training iterations. Re-export is triggered only when the model's computation graph changes (e.g., due to dynamic input shapes or control flow changes).
- Gradient Computation -- ORT handles both forward and backward passes through its own automatic differentiation, avoiding the overhead of PyTorch's autograd for the portions of the graph executed by ORT.
Usage
model = build_model()
from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(model)
# Optional: with debug options for developers
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="my_model"))
Implemented By
Implementation:Microsoft_Onnxruntime_ORTModule_Wrap
Related Pages
- PyTorch Model Preparation -- The preceding step that constructs the model
- Fused Optimizer Configuration -- Complementary optimization for parameter updates
- Memory Optimization -- Reducing GPU memory consumption during ORTModule training
- ORTModule Training Loop -- Using the wrapped model in a training loop
- Heuristic:Microsoft_Onnxruntime_ORTModule_Wrapping_Order
- Heuristic:Microsoft_Onnxruntime_Flash_Attention_Optimization