Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft Onnxruntime ORTModule Training Loop

From Leeroopedia


Overview

Execution of a standard PyTorch training loop with ORTModule providing transparent ORT-accelerated computation.

Metadata

Field Value
Principle Name ORTModule_Training_Loop
Category Wrapper Doc
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/ORTModule_Training_Guidelines.md:L379-400
Last Updated 2026-02-10

Description

The training loop with ORTModule follows standard PyTorch patterns: forward pass, loss computation, backward pass, optimizer step. ORTModule transparently intercepts forward and backward calls, executing them through ORT's optimized engine while maintaining full PyTorch API compatibility.

The loop structure is identical to a native PyTorch training loop:

  1. Forward Pass -- loss = model(inputs) triggers ORTModule to execute the ONNX graph via ORT's engine rather than PyTorch's autograd.
  2. Backward Pass -- loss.backward() triggers ORT to compute gradients through the optimized backward graph.
  3. Optimizer Step -- optimizer.step() updates parameters. This can use standard PyTorch optimizers or ORT's FusedAdam.
  4. Gradient Reset -- optimizer.zero_grad() clears gradients for the next iteration.

The key insight is that ORTModule maintains full PyTorch API compatibility. Existing training scripts require only two lines of modification: importing ORTModule and wrapping the model. All other training code (data loading, loss computation, optimizer, scheduler) remains unchanged.

Environment Variables

Several environment variables control ORTModule behavior during the training loop:

Variable Default Description
ORTMODULE_ENABLE_COMPUTE_OPTIMIZER 1 (enabled) Enables graph-level compute optimizations
ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT 0 (disabled) Enables memory-efficient gradient management that releases ORT gradient buffers early
ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER 1 (enabled) Enables sparse optimization for embedding layers

Theoretical Basis

The training loop implements the standard stochastic gradient descent paradigm:

  • Loss Minimization -- Each iteration moves model parameters in the direction that locally minimizes the loss function, as estimated by the gradient on the current mini-batch.
  • Transparent Acceleration -- ORTModule's interception of forward/backward passes is possible because PyTorch's nn.Module.__call__ mechanism allows the forward method to be overridden. ORTModule replaces the forward method with one that delegates to ORT's execution engine.
  • Graph Caching -- The ONNX graph is exported once (on the first forward call) and cached for subsequent iterations. Re-export occurs only when the graph structure changes, minimizing the amortized overhead of ONNX conversion.

Usage

model = build_model()

from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(model)

from onnxruntime.training.optim import FusedAdam
optimizer = FusedAdam(model.parameters(), lr=1e-4)

# Optional: DeepSpeed integration
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model, optimizer=optimizer, args=args,
    lr_scheduler=lr_scheduler, mpu=mpu, dist_init_required=False,
)

# Optional: FP16 optimizer wrapper
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)

# Standard training loop -- unchanged from native PyTorch
for epoch in range(num_epochs):
    for batch in train_dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Implemented By

Implementation:Microsoft_Onnxruntime_ORTModule_Training_Execution

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment