Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Microsoft Onnxruntime ORTModule Wrapping Order

From Leeroopedia




Field Value
Sources docs/ORTModule_Training_Guidelines.md (L39-41, L326-533)
Domains Training, Model Setup, Distributed Training, DeepSpeed Integration
Last Updated 2026-02-10

Overview

Wrap your PyTorch model with ORTModule before applying any other module wrappers (DeepSpeed, DDP, etc.) to ensure ORT can properly capture and optimize the computation graph.

Description

When using ONNX Runtime for training via ORTModule, the order in which module wrappers are applied to a PyTorch nn.Module is critical for correct behavior. ORTModule works by exporting the PyTorch model to an ONNX graph and executing it through ORT's optimized backend. If other wrappers (such as DeepSpeed's engine or PyTorch's DistributedDataParallel) are applied first, they may modify the module structure in ways that interfere with ONNX export or prevent ORT from capturing the full computation graph.

The official guidance is: wrap with ORTModule first, then apply other wrappers on top. This has been validated across more scenarios than the reverse order. Additionally, ORTModule is explicitly not compatible with torch.nn.DataParallel (which is not recommended even in standard PyTorch usage). Users should use torch.nn.parallel.DistributedDataParallel (DDP) instead.

For pipeline parallelism with DeepSpeed, use ORTPipelineModule as a drop-in replacement for DeepSpeed's PipelineModule.

Usage

Use this heuristic when:

  • Setting up ORTModule for the first time with any model.
  • Combining ORTModule with DeepSpeed (ZeRO, pipeline parallelism).
  • Using DDP alongside ORTModule for multi-GPU data parallelism.
  • Debugging wrapper interaction issues where the model fails to export or run through ORT.

The Insight (Rule of Thumb)

Always follow this wrapping order:

  1. Wrap the base model with ORTModule.
  2. Apply distributed wrappers (DeepSpeed, DDP) after ORTModule.
  3. Replace torch.optim.AdamW with FusedAdam for faster parameter updates.
  4. Optionally wrap the optimizer with FP16_Optimizer for DeepSpeed/APEX integration.

Correct pattern:

model = build_model()

# Step 1: ORTModule FIRST
from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(model)

# Step 2: Optimizer
from onnxruntime.training.optim import FusedAdam
optimizer = FusedAdam(model.parameters(), lr=1)

# Step 3: DeepSpeed (wraps ORTModule, not the bare model)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model, optimizer=optimizer, args=args,
    lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False
)

# Step 4: Optional FP16_Optimizer
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)

Key constraints:

  • ORTModule is NOT compatible with torch.nn.DataParallel. Use DDP instead.
  • For pipeline parallelism, use ORTPipelineModule instead of DeepSpeed's PipelineModule.
  • Set ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE" to force ORT backend execution and prevent silent fallback to PyTorch (useful during benchmarking and development).
  • Set ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 to reduce memory peak during model export when memory is tight.

Pipeline parallelism example:

from onnxruntime.training.ortmodule import DebugOptions
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule

pipeline_module = ORTPipelineModule(
    layers, num_stages=2, base_seed=1234,
    partition_method="parameters",
    debug_options=DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE)
)

Reasoning

ORTModule needs to intercept the model's forward() call to export the computation graph to ONNX and execute it through ORT's optimized backend. When other wrappers are applied first, they modify the module structure (adding communication hooks for DDP, sharding parameters for DeepSpeed ZeRO, etc.) before ORTModule has a chance to capture the original computation graph. This can result in: (1) ONNX export failures because the wrapped module's forward signature or behavior is altered, (2) suboptimal ORT graphs that include wrapper overhead as part of the computation, or (3) incompatible module structures that prevent the ORT backend from running at all. By wrapping with ORTModule first, the original model's computation graph is cleanly captured, and subsequent wrappers operate on the ORTModule-wrapped model, which presents the same interface as the original nn.Module. The explicit incompatibility with DataParallel exists because DataParallel replicates the module across devices within a single process, which conflicts with ORT's single-graph execution model. DDP, which uses one process per GPU with gradient synchronization, is the compatible alternative.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment