Heuristic:Microsoft Onnxruntime ORTModule Wrapping Order
| Field | Value |
|---|---|
| Sources | docs/ORTModule_Training_Guidelines.md (L39-41, L326-533)
|
| Domains | Training, Model Setup, Distributed Training, DeepSpeed Integration |
| Last Updated | 2026-02-10 |
Overview
Wrap your PyTorch model with ORTModule before applying any other module wrappers (DeepSpeed, DDP, etc.) to ensure ORT can properly capture and optimize the computation graph.
Description
When using ONNX Runtime for training via ORTModule, the order in which module wrappers are applied to a PyTorch nn.Module is critical for correct behavior. ORTModule works by exporting the PyTorch model to an ONNX graph and executing it through ORT's optimized backend. If other wrappers (such as DeepSpeed's engine or PyTorch's DistributedDataParallel) are applied first, they may modify the module structure in ways that interfere with ONNX export or prevent ORT from capturing the full computation graph.
The official guidance is: wrap with ORTModule first, then apply other wrappers on top. This has been validated across more scenarios than the reverse order. Additionally, ORTModule is explicitly not compatible with torch.nn.DataParallel (which is not recommended even in standard PyTorch usage). Users should use torch.nn.parallel.DistributedDataParallel (DDP) instead.
For pipeline parallelism with DeepSpeed, use ORTPipelineModule as a drop-in replacement for DeepSpeed's PipelineModule.
Usage
Use this heuristic when:
- Setting up ORTModule for the first time with any model.
- Combining ORTModule with DeepSpeed (ZeRO, pipeline parallelism).
- Using DDP alongside ORTModule for multi-GPU data parallelism.
- Debugging wrapper interaction issues where the model fails to export or run through ORT.
The Insight (Rule of Thumb)
Always follow this wrapping order:
- Wrap the base model with
ORTModule. - Apply distributed wrappers (DeepSpeed, DDP) after ORTModule.
- Replace
torch.optim.AdamWwithFusedAdamfor faster parameter updates. - Optionally wrap the optimizer with
FP16_Optimizerfor DeepSpeed/APEX integration.
Correct pattern:
model = build_model()
# Step 1: ORTModule FIRST
from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(model)
# Step 2: Optimizer
from onnxruntime.training.optim import FusedAdam
optimizer = FusedAdam(model.parameters(), lr=1)
# Step 3: DeepSpeed (wraps ORTModule, not the bare model)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model, optimizer=optimizer, args=args,
lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False
)
# Step 4: Optional FP16_Optimizer
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)
Key constraints:
ORTModuleis NOT compatible withtorch.nn.DataParallel. Use DDP instead.- For pipeline parallelism, use
ORTPipelineModuleinstead of DeepSpeed'sPipelineModule. - Set
ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"to force ORT backend execution and prevent silent fallback to PyTorch (useful during benchmarking and development). - Set
ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0to reduce memory peak during model export when memory is tight.
Pipeline parallelism example:
from onnxruntime.training.ortmodule import DebugOptions
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
pipeline_module = ORTPipelineModule(
layers, num_stages=2, base_seed=1234,
partition_method="parameters",
debug_options=DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE)
)
Reasoning
ORTModule needs to intercept the model's forward() call to export the computation graph to ONNX and execute it through ORT's optimized backend. When other wrappers are applied first, they modify the module structure (adding communication hooks for DDP, sharding parameters for DeepSpeed ZeRO, etc.) before ORTModule has a chance to capture the original computation graph. This can result in: (1) ONNX export failures because the wrapped module's forward signature or behavior is altered, (2) suboptimal ORT graphs that include wrapper overhead as part of the computation, or (3) incompatible module structures that prevent the ORT backend from running at all. By wrapping with ORTModule first, the original model's computation graph is cleanly captured, and subsequent wrappers operate on the ORTModule-wrapped model, which presents the same interface as the original nn.Module. The explicit incompatibility with DataParallel exists because DataParallel replicates the module across devices within a single process, which conflicts with ORT's single-graph execution model. DDP, which uses one process per GPU with gradient synchronization, is the compatible alternative.