Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:NVIDIA TransformerEngine Fused Attention

From Leeroopedia


Metadata

Field Value
Page Type Principle
Knowledge Sources Paper (Flash Attention), Paper (Attention Is All You Need), Repo (TransformerEngine)
Domains Deep_Learning, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Computing scaled dot-product attention using fused GPU kernels that combine the QK^T multiplication, scaling, masking, softmax, dropout, and value projection into a single kernel launch, avoiding materialization of the full attention matrix for improved performance and memory efficiency.

Description

Fused attention combines the entire scaled dot-product attention computation -- from query-key dot products through to the final value-weighted sum -- into a single fused GPU kernel call. This eliminates the need to write the full [batch, heads, seq_len, seq_len] attention score matrix to global memory, which is the primary memory bottleneck for long-sequence Transformer models.

Standard (Unfused) Attention

A naive implementation of scaled dot-product attention proceeds in discrete steps, each requiring a separate kernel launch and global memory round-trip:

  1. QK^T computation: Matrix multiply query and key tensors to produce attention scores of shape [B, H, S_q, S_kv].
  2. Scaling: Divide by sqrt(d_k).
  3. Masking: Apply causal mask, padding mask, or arbitrary attention mask.
  4. Softmax: Compute softmax along the key dimension.
  5. Dropout: Apply attention dropout.
  6. Value projection: Matrix multiply the attention weights with value tensors.

Each intermediate tensor (S, P) has O(n^2) elements in the sequence length, leading to significant memory consumption and bandwidth waste.

Fused Attention Strategy

Fused attention implementations avoid materializing the full attention matrix by processing attention in tiles or blocks:

  • Flash Attention: Processes Q, K, V in blocks, computing partial softmax results in SRAM (shared memory / registers) and accumulating the output incrementally. Only the final output tensor and a small amount of bookkeeping data (logsumexp per row) are written to global memory.
  • cuDNN fused attention: NVIDIA's cuDNN library provides a fused multi-head attention kernel optimized for specific GPU architectures, handling the full attention pipeline including FP8 quantization of intermediate tensors.

Supported Features

TransformerEngine's fused attention supports:

  • Multiple backends: cuDNN fused attention and Flash Attention, selected automatically based on hardware capabilities, sequence length, and configuration.
  • Attention mask types: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "padding_causal_bottom_right", and "arbitrary".
  • Grouped Query Attention (GQA): Different numbers of query heads and key/value heads, where multiple query heads share the same KV head.
  • QKV formats: "sbhd" (sequence, batch, head, dim), "bshd" (batch, sequence, head, dim), and "thd" (token, head, dim) for variable-length sequences.
  • Context parallelism: Distributing the sequence dimension across multiple GPUs with P2P, all-gather, or all-to-all communication patterns.
  • FP8 attention: Quantizing the QKV inputs, attention scores, and output to FP8 for additional throughput on Hopper GPUs.

Theoretical Basis

Scaled Dot-Product Attention

The fundamental attention operation from the "Attention Is All You Need" paper:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

where:

  • Q is the query tensor of shape [B, H_q, S_q, D].
  • K is the key tensor of shape [B, H_kv, S_kv, D].
  • V is the value tensor of shape [B, H_kv, S_kv, D_v].
  • d_k is the key head dimension.
  • The output has shape [B, H_q, S_q, D_v].

Flash Attention Algorithm

Flash Attention (Dao et al., 2022) avoids materializing the [S_q, S_kv] attention matrix by:

  1. Tiling: Dividing Q into blocks of rows and K, V into blocks of columns.
  2. Online softmax: Computing softmax incrementally using the "online softmax trick" -- maintaining a running maximum and denominator that are updated as each K block is processed.
  3. Recomputation in backward pass: Instead of storing the attention matrix for the backward pass, recomputing it from Q, K, V using the stored logsumexp values.

This reduces memory complexity from O(S^2) to O(S) while maintaining exact numerical equivalence (no approximation).

Memory and Compute Trade-offs

Approach Memory (Attention Matrix) Kernel Launches HBM Reads/Writes
Unfused (standard) O(B * H * S_q * S_kv) 5-6 per attention layer Each intermediate written to HBM
Flash Attention O(B * H * S_q) (logsumexp only) 1 (fused) Only Q, K, V read and output written
cuDNN Fused O(B * H * S_q) 1 (fused) Only Q, K, V read and output written

Usage

Use fused attention when:

  • Training or running inference on Transformer models where attention is a memory or compute bottleneck, especially with long sequences.
  • Working with long sequences: The O(S^2) memory of unfused attention becomes prohibitive for sequences longer than a few thousand tokens; fused attention scales to sequences of 16K, 64K, or longer.
  • Maximizing GPU utilization: Fused kernels reduce the number of global memory round-trips and kernel launches, improving both latency and throughput.
  • Using FP8 training: cuDNN fused attention supports FP8 quantization of the attention pipeline, providing additional throughput gains on Hopper GPUs.
  • Distributing long sequences across GPUs: Context parallelism (CP) support enables splitting the sequence dimension across multiple GPUs with efficient communication patterns.

Fused attention is typically used as part of a te.TransformerLayer or te.MultiheadAttention module, but can also be used standalone via te.DotProductAttention for custom attention implementations.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment