Heuristic:Microsoft Onnxruntime Flash Attention Optimization
| Field | Value |
|---|---|
| Sources | docs/ORTModule_Training_Guidelines.md (L265-273, L469-476, L307-315)
|
| Domains | Training, Attention Mechanisms, GPU Optimization, Triton |
| Last Updated | 2026-02-10 |
Overview
Enable Flash Attention or efficient attention kernels in ORTModule to accelerate transformer attention computation and reduce memory usage during training.
Description
Attention computation is typically the most memory-intensive and compute-heavy operation in transformer-based models. ONNX Runtime provides multiple attention optimization paths through ORTModule, each with different hardware requirements and trade-offs:
- Flash Attention (Triton-based) -- Uses OpenAI Triton to generate optimized Flash Attention kernels. This is the highest-performance option but requires both Triton support and CUDA compute capability 8.0 or above (NVIDIA Ampere architecture and newer, e.g., A100, A10, RTX 3090).
- Efficient Attention (PyTorch ATen kernel) -- Falls back to PyTorch's built-in
efficient_attentionATen kernel. This is a more portable option that works on older hardware but requires PyTorch version 2.1.1 or above. - Scaled Dot Product Attention fallback -- Pre-export fallback to PyTorch's
_scaled_dot_product_efficient_attentionfor models usingtorch.nn.functional.scaled_dot_product_attention. This path has a key limitation: it only works for attention without masking (i.e.,attn_mask=None).
All three paths rely on attention fusion patterns that identify attention subgraphs in the ONNX graph and replace them with optimized implementations. ORT includes built-in patterns for common attention architectures, and users can add custom patterns for non-standard attention implementations.
Usage
Use this heuristic when:
- Training transformer-based models (GPT, BERT, LLaMA, etc.) with ORTModule.
- Running on Ampere or newer GPUs (for Flash Attention).
- Wanting to reduce attention memory footprint (Flash Attention avoids materializing the full attention matrix).
- Attention computation is a significant portion of total training time (use GPU profiling to verify).
The Insight (Rule of Thumb)
Choose the attention optimization path based on your hardware and software environment:
Decision tree:
- If CUDA compute capability >= 8.0 (Ampere+) and Triton is available:
- Set
ORTMODULE_USE_TRITON=1andORTMODULE_USE_FLASH_ATTENTION=1
- Set
- Else if PyTorch >= 2.1.1:
- Set
ORTMODULE_USE_EFFICIENT_ATTENTION=1
- Set
- Else if model uses
scaled_dot_product_attentionwithout masking:- Set
ORTMODULE_ATEN_SDPA_FALLBACK=1
- Set
Environment variable configuration:
# Option 1: Flash Attention (best performance, requires Ampere+ and Triton) export ORTMODULE_USE_TRITON=1 export ORTMODULE_USE_FLASH_ATTENTION=1 # Option 2: Efficient Attention (requires torch >= 2.1.1) export ORTMODULE_USE_EFFICIENT_ATTENTION=1 # Option 3: SDPA fallback (only for attention WITHOUT masking) export ORTMODULE_ATEN_SDPA_FALLBACK=1
Prerequisites:
| Option | Requirements |
|---|---|
| Flash Attention | ORTMODULE_USE_TRITON=1, CUDA capability >= 8.0 (Ampere+)
|
| Efficient Attention | PyTorch >= 2.1.1 |
| SDPA Fallback | attn_mask=None (no masking)
|
Verification:
Use GPU profiling (e.g., NVIDIA Nsight Systems) to confirm which attention variant is actually being used at runtime. If none of the built-in attention fusion patterns match your model's attention implementation, you can add a custom pattern in your user script manually.
Reasoning
Standard attention computation has O(n^2) memory complexity in sequence length because it materializes the full attention score matrix. Flash Attention reduces this to O(n) by computing attention in tiles without materializing the full matrix, which both reduces memory usage and improves throughput through better GPU memory hierarchy utilization. The Triton-based implementation generates hardware-optimized kernels at runtime, taking advantage of GPU-specific features like shared memory and warp-level primitives. The efficient attention fallback through PyTorch's ATen kernel provides similar benefits on older hardware or when Triton is not available. The SDPA fallback path is the most limited because it works through pre-export substitution (replacing the attention call before ONNX export), which restricts it to the specific case of unmasked attention. The reason multiple paths exist is to cover the wide range of hardware and software environments in which ORT training is deployed: from cutting-edge data center GPUs with Triton support to older or consumer GPUs where only PyTorch's built-in kernels are available.