Heuristic:Microsoft Onnxruntime Memory Recomputation Optimization
| Field | Value |
|---|---|
| Sources | docs/Memory_Optimizer.md, docs/ORTModule_Training_Guidelines.md (L285-305)
|
| Domains | Training, Memory Management, GPU Optimization |
| Last Updated | 2026-02-10 |
Overview
Trade GPU compute time for reduced VRAM consumption by selectively recomputing activations instead of stashing them during the forward pass.
Description
ONNX Runtime Training provides a memory optimizer that trades node and subgraph recomputations for better memory efficiency. During the forward pass of training, intermediate activations are normally saved ("stashed") for use in the backward pass. When GPU memory is scarce, the memory optimizer can instead discard some of these activations and recompute them on demand during the backward pass. This is conceptually equivalent to PyTorch's gradient checkpointing but operates at the ONNX graph level, giving ORT fine-grained control over exactly which subgraphs to recompute.
The optimizer works by scanning the execution graph to identify all recomputable subgraph candidates from a pre-defined list of recomputable operators. Each candidate is represented by a cluster id (e.g., BiasGelu+, BiasSoftmax+) and reports how much memory it can save. Users then choose which subgraphs to recompute, either automatically via optimization levels or manually via a JSON configuration file. The typical trade-off is roughly 20% additional compute overhead in exchange for substantial memory savings that allow larger batch sizes or bigger models.
Usage
Use this heuristic when:
- OOM errors occur during training, even with the minimum required batch size.
- You want to increase the batch size (e.g., from 2^N to 2^(N+1)) but GPU memory is the bottleneck.
- The model is too large to fit on a single GPU without memory optimization.
- GPU compute and memory bandwidth are not fully saturated at the current batch size.
Do not use this heuristic when:
- Training already fully saturates GPU compute and memory at the current batch size (e.g., batch size 6 with full utilization; bumping to 8 with recompute may not yield better throughput).
- The model is small enough that memory is not a constraint.
The Insight (Rule of Thumb)
Set ORTMODULE_MEMORY_OPT_LEVEL=1 as a first step when hitting OOM during training. This enables transformer layerwise recompute across all detected subgraphs within each transformer layer, which is the ORT equivalent of PyTorch gradient checkpointing. If memory savings are still insufficient, escalate to ORTMODULE_MEMORY_OPT_LEVEL=2 to include compromised recomputable subgraphs (which save a portion, not all, of their stashed activations). For fine-grained control, use ORTMODULE_MEMORY_OPT_LEVEL=0 with ORTMODULE_MEMORY_OPT_CONFIG=<path to JSON> to hand-pick which subgraphs to recompute and how many occurrences to target, iterating until you find the best plan. Additionally, enable ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 for early release of gradient buffers, further reducing memory peaks.
Configuration format: The JSON config file is an array of strings, each in the format "ClusterID+:strategy:count", where strategy is 0 (disabled), 1 (recompute), or 2 (compromised recompute), and count is the number of occurrences to apply (-1 for all).
Example:
export ORTMODULE_MEMORY_OPT_LEVEL=0
export ORTMODULE_MEMORY_OPT_CONFIG="mem_opt.json"
# mem_opt.json contents:
[
"BiasGelu+:1:1",
"Dropout+:1:-1"
]
Key environment variables:
ORTMODULE_MEMORY_OPT_LEVEL=0-- manual subgraph selection (default)ORTMODULE_MEMORY_OPT_LEVEL=1-- transformer layerwise recompute (excludes compromised graphs)ORTMODULE_MEMORY_OPT_LEVEL=2-- all recomputable subgraphs including compromised onesORTMODULE_MEMORY_OPT_CONFIG=<path>-- JSON file specifying which subgraphs to recomputeORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1-- early gradient buffer release
Reasoning
Modern transformer models generate large amounts of intermediate activations during the forward pass that must be retained for gradient computation in the backward pass. For models with many layers, this memory consumption scales linearly and can easily exceed GPU VRAM. Recomputation addresses this by discarding activations and recomputing them when needed, effectively trading compute cycles (which are relatively cheap on modern GPUs) for memory (which is the scarce resource). The ORT memory optimizer improves on naive gradient checkpointing by operating at the subgraph level, allowing users to select only the most memory-hungry subgraphs for recomputation while leaving others untouched. The compromised recompute option (strategy 2) provides an intermediate option where only a portion of the stashed activation (e.g., half) is saved, balancing memory savings with compute overhead. Using DebugOptions(log_level=LogLevel.DEVINFO) provides detailed tables showing each subgraph's frequency, estimated memory savings (both absolute and symbolic), and optimization status.