Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Microsoft Onnxruntime ORTModule Training

From Leeroopedia


Knowledge Sources
Domains Training_Acceleration, PyTorch_Integration, Mixed_Precision
Last Updated 2026-02-10 04:30 GMT

Overview

End-to-end process for accelerating PyTorch model training by wrapping a torch.nn.Module with ONNX Runtime's ORTModule for optimized graph execution.

Description

This workflow demonstrates how to accelerate existing PyTorch training scripts with minimal code changes using ORTModule. The PyTorch model is wrapped with ORTModule, which transparently exports the model to ONNX, applies graph optimizations (operator fusion, memory-efficient gradient computation), and executes the optimized graph on the configured hardware. The workflow supports advanced features including FP16 mixed precision training, the FusedAdam optimizer for faster parameter updates, memory optimization through activation recomputation, and gradient checkpointing for large models.

Usage

Execute this workflow when you have an existing PyTorch training script and want to improve training throughput without rewriting the training logic. ORTModule is particularly effective for transformer-based models where operator fusion (e.g., fused bias+softmax+dropout) provides significant speedups. It is suitable when you need to keep the PyTorch training ecosystem (data loaders, loss functions, custom training loops) while leveraging ONNX Runtime's graph-level optimizations.

Execution Steps

Step 1: Prepare PyTorch Model

Start with a standard PyTorch model (torch.nn.Module) and training script. Ensure the model's forward pass uses operations that are exportable to ONNX. Custom autograd functions require special handling through the PythonOp bridge if they need to be included in the ONNX graph.

Key considerations:

  • Most standard PyTorch operations export to ONNX automatically
  • Custom autograd functions need explicit registration for the PythonOp bridge
  • Dynamic control flow in the forward pass may require special attention

Step 2: Wrap Model with ORTModule

Replace the PyTorch model with an ORTModule wrapper. This is a one-line change that wraps the existing torch.nn.Module. ORTModule intercepts forward calls, exports the model to ONNX on first invocation, applies graph optimizations, and routes execution through ONNX Runtime's optimized engine.

Key considerations:

  • ORTModule is a drop-in replacement for torch.nn.Module
  • The ONNX export and optimization happen lazily on the first forward call
  • DebugOptions can be configured to save intermediate ONNX models and enable logging

Step 3: Configure Optimized Optimizer

Optionally replace the standard PyTorch optimizer with FusedAdam from ONNX Runtime. FusedAdam performs the entire Adam update in a single fused GPU kernel, eliminating multiple kernel launches and memory round-trips for each parameter group.

Key considerations:

  • FusedAdam provides the same semantics as torch.optim.AdamW
  • FP16_Optimizer wrapper adds loss scaling for mixed precision stability
  • Standard PyTorch optimizers still work with ORTModule

Step 4: Enable Memory Optimization

Configure memory optimization strategies via environment variables to enable training of larger models or batch sizes. Options include activation recomputation (recompute activations during backward pass instead of storing them), transformer layer-wise recomputation, and user-specified recomputation patterns.

Key considerations:

  • ORTMODULE_MEMORY_OPT_LEVEL controls optimization aggressiveness
  • Level 1 applies transformer layer-wise recomputation automatically
  • Level 2 allows fine-grained user-specified recomputation via config file
  • Memory savings trade off against additional computation time

Step 5: Execute Training Loop

Run the standard PyTorch training loop with the ORTModule-wrapped model. The training loop structure remains unchanged: forward pass, loss computation, backward pass, optimizer step. ORTModule transparently handles ONNX export, graph optimization, and accelerated execution behind the scenes.

Key considerations:

  • The training loop code is identical to standard PyTorch
  • First iteration is slower due to ONNX export and optimization
  • Subsequent iterations run through the optimized ONNX graph
  • Environment variables control logging, debugging, and fallback behavior

Step 6: Monitor and Debug

Use DebugOptions and environment variables to monitor training behavior. Options include saving the ONNX model at various optimization stages, enabling ORT logging for operator-level insights, and controlling whether to fall back to PyTorch execution for unsupported operations.

Key considerations:

  • ORTMODULE_SAVE_ONNX_PATH saves exported ONNX models for inspection
  • Log levels range from verbose to fatal
  • ORTModule convergence notes provide debugging guidance for numerical differences

Execution Diagram

GitHub URL

Workflow Repository