Heuristic:NVIDIA TransformerEngine Attention Backend Selection
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs, Attention |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
Attention backend selection heuristics controlling when FlashAttention-2, FlashAttention-3, cuDNN FusedAttention, or unfused PyTorch attention is used based on GPU, sequence length, head dimensions, and dtype.
Description
TransformerEngine supports multiple attention backends, each with different hardware requirements, feature sets, and performance characteristics. The backend selection is controlled by environment variables (`NVTE_FLASH_ATTN`, `NVTE_FUSED_ATTN`, `NVTE_UNFUSED_ATTN`) and automatic compatibility checks. Understanding the selection logic helps users maximize performance by ensuring the fastest backend is active for their configuration.
Usage
Use this heuristic when debugging attention performance or when encountering unexpected backend fallbacks. The most common issue is FlashAttention being silently disabled due to version mismatches, unsupported head dimensions, or causal mask incompatibilities.
The Insight (Rule of Thumb)
- Action: Check backend selection with `NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2` to see which backend is chosen.
- Value: Backend priority and constraints:
- FlashAttention-3 (SM 9.0 only): Fastest for Hopper, supports FP8 attention. Requires separate installation.
- FlashAttention-2 (SM 8.0+): Best general-purpose backend. Version 2.1.1+ required, 2.7.3+ for Blackwell.
- cuDNN FusedAttention: Two sub-backends:
- `F16_max512_seqlen`: For seq_len <= 512 only. Limited feature set.
- `F16_arbitrary_seqlen`: For any sequence length. Training on SM < 9.0 requires cuDNN >= 8.9.5.
- Unfused PyTorch: Fallback when no fused backend is available.
- Trade-off: FlashAttention is generally fastest but has strict head dimension and dtype constraints. FusedAttention/cuDNN supports more dtypes (including FP8) but may be slower for some configurations.
Key Constraints
- Head dimension (FlashAttention-2): Must be <= 256 and divisible by 8. Heads > 192 only supported on SM 80, 90, 100, 120.
- Dtype (FlashAttention-2): Only `torch.bfloat16` and `torch.float16`.
- Causal masking: If `max_seqlen_q != max_seqlen_kv` with causal mask, FlashAttention is disabled.
- SM 8.9 (Ada) bug: FusedAttention with KV caching is disabled due to a cuDNN bug.
- Determinism: FP8 FusedAttention cannot be deterministic. FlashAttention-2 determinism requires version >= 2.4.1.
Reasoning
Each backend is optimized for different scenarios. FlashAttention uses IO-aware tiling to minimize HBM reads, making it fastest for most configurations. cuDNN FusedAttention provides broader dtype support (FP8, FP32) and can leverage cuDNN graph optimizations. The unfused backend serves as a correctness reference and fallback. The selection logic prioritizes the fastest available backend while respecting hardware and feature constraints.
Code Evidence
Environment variable controls from `transformer_engine/pytorch/attention/dot_product_attention/utils.py:57-64`:
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1"))
Flash Attention version requirements from `transformer_engine/pytorch/attention/dot_product_attention/utils.py:114-118`:
class FlashAttentionUtils:
version_required = PkgVersion("2.1.1") # Pre-Blackwell
version_required_blackwell = PkgVersion("2.7.3") # Blackwell (SM 10.0+)
max_version = PkgVersion("2.8.3") # Maximum supported