Implementation:NVIDIA TransformerEngine Fused Attn F16 Max512
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Header declaring the forward and backward functions for fused attention with FP16/BF16 data types optimized for short sequences (up to 512 tokens), using a specialized cuDNN execution path.
Description
fused_attn_f16_max512_seqlen.h declares a simpler fused attention interface optimized for short sequences:
- fused_attn_max_512_fwd: Forward pass with a more compact parameter set than the arbitrary-length variant (no sliding window, no paged KV, no GQA groups, no softmax offset).
- fused_attn_max_512_bwd: Backward pass computing gradients for Q, K, V, and optional bias.
Guarded by CUDNN_VERSION >= 8901. The short-sequence specialization enables cuDNN to use more efficient memory access patterns than the general Flash Attention backend.
Usage
Used by the fused attention dispatch layer when the sequence length is 512 or less and the data type is FP16 or BF16.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h- Lines
- 1--43
Signature
namespace transformer_engine {
#if (CUDNN_VERSION >= 8901)
void fused_attn_max_512_fwd(
size_t batch, size_t num_head,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim,
bool is_training, float attn_scale, float p_dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_Bias, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors,
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_max_512_bwd(...);
#endif
} // namespace transformer_engine
Import
#include "fused_attn/fused_attn_f16_max512_seqlen.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
input_Q |
const Tensor* |
Yes | Query tensor |
input_K |
const Tensor* |
Yes | Key tensor |
input_V |
const Tensor* |
Yes | Value tensor |
batch |
size_t |
Yes | Batch size |
num_head |
size_t |
Yes | Number of attention heads |
head_dim |
size_t |
Yes | Head dimension |
Outputs
| Name | Type | Description |
|---|---|---|
output_O |
Tensor* |
Attention output |
Aux_CTX_Tensors |
NVTETensorPack* |
Auxiliary context for backward |
Usage Examples
// Called internally by the fused_attn dispatch layer
fused_attn_max_512_fwd(batch, num_heads, q_seqlen, kv_seqlen,
head_dim, is_training, attn_scale, dropout,
qkv_layout, bias_type, mask_type,
Q, K, V, Bias, O, &aux, cu_seqlens_q, cu_seqlens_kv,
rng_state, workspace, stream, handle);