Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime ORTModule Training Execution

From Leeroopedia


Overview

Executes a standard PyTorch training loop where ORTModule transparently intercepts forward and backward passes to accelerate computation through ONNX Runtime's optimized execution engine.

Metadata

Field Value
Implementation Name ORTModule_Training_Execution
Type Wrapper Doc
Language Python
API Standard PyTorch training loop: loss = model(inputs), loss.backward(), optimizer.step(), optimizer.zero_grad() -- ORTModule transparently intercepts
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/ORTModule_Training_Guidelines.md:L379-400
Last Updated 2026-02-10

Description

The training execution with ORTModule uses the exact same training loop pattern as native PyTorch. The key difference is that the model has been wrapped with ORTModule, which intercepts:

  • Forward pass (model(inputs)) -- ORTModule exports the model to ONNX (on first call), optimizes the graph, and executes it through ORT's engine.
  • Backward pass (loss.backward()) -- ORTModule routes gradient computation through ORT's optimized backward graph.

The optimizer and gradient management remain standard PyTorch operations, though ORT's FusedAdam and FP16_Optimizer can be used for additional performance.

Key Environment Variables

Several environment variables influence ORTModule behavior during training execution:

Variable Default Description
ORTMODULE_ENABLE_COMPUTE_OPTIMIZER 1 (enabled) Enables graph-level compute optimizations during execution
ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT 0 (disabled) Enables early gradient buffer release through PythonOpGrad operator
ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER 1 (enabled) Enables sparse optimization for embedding layers

API Signature

# The training loop API is standard PyTorch
for epoch in range(num_epochs):
    for batch in train_dataloader:
        loss = model(batch)       # Forward: intercepted by ORTModule
        loss.backward()           # Backward: intercepted by ORTModule
        optimizer.step()          # Parameter update: standard PyTorch or FusedAdam
        optimizer.zero_grad()     # Gradient reset: standard PyTorch

Key Parameters

Parameter Type Description
model ORTModule-wrapped nn.Module The model previously wrapped with ORTModule(model)
optimizer torch.optim.Optimizer Any PyTorch optimizer, optionally ORT's FusedAdam or FP16_Optimizer
inputs Tensors Training batch data passed to the model

I/O Contract

Direction Type Description
Input Training data batches Tensors from the data loader
Output Loss value Scalar tensor representing the computed loss
Side Effect Parameter updates Model parameters are updated in-place by the optimizer
Side Effect ORT graph caching ONNX graph is cached after first forward call

Code Reference

From docs/ORTModule_Training_Guidelines.md:

	model = build_model()

+	from onnxruntime.training.ortmodule import ORTModule
+	model = ORTModule(model)

-	optimizer = AdamW(model.parameters(), lr=1)
+	from onnxruntime.training.optim import FusedAdam
+	optimizer = FusedAdam(model.parameters(), lr=1)

	model, optimizer, _, lr_scheduler = deepspeed.initialize(
			model=model,
			optimizer=optimizer,
			args=args,
			lr_scheduler=lr_scheduler,
			mpu=mpu,
			dist_init_required=False)

+	from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
+	optimizer = FP16_Optimizer(optimizer)

Usage Example

Complete Training Pipeline

import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer

# Step 1: Build model
model = build_model()

# Step 2: Wrap with ORTModule (do this FIRST)
model = ORTModule(model)

# Step 3: Configure optimizer
optimizer = FusedAdam(model.parameters(), lr=1e-4)

# Step 4: Optional DeepSpeed integration
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model, optimizer=optimizer, args=args,
    lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False,
)

# Step 5: Optional FP16 wrapper
optimizer = FP16_Optimizer(optimizer)

# Step 6: Training loop (standard PyTorch pattern)
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        inputs, labels = batch
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    # Evaluation
    model.eval()
    with torch.no_grad():
        for batch in eval_dataloader:
            inputs, labels = batch
            outputs = model(inputs)
            eval_loss = loss_fn(outputs, labels)

Environment Variable Configuration

# Enable compute optimizer (default: enabled)
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1

# Enable memory-efficient gradient management
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1

# Enable embedding sparse optimizer (default: enabled)
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER=1

# Enable memory optimization (transformer layerwise recompute)
export ORTMODULE_MEMORY_OPT_LEVEL=1

Implements

Principle:Microsoft_Onnxruntime_ORTModule_Training_Loop

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment