Implementation:NVIDIA TransformerEngine PyTorch Ext Attention
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Attention, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements fused multi-head attention forward and backward passes in C++, along with helper functions for attention tensor format conversion and KV cache management.
Description
fused_attn_fwd and fused_attn_bwd wrap nvte_fused_attn_fwd/nvte_fused_attn_bwd CUDA kernels, handling separate Q/K/V tensors, quantizers for S (softmax intermediate) and O (output), optional bias, masks, paged KV caches, sliding window attention, dropout, and CUDA graph compatibility. Includes quantizer_helper to create appropriately typed tensor wrappers for different quantization modes (none, delayed FP8, current scaling FP8). Also provides mha_fill for fast GPU zero-fill, format converters (convert_thd_to_bshd, convert_bshd_to_thd), KV cache copy (copy_to_kv_cache), and FlashAttention preparation helpers (fa_prepare_fwd/bwd).
Usage
The largest and most complex extension file -- fused attention is the primary performance-critical operation in Transformers. This supports all major attention variants (MHA, GQA, MQA) with FP8 quantization.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/attention.cpp- Lines
- 1--873
Signature
namespace transformer_engine::pytorch {
py::object fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, int qkv_layout,
int bias_type, int mask_type, int softmax_type,
at::Tensor Q, at::Tensor K, at::Tensor V,
py::handle S_quantizer, py::handle O_quantizer, ...);
std::vector<at::Tensor> fused_attn_bwd(
size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, int qkv_layout,
at::Tensor Q, at::Tensor K, at::Tensor V,
at::Tensor O, at::Tensor dO, ...);
void mha_fill(at::Tensor &self, at::Scalar value);
at::Tensor convert_thd_to_bshd(at::Tensor input, at::Tensor cu_seqlens, int max_seqlen);
at::Tensor convert_bshd_to_thd(at::Tensor input, at::Tensor cu_seqlens, int max_seqlen);
void copy_to_kv_cache(at::Tensor kv_cache, at::Tensor new_kv, ...);
}
Import
#include "../extensions.h"
#include "common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| Q | at::Tensor |
Yes | Query tensor |
| K | at::Tensor |
Yes | Key tensor |
| V | at::Tensor |
Yes | Value tensor |
| max_seqlen_q | size_t |
Yes | Maximum query sequence length |
| max_seqlen_kv | size_t |
Yes | Maximum key/value sequence length |
| attn_scale | float |
Yes | Attention scaling factor |
| S_quantizer | py::handle |
No | Quantizer for softmax intermediate |
| O_quantizer | py::handle |
No | Quantizer for attention output |
| bias | at::Tensor |
No | Optional attention bias |
Outputs
| Name | Type | Description |
|---|---|---|
| output | py::object |
Attention output (possibly quantized) |
| softmax_lse | at::Tensor |
Log-sum-exp of softmax |
| rng_state | at::Tensor |
RNG state for dropout reproducibility |
Usage Examples
import transformer_engine_torch as tex
# Called internally by the fused_attn Python wrappers
output, softmax_lse, rng_state = tex.fused_attn_fwd(
max_seqlen_q, max_seqlen_kv, True,
attn_scale, dropout, qkv_layout,
bias_type, mask_type, softmax_type,
Q, K, V, S_quantizer, O_quantizer, ...
)