Implementation:Microsoft Onnxruntime ORTModule Training Execution
Overview
Executes a standard PyTorch training loop where ORTModule transparently intercepts forward and backward passes to accelerate computation through ONNX Runtime's optimized execution engine.
Metadata
| Field | Value |
|---|---|
| Implementation Name | ORTModule_Training_Execution |
| Type | Wrapper Doc |
| Language | Python |
| API | Standard PyTorch training loop: loss = model(inputs), loss.backward(), optimizer.step(), optimizer.zero_grad() -- ORTModule transparently intercepts
|
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Training_Guidelines.md:L379-400 |
| Last Updated | 2026-02-10 |
Description
The training execution with ORTModule uses the exact same training loop pattern as native PyTorch. The key difference is that the model has been wrapped with ORTModule, which intercepts:
- Forward pass (
model(inputs)) -- ORTModule exports the model to ONNX (on first call), optimizes the graph, and executes it through ORT's engine. - Backward pass (
loss.backward()) -- ORTModule routes gradient computation through ORT's optimized backward graph.
The optimizer and gradient management remain standard PyTorch operations, though ORT's FusedAdam and FP16_Optimizer can be used for additional performance.
Key Environment Variables
Several environment variables influence ORTModule behavior during training execution:
| Variable | Default | Description |
|---|---|---|
ORTMODULE_ENABLE_COMPUTE_OPTIMIZER |
1 (enabled) |
Enables graph-level compute optimizations during execution |
ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT |
0 (disabled) |
Enables early gradient buffer release through PythonOpGrad operator
|
ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER |
1 (enabled) |
Enables sparse optimization for embedding layers |
API Signature
# The training loop API is standard PyTorch
for epoch in range(num_epochs):
for batch in train_dataloader:
loss = model(batch) # Forward: intercepted by ORTModule
loss.backward() # Backward: intercepted by ORTModule
optimizer.step() # Parameter update: standard PyTorch or FusedAdam
optimizer.zero_grad() # Gradient reset: standard PyTorch
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| model | ORTModule-wrapped nn.Module |
The model previously wrapped with ORTModule(model)
|
| optimizer | torch.optim.Optimizer |
Any PyTorch optimizer, optionally ORT's FusedAdam or FP16_Optimizer
|
| inputs | Tensors | Training batch data passed to the model |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | Training data batches | Tensors from the data loader |
| Output | Loss value | Scalar tensor representing the computed loss |
| Side Effect | Parameter updates | Model parameters are updated in-place by the optimizer |
| Side Effect | ORT graph caching | ONNX graph is cached after first forward call |
Code Reference
From docs/ORTModule_Training_Guidelines.md:
model = build_model()
+ from onnxruntime.training.ortmodule import ORTModule
+ model = ORTModule(model)
- optimizer = AdamW(model.parameters(), lr=1)
+ from onnxruntime.training.optim import FusedAdam
+ optimizer = FusedAdam(model.parameters(), lr=1)
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
lr_scheduler=lr_scheduler,
mpu=mpu,
dist_init_required=False)
+ from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+ optimizer = FP16_Optimizer(optimizer)
Usage Example
Complete Training Pipeline
import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
# Step 1: Build model
model = build_model()
# Step 2: Wrap with ORTModule (do this FIRST)
model = ORTModule(model)
# Step 3: Configure optimizer
optimizer = FusedAdam(model.parameters(), lr=1e-4)
# Step 4: Optional DeepSpeed integration
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model, optimizer=optimizer, args=args,
lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False,
)
# Step 5: Optional FP16 wrapper
optimizer = FP16_Optimizer(optimizer)
# Step 6: Training loop (standard PyTorch pattern)
for epoch in range(num_epochs):
model.train()
for batch in train_dataloader:
inputs, labels = batch
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Evaluation
model.eval()
with torch.no_grad():
for batch in eval_dataloader:
inputs, labels = batch
outputs = model(inputs)
eval_loss = loss_fn(outputs, labels)
Environment Variable Configuration
# Enable compute optimizer (default: enabled)
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1
# Enable memory-efficient gradient management
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1
# Enable embedding sparse optimizer (default: enabled)
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=1
# Enable memory optimization (transformer layerwise recompute)
export ORTMODULE_MEMORY_OPT_LEVEL=1
Implements
Principle:Microsoft_Onnxruntime_ORTModule_Training_Loop
Related Pages
- ORTModule Wrap -- The model wrapping step that precedes training execution
- FusedAdam FP16Optimizer -- Optimized parameter updates used in the loop
- Memory Opt Env Config -- Memory optimization active during training
- GlobalSubscriberManager Usage -- Monitoring tools for the training loop
- Heuristic:Microsoft_Onnxruntime_ORTModule_Wrapping_Order