Workflow:Microsoft Onnxruntime ORTModule Training
| Knowledge Sources | |
|---|---|
| Domains | Training_Acceleration, PyTorch_Integration, Mixed_Precision |
| Last Updated | 2026-02-10 04:30 GMT |
Overview
End-to-end process for accelerating PyTorch model training by wrapping a torch.nn.Module with ONNX Runtime's ORTModule for optimized graph execution.
Description
This workflow demonstrates how to accelerate existing PyTorch training scripts with minimal code changes using ORTModule. The PyTorch model is wrapped with ORTModule, which transparently exports the model to ONNX, applies graph optimizations (operator fusion, memory-efficient gradient computation), and executes the optimized graph on the configured hardware. The workflow supports advanced features including FP16 mixed precision training, the FusedAdam optimizer for faster parameter updates, memory optimization through activation recomputation, and gradient checkpointing for large models.
Usage
Execute this workflow when you have an existing PyTorch training script and want to improve training throughput without rewriting the training logic. ORTModule is particularly effective for transformer-based models where operator fusion (e.g., fused bias+softmax+dropout) provides significant speedups. It is suitable when you need to keep the PyTorch training ecosystem (data loaders, loss functions, custom training loops) while leveraging ONNX Runtime's graph-level optimizations.
Execution Steps
Step 1: Prepare PyTorch Model
Start with a standard PyTorch model (torch.nn.Module) and training script. Ensure the model's forward pass uses operations that are exportable to ONNX. Custom autograd functions require special handling through the PythonOp bridge if they need to be included in the ONNX graph.
Key considerations:
- Most standard PyTorch operations export to ONNX automatically
- Custom autograd functions need explicit registration for the PythonOp bridge
- Dynamic control flow in the forward pass may require special attention
Step 2: Wrap Model with ORTModule
Replace the PyTorch model with an ORTModule wrapper. This is a one-line change that wraps the existing torch.nn.Module. ORTModule intercepts forward calls, exports the model to ONNX on first invocation, applies graph optimizations, and routes execution through ONNX Runtime's optimized engine.
Key considerations:
- ORTModule is a drop-in replacement for torch.nn.Module
- The ONNX export and optimization happen lazily on the first forward call
- DebugOptions can be configured to save intermediate ONNX models and enable logging
Step 3: Configure Optimized Optimizer
Optionally replace the standard PyTorch optimizer with FusedAdam from ONNX Runtime. FusedAdam performs the entire Adam update in a single fused GPU kernel, eliminating multiple kernel launches and memory round-trips for each parameter group.
Key considerations:
- FusedAdam provides the same semantics as torch.optim.AdamW
- FP16_Optimizer wrapper adds loss scaling for mixed precision stability
- Standard PyTorch optimizers still work with ORTModule
Step 4: Enable Memory Optimization
Configure memory optimization strategies via environment variables to enable training of larger models or batch sizes. Options include activation recomputation (recompute activations during backward pass instead of storing them), transformer layer-wise recomputation, and user-specified recomputation patterns.
Key considerations:
- ORTMODULE_MEMORY_OPT_LEVEL controls optimization aggressiveness
- Level 1 applies transformer layer-wise recomputation automatically
- Level 2 allows fine-grained user-specified recomputation via config file
- Memory savings trade off against additional computation time
Step 5: Execute Training Loop
Run the standard PyTorch training loop with the ORTModule-wrapped model. The training loop structure remains unchanged: forward pass, loss computation, backward pass, optimizer step. ORTModule transparently handles ONNX export, graph optimization, and accelerated execution behind the scenes.
Key considerations:
- The training loop code is identical to standard PyTorch
- First iteration is slower due to ONNX export and optimization
- Subsequent iterations run through the optimized ONNX graph
- Environment variables control logging, debugging, and fallback behavior
Step 6: Monitor and Debug
Use DebugOptions and environment variables to monitor training behavior. Options include saving the ONNX model at various optimization stages, enabling ORT logging for operator-level insights, and controlling whether to fall back to PyTorch execution for unsupported operations.
Key considerations:
- ORTMODULE_SAVE_ONNX_PATH saves exported ONNX models for inspection
- Log levels range from verbose to fatal
- ORTModule convergence notes provide debugging guidance for numerical differences