Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Microsoft Onnxruntime Convergence Debugging Tips

From Leeroopedia



Field Value
Sources docs/ORTModule_Convergence_Notes.md, docs/ORTModule_Training_Guidelines.md (L52-133, L199-209)
Domains Training, Debugging, Convergence Analysis, Numerical Validation
Last Updated 2026-02-10

Overview

Systematically diagnose and resolve convergence discrepancies between ORTModule and PyTorch training by collecting activation statistics and eliminating sources of non-determinism.

Description

When training with ORTModule, convergence issues may manifest as large discrepancies in training loss, evaluation loss, or model-specific AUC metrics compared to a PyTorch baseline. These differences can stem from several sources: non-deterministic operations (dropout, random initialization), compute optimizations that alter floating-point accumulation order, or genuine bugs in graph export or operator implementation.

ONNX Runtime provides a structured debugging workflow centered on the GlobalSubscriberManager and StatisticsSubscriber classes. These tools hook into nn.Module forward outputs to collect activation statistics (min, max, mean, norm, etc.) at each training step, enabling side-by-side comparison between PyTorch and ORTModule runs. For intermediate tensors within a module's forward function (not just the outputs), the inspect_activation utility can be inserted directly into the model code to dump specific named tensors.

The debugging process follows a systematic elimination approach: first verify that the discrepancy is not due to randomness, then remove all sources of non-determinism, and finally use activation statistics to pinpoint the exact module and step where outputs diverge.

Usage

Use this heuristic when:

  • Training loss diverges between ORTModule and PyTorch after a certain number of steps.
  • Evaluation metrics (AUC, accuracy, etc.) are significantly worse with ORTModule.
  • Runtime failures occur (e.g., loss scaler reaching minimum and raising an exception).
  • You need to validate that ORTModule produces numerically equivalent results to PyTorch.

The Insight (Rule of Thumb)

Follow this systematic debugging workflow:

Step 1: Rule out randomness

Change the seed for the baseline PyTorch run. If the metric difference is comparable in magnitude to the ORTModule discrepancy, the issue is likely random variation, not a bug.

Step 2: Remove non-determinism

  • Set identical seeds for both PyTorch and ORTModule runs.
  • Set dropout ratio to 0.
  • Enable deterministic compute: ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 (disables ORT compute optimizations to guarantee exact parity with PyTorch).

Step 3: Collect activation statistics

# PyTorch baseline:
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
    model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)

# ORTModule:
model = ORTModule(model)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
    model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)

Step 4: Inspect intermediate tensors (if needed)

from onnxruntime.training.utils.hooks import inspect_activation

def forward(self, input_ids, ...):
    hidden_states = self.transformer(...)
    lm_logits = self.lm_head(hidden_states)
    lm_logits = inspect_activation("lm_logits", lm_logits)  # dumps to output_dir
    ...

Ensure each activation name is unique; otherwise, statistics files will be overwritten.

Step 5: Multi-rank collection

For distributed training, use a unique output_dir per rank to avoid file write conflicts:

GlobalSubscriberManager.subscribe(
    model, [StatisticsSubscriber(
        output_dir="ort_out_" + str(torch.distributed.get_rank()),
        override_output_dir=True
    )]
)

Step 6: Generate per-step comparison

python -m onnxruntime.training.utils.hooks.merge_activation_summary \
    --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output

Manually review the summary to identify the first step and module where a large divergence appears.

Log levels for debugging:

Level Purpose
WARNING Default for users; minimal output
INFO Experimental feature statistics; slightly more error detail
DEVINFO Recommended for debugging; enables ORT backend INFO logs and multi-rank logging
VERBOSE Last resort for hard problems; maximum log output

Key environment variables:

  • ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0 -- Disable compute optimizations for exact numerical parity with PyTorch.
  • ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE" -- Prevent silent fallback to PyTorch (ensures you are actually testing ORT).

Reasoning

Convergence debugging is challenging because discrepancies can originate from many sources: random initialization, non-deterministic GPU operations, floating-point accumulation order differences due to operator fusion, or genuine export/execution bugs. The systematic elimination approach works by progressively reducing the search space. First, establishing that the discrepancy exceeds what random variation alone would produce confirms that there is a real issue to investigate. Then, removing all sources of non-determinism (seeds, dropout, compute optimizations) creates a controlled environment where any remaining difference must come from the ORT execution path itself. The GlobalSubscriberManager with StatisticsSubscriber provides an automated way to collect comparable statistics without manually modifying every module, while inspect_activation enables targeted inspection of specific tensors within complex forward functions. The per-rank output directory pattern prevents file corruption in multi-GPU distributed training, which is the environment where most convergence issues are encountered in practice. The merge_activation_summary command produces a unified comparison that highlights divergent activations, making it straightforward to trace the issue back to a specific layer and step.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment