Implementation:Microsoft Onnxruntime GlobalSubscriberManager Usage
Overview
Uses GlobalSubscriberManager and StatisticsSubscriber to capture per-step activation statistics for comparing convergence behavior between native PyTorch and ORTModule training, with inspect_activation for intermediate tensor inspection.
Metadata
| Field | Value |
|---|---|
| Implementation Name | GlobalSubscriberManager_Usage |
| Type | API Doc |
| Language | Python |
| API | onnxruntime.training.utils.hooks.GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir)]), inspect_activation(name, tensor)
|
| Import | from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
|
| Domain | Accelerated_Training, PyTorch_Integration |
| Repository | microsoft/onnxruntime |
| Source Reference | docs/ORTModule_Convergence_Notes.md:L33-38 (subscriber), L91-115 (inspect_activation) |
| Last Updated | 2026-02-10 |
Description
This implementation provides two mechanisms for capturing activation data during training:
GlobalSubscriberManager
The GlobalSubscriberManager subscribes to nn.Module forward outputs across the entire model hierarchy. When a StatisticsSubscriber is attached, it captures per-step tensor statistics (min, max, mean, standard deviation, etc.) and writes them to files organized by step number in the specified output directory.
Key characteristics:
- Can be subscribed before or after wrapping with ORTModule
- Works identically on both native PyTorch and ORTModule execution paths
- Only captures
nn.Moduleforward output tensors (not intermediate values) - Does not handle multi-rank racing conditions -- separate output directories must be used per rank
inspect_activation
For capturing intermediate tensors within a module's forward() method, inspect_activation is inserted inline. It takes a unique name and a tensor, records its statistics, and returns the tensor unchanged (pass-through). The statistics are stored in the output_dir configured by the subscriber.
Each activation name must be unique across the entire model to prevent file overwrites.
API Signature
GlobalSubscriberManager
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model,
[StatisticsSubscriber(
output_dir="output_directory",
override_output_dir=True,
start_step=None, # optional: first step to collect
end_step=None, # optional: last step (exclusive) to collect
run_on_cpu=False, # optional: run on CPU to avoid OOM
bucket_size=None, # optional: bucket size for statistic calculation
)],
)
inspect_activation
from onnxruntime.training.utils.hooks import inspect_activation
# Inside a module's forward() method:
tensor = inspect_activation("unique_activation_name", tensor)
Key Parameters
StatisticsSubscriber
| Parameter | Type | Description |
|---|---|---|
| output_dir | str |
Directory where activation statistics files are stored |
| override_output_dir | bool |
Whether to overwrite the output directory if it already exists |
| start_step | int (optional) |
First training step at which to begin collecting statistics |
| end_step | int (optional) |
Training step at which to stop collecting statistics (exclusive) |
| run_on_cpu | bool |
Whether to run subscriber actions on CPU (last resort for OOM situations) |
| bucket_size | int (optional) |
Size of the bucket for splitting statistic calculations |
inspect_activation
| Parameter | Type | Description |
|---|---|---|
| name | str |
Unique identifier for the activation (used as filename) |
| tensor | torch.Tensor |
The tensor to inspect (returned unchanged) |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | nn.Module model |
Model to subscribe to (native or ORTModule-wrapped) |
| Output | Statistics files | Per-step directories containing activation summary files |
| Input (inspect_activation) | Named tensor | Intermediate tensor to capture statistics for |
| Output (inspect_activation) | Same tensor | Pass-through (tensor is returned unchanged) |
Code Reference
From docs/ORTModule_Convergence_Notes.md:
Baseline (Native PyTorch)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)
ORTModule
model = ORTModule(model)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)
inspect_activation
+ from onnxruntime.training.utils.hooks import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
def forward(self, input_ids, ...):
...
lm_logits = self.lm_head(hidden_states)
+ lm_logits = inspect_activation("lm_logits", lm_logits)
shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_logits = inspect_activation("shift_logits", shift_logits)
...
Generating Comparison Summary
python -m onnxruntime.training.utils.hooks.merge_activation_summary \
--pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
Usage Example
Complete Convergence Investigation Workflow
import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils.hooks import (
GlobalSubscriberManager,
StatisticsSubscriber,
inspect_activation,
)
# Baseline run (native PyTorch)
model_pt = build_model()
GlobalSubscriberManager.subscribe(
model_pt,
[StatisticsSubscriber(output_dir="pt_out", override_output_dir=True, end_step=100)],
)
# Run baseline training for 100 steps...
# ORTModule run
model_ort = build_model()
model_ort = ORTModule(model_ort)
GlobalSubscriberManager.subscribe(
model_ort,
[StatisticsSubscriber(output_dir="ort_out", override_output_dir=True, end_step=100)],
)
# Run ORTModule training for 100 steps...
# Compare results
# python -m onnxruntime.training.utils.hooks.merge_activation_summary \
# --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/comparison
Multi-Rank Collection
import torch.distributed
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
rank = torch.distributed.get_rank()
GlobalSubscriberManager.subscribe(
model,
[StatisticsSubscriber(
output_dir=f"ort_out_{rank}",
override_output_dir=True,
)],
)
Implements
Principle:Microsoft_Onnxruntime_Training_Monitoring_and_Debugging
Related Pages
- ORTModule Training Execution -- The training loop being monitored
- ORTModule Wrap -- The ORTModule wrapping whose behavior is being debugged
- Heuristic:Microsoft_Onnxruntime_Convergence_Debugging_Tips