Implementation:NVIDIA TransformerEngine TE DotProductAttention
| Field | Value |
|---|---|
| Sources | TransformerEngine, Flash Attention, Attention Is All You Need |
| Domains | Deep_Learning, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
te.DotProductAttention is a concrete tool for fused scaled dot-product attention provided by NVIDIA's TransformerEngine library. It computes multi-head attention using cuDNN and Flash Attention backends with support for Grouped Query Attention, context parallelism, multiple QKV formats, FP8 quantization, and various attention mask types.
Description
te.DotProductAttention computes scaled dot-product attention:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
The class inherits from TransformerEngineBaseModule and automatically selects between cuDNN fused attention and Flash Attention backends based on hardware capabilities, input dimensions, and configuration. It supports:
- Grouped Query Attention (GQA): The
num_gqa_groupsparameter controls how many key/value heads are used. Whennum_gqa_groups < num_attention_heads, multiple query heads share the same KV head, reducing KV cache memory during inference and KV computation during training. - Multiple QKV formats:
"sbhd"(sequence, batch, head, dim),"bshd"(batch, sequence, head, dim), and"thd"(token, head, dim) for variable-length/unpadded sequences. - Attention mask types:
"no_mask","padding","causal","padding_causal","causal_bottom_right","padding_causal_bottom_right", and"arbitrary". - Context parallelism (CP): Distributes the sequence dimension across multiple GPUs using P2P ring communication (
"p2p"), all-gather ("all_gather"), all-to-all ("a2a"), or hierarchical ("a2a+p2p") strategies. - FP8 attention: When used inside an FP8 autocast context with
fp8_dpa=True, the QKV tensors, attention scores, and output are quantized to FP8 for additional throughput on Hopper GPUs. Supports both delayed scaling and current scaling FP8 recipes. - Sliding window attention: The
window_sizeparameter enables local attention windows for causal mask types. - Softmax variants: Standard (vanilla) softmax and custom softmax types via the
softmax_typeparameter.
The backend selection is automatic: cuDNN fused attention is preferred when available and supported for the given configuration, with Flash Attention as a fallback.
Usage
Import te.DotProductAttention when building custom attention modules that need fused attention kernels, or when constructing Transformer layers outside of the higher-level te.TransformerLayer API. It is most commonly used as a component inside te.MultiheadAttention.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py- Class
DotProductAttention- Lines
- __init__ at L323--346
Signature
class DotProductAttention(TransformerEngineBaseModule):
def __init__(
self,
num_attention_heads: int,
kv_channels: Union[int, Tuple[int, int]],
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
attention_type: str = "self",
cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
cp_global_ranks: List[int] = None,
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
) -> None:
Import
from transformer_engine.pytorch.attention import DotProductAttention
# or equivalently:
import transformer_engine.pytorch as te
te.DotProductAttention
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
query_layer |
torch.Tensor |
Yes | Query tensor; shape depends on qkv_format: [S, B, H, D] for "sbhd", [B, S, H, D] for "bshd", [T, H, D] for "thd"
|
key_layer |
torch.Tensor |
Yes | Key tensor; same layout as query but with H_kv heads (num_gqa_groups)
|
value_layer |
torch.Tensor |
Yes | Value tensor; same layout as key; head dimension may differ from key if kv_channels is a tuple
|
attention_mask |
torch.Tensor / None |
No | Attention mask; required when attn_mask_type includes "padding" or is "arbitrary"
|
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Attention output; same layout as query with head dimension equal to value head dimension (D_v) |
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
num_attention_heads |
int | required | Number of query attention heads |
kv_channels |
int or (int, int) | required | Head dimension for key and value; if a tuple, first element is key dim and second is value dim |
num_gqa_groups |
int / None | None (= num_attention_heads) |
Number of GQA groups (key/value heads); GQA-1 is Multi-Query Attention, GQA-H is standard MHA |
attention_dropout |
float | 0.0 |
Dropout probability applied to the attention weights after softmax |
attn_mask_type |
str | "causal" |
Attention mask type: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "padding_causal_bottom_right", "arbitrary"
|
qkv_format |
str | "sbhd" |
QKV tensor layout: "sbhd" (sequence-first), "bshd" (batch-first), "thd" (packed variable-length tokens)
|
softmax_scale |
float / None | None (= 1/sqrt(d_k)) |
Scaling factor applied to QK^T before softmax; defaults to inverse square root of key head dimension |
attention_type |
str | "self" |
"self" for self-attention, "cross" for cross-attention
|
cp_group |
ProcessGroup / List[ProcessGroup] / None | None |
Context parallelism process group(s); a single group for "p2p"/"all_gather"/"a2a", a list of two groups for "a2a+p2p"
|
cp_comm_type |
str | "p2p" |
Context parallelism communication strategy: "p2p", "all_gather", "a2a", or "a2a+p2p"
|
window_size |
(int, int) / None | None |
Sliding window size for local attention with causal mask types; tuple of (left_window, right_window) |
layer_number |
int / None | None |
Layer index for debugging and logging purposes |
Usage Examples
Basic Self-Attention
import torch
import transformer_engine.pytorch as te
# Create fused dot-product attention module
dpa = te.DotProductAttention(
num_attention_heads=12,
kv_channels=64,
attention_dropout=0.1,
attn_mask_type="causal",
qkv_format="bshd",
)
# Input shapes: [batch, seq_len, heads, head_dim]
batch_size, seq_len, num_heads, head_dim = 8, 512, 12, 64
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
value = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
output = dpa(query, key, value)
# output shape: [8, 512, 12, 64]
Grouped Query Attention
import transformer_engine.pytorch as te
# GQA with 32 query heads and 8 KV heads (4:1 ratio)
dpa = te.DotProductAttention(
num_attention_heads=32,
kv_channels=128,
num_gqa_groups=8,
attn_mask_type="causal",
qkv_format="bshd",
)
# Query has 32 heads, Key/Value have 8 heads
query = torch.randn(4, 2048, 32, 128, device="cuda", dtype=torch.bfloat16)
key = torch.randn(4, 2048, 8, 128, device="cuda", dtype=torch.bfloat16)
value = torch.randn(4, 2048, 8, 128, device="cuda", dtype=torch.bfloat16)
output = dpa(query, key, value)
With Context Parallelism
import torch.distributed as dist
import transformer_engine.pytorch as te
# Assuming cp_group is a process group for context parallelism
dpa = te.DotProductAttention(
num_attention_heads=16,
kv_channels=64,
attn_mask_type="causal",
qkv_format="sbhd",
cp_group=cp_group,
cp_global_ranks=list(range(dist.get_world_size())),
cp_comm_type="p2p",
)
With FP8 Autocast
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
dpa = te.DotProductAttention(
num_attention_heads=16,
kv_channels=64,
attn_mask_type="causal",
)
# FP8 attention with delayed scaling recipe
fp8_recipe = DelayedScaling(fp8_dpa=True)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = dpa(query, key, value)