Implementation:NVIDIA TransformerEngine Cpp Fused Attn
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Python interface to the cuDNN-backed fused attention C++ extensions, providing fused_attn_fwd and fused_attn_bwd functions.
Description
Defines Python-to-C++ enum mappings for QKV formats (bshd, sbhd, thd and conversion variants), QKV layouts (all permutations like sb3hd, bs3hd, t3hd, paged_kv variants), attention bias types (no_bias, pre/post_scale_bias, alibi), mask types (no_mask, padding, causal, padding_causal, causal_bottom_right), softmax types (vanilla, off-by-one, learnable), and fused attention backends (F16_max512, F16_arbitrary, FP8). Also defines FP8 metadata tensor indices (META_QKV, META_O, META_S, META_DP, META_DQKV, META_DO). The fused_attn_fwd function marshals all attention parameters and calls tex.fused_attn_fwd with proper type conversion. fused_attn_bwd does the same for the backward pass. Both handle optional FP8 quantizers, page tables, sliding windows, and CUDA graph support.
Usage
Used as the direct Python binding to cuDNN's fused multi-head attention kernels. All fused attention paths in the backends module call through these two functions.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/cpp_extensions/fused_attn.py- Lines
- 1--544
Signature
def fused_attn_fwd(
is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
Q, K, V, qkv_dtype, fused_attention_backend, attn_bias, ...
): ...
def fused_attn_bwd(
max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
Q, K, V, O, dO, qkv_dtype, fused_attention_backend, ...
): ...
Import
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
FusedAttnBackend,
META_QKV,
META_O,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| is_training | bool |
Yes | Whether in training mode (affects dropout) |
| max_seqlen_q | int |
Yes | Maximum query sequence length |
| max_seqlen_kv | int |
Yes | Maximum key/value sequence length |
| cu_seqlens_q | torch.Tensor |
Yes | Cumulative sequence lengths for queries |
| cu_seqlens_kv | torch.Tensor |
Yes | Cumulative sequence lengths for keys/values |
| Q | torch.Tensor |
Yes | Query tensor |
| K | torch.Tensor |
Yes | Key tensor |
| V | torch.Tensor |
Yes | Value tensor |
| qkv_dtype | TE_DType |
Yes | Data type for QKV tensors |
| fused_attention_backend | FusedAttnBackend |
Yes | Backend selection (F16_max512, F16_arbitrary, FP8) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Attention output tensor |
| softmax_lse | torch.Tensor |
Log-sum-exp of softmax for backward pass |
| rng_state | torch.Tensor |
RNG state for dropout reproducibility |
Usage Examples
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
FusedAttnBackend,
)
output, softmax_lse, rng_state = fused_attn_fwd(
is_training=True,
max_seqlen_q=512,
max_seqlen_kv=512,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
Q=query, K=key, V=value,
qkv_dtype=qkv_dtype,
fused_attention_backend=FusedAttnBackend["FP8"],
)