Principle:NVIDIA TransformerEngine Fused Attention
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Knowledge Sources | Paper (Flash Attention), Paper (Attention Is All You Need), Repo (TransformerEngine) |
| Domains | Deep_Learning, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Computing scaled dot-product attention using fused GPU kernels that combine the QK^T multiplication, scaling, masking, softmax, dropout, and value projection into a single kernel launch, avoiding materialization of the full attention matrix for improved performance and memory efficiency.
Description
Fused attention combines the entire scaled dot-product attention computation -- from query-key dot products through to the final value-weighted sum -- into a single fused GPU kernel call. This eliminates the need to write the full [batch, heads, seq_len, seq_len] attention score matrix to global memory, which is the primary memory bottleneck for long-sequence Transformer models.
Standard (Unfused) Attention
A naive implementation of scaled dot-product attention proceeds in discrete steps, each requiring a separate kernel launch and global memory round-trip:
- QK^T computation: Matrix multiply query and key tensors to produce attention scores of shape
[B, H, S_q, S_kv]. - Scaling: Divide by
sqrt(d_k). - Masking: Apply causal mask, padding mask, or arbitrary attention mask.
- Softmax: Compute softmax along the key dimension.
- Dropout: Apply attention dropout.
- Value projection: Matrix multiply the attention weights with value tensors.
Each intermediate tensor (S, P) has O(n^2) elements in the sequence length, leading to significant memory consumption and bandwidth waste.
Fused Attention Strategy
Fused attention implementations avoid materializing the full attention matrix by processing attention in tiles or blocks:
- Flash Attention: Processes Q, K, V in blocks, computing partial softmax results in SRAM (shared memory / registers) and accumulating the output incrementally. Only the final output tensor and a small amount of bookkeeping data (logsumexp per row) are written to global memory.
- cuDNN fused attention: NVIDIA's cuDNN library provides a fused multi-head attention kernel optimized for specific GPU architectures, handling the full attention pipeline including FP8 quantization of intermediate tensors.
Supported Features
TransformerEngine's fused attention supports:
- Multiple backends: cuDNN fused attention and Flash Attention, selected automatically based on hardware capabilities, sequence length, and configuration.
- Attention mask types:
"no_mask","padding","causal","padding_causal","causal_bottom_right","padding_causal_bottom_right", and"arbitrary". - Grouped Query Attention (GQA): Different numbers of query heads and key/value heads, where multiple query heads share the same KV head.
- QKV formats:
"sbhd"(sequence, batch, head, dim),"bshd"(batch, sequence, head, dim), and"thd"(token, head, dim) for variable-length sequences. - Context parallelism: Distributing the sequence dimension across multiple GPUs with P2P, all-gather, or all-to-all communication patterns.
- FP8 attention: Quantizing the QKV inputs, attention scores, and output to FP8 for additional throughput on Hopper GPUs.
Theoretical Basis
Scaled Dot-Product Attention
The fundamental attention operation from the "Attention Is All You Need" paper:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
where:
Qis the query tensor of shape[B, H_q, S_q, D].Kis the key tensor of shape[B, H_kv, S_kv, D].Vis the value tensor of shape[B, H_kv, S_kv, D_v].d_kis the key head dimension.- The output has shape
[B, H_q, S_q, D_v].
Flash Attention Algorithm
Flash Attention (Dao et al., 2022) avoids materializing the [S_q, S_kv] attention matrix by:
- Tiling: Dividing Q into blocks of rows and K, V into blocks of columns.
- Online softmax: Computing softmax incrementally using the "online softmax trick" -- maintaining a running maximum and denominator that are updated as each K block is processed.
- Recomputation in backward pass: Instead of storing the attention matrix for the backward pass, recomputing it from Q, K, V using the stored logsumexp values.
This reduces memory complexity from O(S^2) to O(S) while maintaining exact numerical equivalence (no approximation).
Memory and Compute Trade-offs
| Approach | Memory (Attention Matrix) | Kernel Launches | HBM Reads/Writes |
|---|---|---|---|
| Unfused (standard) | O(B * H * S_q * S_kv) | 5-6 per attention layer | Each intermediate written to HBM |
| Flash Attention | O(B * H * S_q) (logsumexp only) | 1 (fused) | Only Q, K, V read and output written |
| cuDNN Fused | O(B * H * S_q) | 1 (fused) | Only Q, K, V read and output written |
Usage
Use fused attention when:
- Training or running inference on Transformer models where attention is a memory or compute bottleneck, especially with long sequences.
- Working with long sequences: The O(S^2) memory of unfused attention becomes prohibitive for sequences longer than a few thousand tokens; fused attention scales to sequences of 16K, 64K, or longer.
- Maximizing GPU utilization: Fused kernels reduce the number of global memory round-trips and kernel launches, improving both latency and throughput.
- Using FP8 training: cuDNN fused attention supports FP8 quantization of the attention pipeline, providing additional throughput gains on Hopper GPUs.
- Distributing long sequences across GPUs: Context parallelism (CP) support enables splitting the sequence dimension across multiple GPUs with efficient communication patterns.
Fused attention is typically used as part of a te.TransformerLayer or te.MultiheadAttention module, but can also be used standalone via te.DotProductAttention for custom attention implementations.