Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft Onnxruntime Memory Optimization

From Leeroopedia


Overview

Reduction of GPU memory consumption during training through activation recomputation strategies.

Metadata

Field Value
Principle Name Memory_Optimization
Category Pattern Doc
Domain Accelerated_Training, PyTorch_Integration
Repository microsoft/onnxruntime
Source Reference docs/Memory_Optimizer.md:L33-34 (level), L81-96 (config)
Last Updated 2026-02-10

Description

ONNX Runtime's memory optimizer trades compute for memory by selectively recomputing activations during the backward pass instead of storing them. Three levels are available: disabled, automatic transformer-layer recompute, and manual subgraph-level recompute via configuration.

The memory optimizer is implemented as a graph transformer that identifies re-computable subgraphs -- groups of connected operators whose intermediate outputs can be discarded during the forward pass and recomputed during the backward pass when needed for gradient calculation.

Optimization Levels

Level Name Description
0 Disabled / User-Selected Memory optimization is off by default. When combined with ORTMODULE_MEMORY_OPT_CONFIG, enables user-selected subgraph recompute.
1 Transformer Layerwise Recompute Aggressively recomputes all supported nodes within each transformer layer (attention and MLP sublayers). User config is not respected in this mode.
2 Compromised Recompute Extends level 1 to include compromised re-computable subgraphs that may have less favorable compute-to-memory trade-offs.

Typical Use Cases

  • Training with ORTModule at batch size B, where GPU memory and compute are not fully saturated, but attempting batch size 2B causes OOM.
  • Large models where even the minimum allowed batch size exceeds available GPU memory.

Not all models benefit from this optimization. If the current batch size fully saturates GPU compute and memory bandwidth, enabling recompute with a larger batch size may not improve throughput.

Theoretical Basis

Activation checkpointing (gradient checkpointing) reduces peak memory by O(sqrt(n)) for n layers by discarding intermediate activations and recomputing them during the backward pass. The trade-off is additional forward computation time.

  • Memory-Compute Trade-off -- During the forward pass of a standard training pipeline, all intermediate activations are stored in memory for use during the backward pass. For deep networks, this can consume the majority of GPU memory. By discarding some activations and recomputing them on demand, peak memory consumption is reduced at the cost of additional forward computation.
  • Subgraph Granularity -- Unlike PyTorch's layer-level gradient checkpointing, ORT's memory optimizer operates at the subgraph level. It identifies specific operator clusters (e.g., BiasGelu+, BiasSoftmax+) and selectively recomputes them. This provides finer-grained control over the memory-compute trade-off.
  • Symbolic Size Analysis -- The optimizer computes symbolic expressions for the memory savings of each recomputable subgraph. This allows users to make informed decisions about which subgraphs to enable based on their actual batch size and sequence length.

Usage

Simple Mode (Transformer Layerwise Recompute)

export ORTMODULE_MEMORY_OPT_LEVEL=1
# Run training as usual -- all supported nodes within transformer layers will be recomputed

Advanced Mode (User-Selected Subgraph Recompute)

export ORTMODULE_MEMORY_OPT_LEVEL=0
export ORTMODULE_MEMORY_OPT_CONFIG="mem_opt.json"

Where mem_opt.json contains:

[
    "BiasGelu+:1:1",
    "Dropout+:1:-1"
]

The configuration format is: <cluster_id>:<strategy>:<count> where strategy 0=none, 1=recompute, 2=compromised recompute, and count specifies how many occurrences to apply (-1 for all).

Implemented By

Implementation:Microsoft_Onnxruntime_Memory_Opt_Env_Config

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment