Implementation:NVIDIA TransformerEngine Fused Attn C API
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Declares the comprehensive C API for fused multi-head attention operations, including enums for QKV layouts, mask types, bias types, softmax variants, and backend selection, plus forward/backward functions for both FP16/BF16 and FP8 fused attention.
Description
fused_attn.h is the most critical performance header in TransformerEngine. It defines extensive enums:
- NVTE_QKV_Layout: 25 layout variants including SBHD, BSHD, THD, and paged-KV formats
- NVTE_Mask_Type: No mask, padding, causal (top-left and bottom-right aligned)
- NVTE_Bias_Type: None, pre/post-scale bias, ALiBi
- NVTE_Softmax_Type: Vanilla, off-by-one, learnable
- NVTE_Fused_Attn_Backend: cuDNN-based F16 and FP8 backends
Main API functions: nvte_fused_attn_fwd and nvte_fused_attn_bwd with parameters for Q/K/V tensors, attention configuration, dropout, sliding window sizes, and workspace management. Supports variable-length sequences, GQA/MQA, paged KV caches for inference, and FP8 attention via cuDNN backends.
Usage
Use for all fused attention operations in TransformerEngine. The backend is automatically selected based on data type, sequence length, and GPU capability.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/include/transformer_engine/fused_attn.h- Lines
- 1--875
Signature
enum NVTE_QKV_Layout { NVTE_SB3HD = 0, NVTE_BS3HD = 5, NVTE_T3HD = 10, ... };
enum NVTE_QKV_Layout_Group { NVTE_3HD = 0, NVTE_H3D = 1, ... };
enum NVTE_Bias_Type { NVTE_NO_BIAS = 0, NVTE_PRE_SCALE_BIAS = 1, ... };
enum NVTE_Mask_Type { NVTE_NO_MASK = 0, NVTE_CAUSAL_MASK = 2, ... };
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout);
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout);
void nvte_fused_attn_fwd(...);
void nvte_fused_attn_bwd(...);
Import
#include "transformer_engine/fused_attn.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
Q, K, V |
NVTETensor |
Yes | Query, Key, Value tensors |
Bias |
NVTETensor |
No | Attention bias tensor |
qkv_layout |
NVTE_QKV_Layout |
Yes | QKV memory layout |
mask_type |
NVTE_Mask_Type |
Yes | Attention mask type |
attn_scale |
float |
Yes | Attention scaling factor (typically 1/sqrt(d)) |
p_dropout |
float |
Yes | Dropout probability |
Outputs
| Name | Type | Description |
|---|---|---|
O |
NVTETensor |
Attention output |
Aux_CTX_Tensors |
NVTETensorPack* |
Auxiliary context tensors for backward |
Usage Examples
#include "transformer_engine/fused_attn.h"
// Forward pass with causal masking
nvte_fused_attn_fwd(Q, K, V, Bias, S, O, &aux_ctx,
cu_seqlens_q, cu_seqlens_kv, rng_state,
max_seqlen_q, max_seqlen_kv,
/*is_training=*/true, attn_scale, dropout_prob,
NVTE_BS3HD, NVTE_NO_BIAS, NVTE_CAUSAL_MASK,
/*window_left=*/-1, /*window_right=*/0,
workspace, stream);