Heuristic:Microsoft Onnxruntime Convergence Debugging Tips
| Field | Value |
|---|---|
| Sources | docs/ORTModule_Convergence_Notes.md, docs/ORTModule_Training_Guidelines.md (L52-133, L199-209)
|
| Domains | Training, Debugging, Convergence Analysis, Numerical Validation |
| Last Updated | 2026-02-10 |
Overview
Systematically diagnose and resolve convergence discrepancies between ORTModule and PyTorch training by collecting activation statistics and eliminating sources of non-determinism.
Description
When training with ORTModule, convergence issues may manifest as large discrepancies in training loss, evaluation loss, or model-specific AUC metrics compared to a PyTorch baseline. These differences can stem from several sources: non-deterministic operations (dropout, random initialization), compute optimizations that alter floating-point accumulation order, or genuine bugs in graph export or operator implementation.
ONNX Runtime provides a structured debugging workflow centered on the GlobalSubscriberManager and StatisticsSubscriber classes. These tools hook into nn.Module forward outputs to collect activation statistics (min, max, mean, norm, etc.) at each training step, enabling side-by-side comparison between PyTorch and ORTModule runs. For intermediate tensors within a module's forward function (not just the outputs), the inspect_activation utility can be inserted directly into the model code to dump specific named tensors.
The debugging process follows a systematic elimination approach: first verify that the discrepancy is not due to randomness, then remove all sources of non-determinism, and finally use activation statistics to pinpoint the exact module and step where outputs diverge.
Usage
Use this heuristic when:
- Training loss diverges between ORTModule and PyTorch after a certain number of steps.
- Evaluation metrics (AUC, accuracy, etc.) are significantly worse with ORTModule.
- Runtime failures occur (e.g., loss scaler reaching minimum and raising an exception).
- You need to validate that ORTModule produces numerically equivalent results to PyTorch.
The Insight (Rule of Thumb)
Follow this systematic debugging workflow:
Step 1: Rule out randomness
Change the seed for the baseline PyTorch run. If the metric difference is comparable in magnitude to the ORTModule discrepancy, the issue is likely random variation, not a bug.
Step 2: Remove non-determinism
- Set identical seeds for both PyTorch and ORTModule runs.
- Set dropout ratio to 0.
- Enable deterministic compute:
ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0(disables ORT compute optimizations to guarantee exact parity with PyTorch).
Step 3: Collect activation statistics
# PyTorch baseline:
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)
# ORTModule:
model = ORTModule(model)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)
Step 4: Inspect intermediate tensors (if needed)
from onnxruntime.training.utils.hooks import inspect_activation
def forward(self, input_ids, ...):
hidden_states = self.transformer(...)
lm_logits = self.lm_head(hidden_states)
lm_logits = inspect_activation("lm_logits", lm_logits) # dumps to output_dir
...
Ensure each activation name is unique; otherwise, statistics files will be overwritten.
Step 5: Multi-rank collection
For distributed training, use a unique output_dir per rank to avoid file write conflicts:
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(
output_dir="ort_out_" + str(torch.distributed.get_rank()),
override_output_dir=True
)]
)
Step 6: Generate per-step comparison
python -m onnxruntime.training.utils.hooks.merge_activation_summary \
--pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
Manually review the summary to identify the first step and module where a large divergence appears.
Log levels for debugging:
| Level | Purpose |
|---|---|
WARNING |
Default for users; minimal output |
INFO |
Experimental feature statistics; slightly more error detail |
DEVINFO |
Recommended for debugging; enables ORT backend INFO logs and multi-rank logging |
VERBOSE |
Last resort for hard problems; maximum log output |
Key environment variables:
ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=0-- Disable compute optimizations for exact numerical parity with PyTorch.ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"-- Prevent silent fallback to PyTorch (ensures you are actually testing ORT).
Reasoning
Convergence debugging is challenging because discrepancies can originate from many sources: random initialization, non-deterministic GPU operations, floating-point accumulation order differences due to operator fusion, or genuine export/execution bugs. The systematic elimination approach works by progressively reducing the search space. First, establishing that the discrepancy exceeds what random variation alone would produce confirms that there is a real issue to investigate. Then, removing all sources of non-determinism (seeds, dropout, compute optimizations) creates a controlled environment where any remaining difference must come from the ORT execution path itself. The GlobalSubscriberManager with StatisticsSubscriber provides an automated way to collect comparable statistics without manually modifying every module, while inspect_activation enables targeted inspection of specific tensors within complex forward functions. The per-rank output directory pattern prevents file corruption in multi-GPU distributed training, which is the environment where most convergence issues are encountered in practice. The merge_activation_summary command produces a unified comparison that highlights divergent activations, making it straightforward to trace the issue back to a specific layer and step.