Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime ORTModule Wrap

From Leeroopedia


Overview

Wraps a PyTorch nn.Module with ORTModule to transparently route forward and backward passes through the ONNX Runtime execution engine for accelerated training.

Metadata

Field Value
Implementation Name ORTModule_Wrap
Type API Doc
Language Python
API onnxruntime.training.ortmodule.ORTModule(model: torch.nn.Module, debug_options: DebugOptions = None) -> ORTModule
Import from onnxruntime.training.ortmodule import ORTModule
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/ORTModule_Training_Guidelines.md:L35-49
Last Updated 2026-02-10

Description

The ORTModule constructor accepts a standard PyTorch nn.Module and returns a wrapped module that behaves identically from the PyTorch API perspective but executes computations through ONNX Runtime's optimized engine.

On the first forward call, ORTModule:

  1. Traces the wrapped model to produce an ONNX graph
  2. Applies ORT graph optimizations (operator fusion, memory planning, etc.)
  3. Caches the optimized graph for subsequent calls
  4. Executes the forward pass through ORT's execution engine

Subsequent forward/backward calls reuse the cached graph unless the computation graph changes (e.g., due to dynamic shapes or control flow changes).

The optional DebugOptions parameter enables developer-facing features:

  • save_onnx=True -- Saves the exported ONNX models to disk for inspection
  • log_level=LogLevel.VERBOSE -- Enables detailed logging of ORT operations
  • onnx_prefix="model_name" -- Sets a prefix for saved ONNX file names

API Signature

from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel

# Basic usage
model = ORTModule(model)

# With debug options
model = ORTModule(
    model,
    DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="my_model"),
)

Key Parameters

Parameter Type Description
model torch.nn.Module The PyTorch model to accelerate. Must be wrapped before other wrappers (DeepSpeed, DDP).
debug_options DebugOptions (optional) Configuration for debugging: ONNX model saving, log level, and model name prefix.

DebugOptions Fields

Field Type Description
save_onnx bool Whether to save exported ONNX models to disk
log_level LogLevel Logging verbosity: FATAL, ERROR, WARNING (default), INFO, VERBOSE
onnx_prefix str Prefix for saved ONNX file names

I/O Contract

Direction Type Description
Input torch.nn.Module Standard PyTorch model with registered parameters
Output ORTModule Wrapped module that intercepts forward/backward and executes through ORT

Code Reference

From docs/ORTModule_Training_Guidelines.md:

	model = build_model()

+	from onnxruntime.training.ortmodule import ORTModule
+	model = ORTModule(model)

With debug options:

	model = build_model()

+	from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
+	model = ORTModule(model, DebugOptions(save_onnx=True, log_level=LogLevel.VERBOSE, onnx_prefix="model_name"))

Usage Example

import torch
from onnxruntime.training.ortmodule import ORTModule

# Construct the model
model = build_model()

# Wrap with ORTModule (do this BEFORE other wrappers)
model = ORTModule(model)

# Now wrap with distributed training (if needed)
model = torch.nn.parallel.DistributedDataParallel(model)

# Training loop works exactly as before
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Important: ORTModule is not compatible with torch.nn.DataParallel. Use torch.nn.parallel.DistributedDataParallel instead.

Implements

Principle:Microsoft_Onnxruntime_ORT_Accelerated_Training

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment