Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE DotProductAttention

From Leeroopedia


Field Value
Sources TransformerEngine, Flash Attention, Attention Is All You Need
Domains Deep_Learning, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

te.DotProductAttention is a concrete tool for fused scaled dot-product attention provided by NVIDIA's TransformerEngine library. It computes multi-head attention using cuDNN and Flash Attention backends with support for Grouped Query Attention, context parallelism, multiple QKV formats, FP8 quantization, and various attention mask types.

Description

te.DotProductAttention computes scaled dot-product attention:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

The class inherits from TransformerEngineBaseModule and automatically selects between cuDNN fused attention and Flash Attention backends based on hardware capabilities, input dimensions, and configuration. It supports:

  • Grouped Query Attention (GQA): The num_gqa_groups parameter controls how many key/value heads are used. When num_gqa_groups < num_attention_heads, multiple query heads share the same KV head, reducing KV cache memory during inference and KV computation during training.
  • Multiple QKV formats: "sbhd" (sequence, batch, head, dim), "bshd" (batch, sequence, head, dim), and "thd" (token, head, dim) for variable-length/unpadded sequences.
  • Attention mask types: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "padding_causal_bottom_right", and "arbitrary".
  • Context parallelism (CP): Distributes the sequence dimension across multiple GPUs using P2P ring communication ("p2p"), all-gather ("all_gather"), all-to-all ("a2a"), or hierarchical ("a2a+p2p") strategies.
  • FP8 attention: When used inside an FP8 autocast context with fp8_dpa=True, the QKV tensors, attention scores, and output are quantized to FP8 for additional throughput on Hopper GPUs. Supports both delayed scaling and current scaling FP8 recipes.
  • Sliding window attention: The window_size parameter enables local attention windows for causal mask types.
  • Softmax variants: Standard (vanilla) softmax and custom softmax types via the softmax_type parameter.

The backend selection is automatic: cuDNN fused attention is preferred when available and supported for the given configuration, with Flash Attention as a fallback.

Usage

Import te.DotProductAttention when building custom attention modules that need fused attention kernels, or when constructing Transformer layers outside of the higher-level te.TransformerLayer API. It is most commonly used as a component inside te.MultiheadAttention.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py
Class
DotProductAttention
Lines
__init__ at L323--346

Signature

class DotProductAttention(TransformerEngineBaseModule):
    def __init__(
        self,
        num_attention_heads: int,
        kv_channels: Union[int, Tuple[int, int]],
        num_gqa_groups: Optional[int] = None,
        attention_dropout: float = 0.0,
        qkv_format: str = "sbhd",
        attn_mask_type: str = "causal",
        window_size: Optional[Tuple[int, int]] = None,
        bottom_right_diagonal: Optional[bool] = None,
        sequence_parallel: bool = False,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        layer_number: Optional[int] = None,
        attention_type: str = "self",
        cp_group: Optional[Union[dist_group_type, List[dist_group_type]]] = None,
        cp_global_ranks: List[int] = None,
        cp_stream: torch.cuda.Stream = None,
        cp_comm_type: str = "p2p",
        softmax_scale: Optional[float] = None,
        softmax_type: str = "vanilla",
        return_max_logit: Optional[bool] = False,
    ) -> None:

Import

from transformer_engine.pytorch.attention import DotProductAttention

# or equivalently:
import transformer_engine.pytorch as te
te.DotProductAttention

I/O Contract

Inputs

Name Type Required Description
query_layer torch.Tensor Yes Query tensor; shape depends on qkv_format: [S, B, H, D] for "sbhd", [B, S, H, D] for "bshd", [T, H, D] for "thd"
key_layer torch.Tensor Yes Key tensor; same layout as query but with H_kv heads (num_gqa_groups)
value_layer torch.Tensor Yes Value tensor; same layout as key; head dimension may differ from key if kv_channels is a tuple
attention_mask torch.Tensor / None No Attention mask; required when attn_mask_type includes "padding" or is "arbitrary"

Outputs

Name Type Description
output torch.Tensor Attention output; same layout as query with head dimension equal to value head dimension (D_v)

Key Parameters

Parameter Type Default Description
num_attention_heads int required Number of query attention heads
kv_channels int or (int, int) required Head dimension for key and value; if a tuple, first element is key dim and second is value dim
num_gqa_groups int / None None (= num_attention_heads) Number of GQA groups (key/value heads); GQA-1 is Multi-Query Attention, GQA-H is standard MHA
attention_dropout float 0.0 Dropout probability applied to the attention weights after softmax
attn_mask_type str "causal" Attention mask type: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "padding_causal_bottom_right", "arbitrary"
qkv_format str "sbhd" QKV tensor layout: "sbhd" (sequence-first), "bshd" (batch-first), "thd" (packed variable-length tokens)
softmax_scale float / None None (= 1/sqrt(d_k)) Scaling factor applied to QK^T before softmax; defaults to inverse square root of key head dimension
attention_type str "self" "self" for self-attention, "cross" for cross-attention
cp_group ProcessGroup / List[ProcessGroup] / None None Context parallelism process group(s); a single group for "p2p"/"all_gather"/"a2a", a list of two groups for "a2a+p2p"
cp_comm_type str "p2p" Context parallelism communication strategy: "p2p", "all_gather", "a2a", or "a2a+p2p"
window_size (int, int) / None None Sliding window size for local attention with causal mask types; tuple of (left_window, right_window)
layer_number int / None None Layer index for debugging and logging purposes

Usage Examples

Basic Self-Attention

import torch
import transformer_engine.pytorch as te

# Create fused dot-product attention module
dpa = te.DotProductAttention(
    num_attention_heads=12,
    kv_channels=64,
    attention_dropout=0.1,
    attn_mask_type="causal",
    qkv_format="bshd",
)

# Input shapes: [batch, seq_len, heads, head_dim]
batch_size, seq_len, num_heads, head_dim = 8, 512, 12, 64
query = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
value = torch.randn(batch_size, seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)

output = dpa(query, key, value)
# output shape: [8, 512, 12, 64]

Grouped Query Attention

import transformer_engine.pytorch as te

# GQA with 32 query heads and 8 KV heads (4:1 ratio)
dpa = te.DotProductAttention(
    num_attention_heads=32,
    kv_channels=128,
    num_gqa_groups=8,
    attn_mask_type="causal",
    qkv_format="bshd",
)

# Query has 32 heads, Key/Value have 8 heads
query = torch.randn(4, 2048, 32, 128, device="cuda", dtype=torch.bfloat16)
key = torch.randn(4, 2048, 8, 128, device="cuda", dtype=torch.bfloat16)
value = torch.randn(4, 2048, 8, 128, device="cuda", dtype=torch.bfloat16)

output = dpa(query, key, value)

With Context Parallelism

import torch.distributed as dist
import transformer_engine.pytorch as te

# Assuming cp_group is a process group for context parallelism
dpa = te.DotProductAttention(
    num_attention_heads=16,
    kv_channels=64,
    attn_mask_type="causal",
    qkv_format="sbhd",
    cp_group=cp_group,
    cp_global_ranks=list(range(dist.get_world_size())),
    cp_comm_type="p2p",
)

With FP8 Autocast

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling

dpa = te.DotProductAttention(
    num_attention_heads=16,
    kv_channels=64,
    attn_mask_type="causal",
)

# FP8 attention with delayed scaling recipe
fp8_recipe = DelayedScaling(fp8_dpa=True)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    output = dpa(query, key, value)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment