Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime GlobalSubscriberManager Usage

From Leeroopedia


Overview

Uses GlobalSubscriberManager and StatisticsSubscriber to capture per-step activation statistics for comparing convergence behavior between native PyTorch and ORTModule training, with inspect_activation for intermediate tensor inspection.

Metadata

Field Value
Implementation Name GlobalSubscriberManager_Usage
Type API Doc
Language Python
API onnxruntime.training.utils.hooks.GlobalSubscriberManager.subscribe(model, [StatisticsSubscriber(output_dir)]), inspect_activation(name, tensor)
Import from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/ORTModule_Convergence_Notes.md:L33-38 (subscriber), L91-115 (inspect_activation)
Last Updated 2026-02-10

Description

This implementation provides two mechanisms for capturing activation data during training:

GlobalSubscriberManager

The GlobalSubscriberManager subscribes to nn.Module forward outputs across the entire model hierarchy. When a StatisticsSubscriber is attached, it captures per-step tensor statistics (min, max, mean, standard deviation, etc.) and writes them to files organized by step number in the specified output directory.

Key characteristics:

  • Can be subscribed before or after wrapping with ORTModule
  • Works identically on both native PyTorch and ORTModule execution paths
  • Only captures nn.Module forward output tensors (not intermediate values)
  • Does not handle multi-rank racing conditions -- separate output directories must be used per rank

inspect_activation

For capturing intermediate tensors within a module's forward() method, inspect_activation is inserted inline. It takes a unique name and a tensor, records its statistics, and returns the tensor unchanged (pass-through). The statistics are stored in the output_dir configured by the subscriber.

Each activation name must be unique across the entire model to prevent file overwrites.

API Signature

GlobalSubscriberManager

from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber

GlobalSubscriberManager.subscribe(
    model,
    [StatisticsSubscriber(
        output_dir="output_directory",
        override_output_dir=True,
        start_step=None,        # optional: first step to collect
        end_step=None,          # optional: last step (exclusive) to collect
        run_on_cpu=False,       # optional: run on CPU to avoid OOM
        bucket_size=None,       # optional: bucket size for statistic calculation
    )],
)

inspect_activation

from onnxruntime.training.utils.hooks import inspect_activation

# Inside a module's forward() method:
tensor = inspect_activation("unique_activation_name", tensor)

Key Parameters

StatisticsSubscriber

Parameter Type Description
output_dir str Directory where activation statistics files are stored
override_output_dir bool Whether to overwrite the output directory if it already exists
start_step int (optional) First training step at which to begin collecting statistics
end_step int (optional) Training step at which to stop collecting statistics (exclusive)
run_on_cpu bool Whether to run subscriber actions on CPU (last resort for OOM situations)
bucket_size int (optional) Size of the bucket for splitting statistic calculations

inspect_activation

Parameter Type Description
name str Unique identifier for the activation (used as filename)
tensor torch.Tensor The tensor to inspect (returned unchanged)

I/O Contract

Direction Type Description
Input nn.Module model Model to subscribe to (native or ORTModule-wrapped)
Output Statistics files Per-step directories containing activation summary files
Input (inspect_activation) Named tensor Intermediate tensor to capture statistics for
Output (inspect_activation) Same tensor Pass-through (tensor is returned unchanged)

Code Reference

From docs/ORTModule_Convergence_Notes.md:

Baseline (Native PyTorch)

from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
    model, [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True)]
)

ORTModule

model = ORTModule(model)
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber
GlobalSubscriberManager.subscribe(
    model, [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True)]
)

inspect_activation

+   from onnxruntime.training.utils.hooks import inspect_activation
class BloomForCausalLM(BloomPreTrainedModel):
  def forward(self, input_ids, ...):
    ...
    lm_logits = self.lm_head(hidden_states)
+   lm_logits = inspect_activation("lm_logits", lm_logits)
    shift_logits = lm_logits[..., :-1, :].contiguous()
+   shift_logits = inspect_activation("shift_logits", shift_logits)
    ...

Generating Comparison Summary

python -m onnxruntime.training.utils.hooks.merge_activation_summary \
    --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output

Usage Example

Complete Convergence Investigation Workflow

import torch
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.utils.hooks import (
    GlobalSubscriberManager,
    StatisticsSubscriber,
    inspect_activation,
)

# Baseline run (native PyTorch)
model_pt = build_model()
GlobalSubscriberManager.subscribe(
    model_pt,
    [StatisticsSubscriber(output_dir="pt_out", override_output_dir=True, end_step=100)],
)
# Run baseline training for 100 steps...

# ORTModule run
model_ort = build_model()
model_ort = ORTModule(model_ort)
GlobalSubscriberManager.subscribe(
    model_ort,
    [StatisticsSubscriber(output_dir="ort_out", override_output_dir=True, end_step=100)],
)
# Run ORTModule training for 100 steps...

# Compare results
# python -m onnxruntime.training.utils.hooks.merge_activation_summary \
#     --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/comparison

Multi-Rank Collection

import torch.distributed
from onnxruntime.training.utils.hooks import GlobalSubscriberManager, StatisticsSubscriber

rank = torch.distributed.get_rank()
GlobalSubscriberManager.subscribe(
    model,
    [StatisticsSubscriber(
        output_dir=f"ort_out_{rank}",
        override_output_dir=True,
    )],
)

Implements

Principle:Microsoft_Onnxruntime_Training_Monitoring_and_Debugging

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment