Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Ext Attention

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Attention, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements fused multi-head attention forward and backward passes in C++, along with helper functions for attention tensor format conversion and KV cache management.

Description

fused_attn_fwd and fused_attn_bwd wrap nvte_fused_attn_fwd/nvte_fused_attn_bwd CUDA kernels, handling separate Q/K/V tensors, quantizers for S (softmax intermediate) and O (output), optional bias, masks, paged KV caches, sliding window attention, dropout, and CUDA graph compatibility. Includes quantizer_helper to create appropriately typed tensor wrappers for different quantization modes (none, delayed FP8, current scaling FP8). Also provides mha_fill for fast GPU zero-fill, format converters (convert_thd_to_bshd, convert_bshd_to_thd), KV cache copy (copy_to_kv_cache), and FlashAttention preparation helpers (fa_prepare_fwd/bwd).

Usage

The largest and most complex extension file -- fused attention is the primary performance-critical operation in Transformers. This supports all major attention variants (MHA, GQA, MQA) with FP8 quantization.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/attention.cpp
Lines
1--873

Signature

namespace transformer_engine::pytorch {

py::object fused_attn_fwd(
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
    float attn_scale, float dropout, int qkv_layout,
    int bias_type, int mask_type, int softmax_type,
    at::Tensor Q, at::Tensor K, at::Tensor V,
    py::handle S_quantizer, py::handle O_quantizer, ...);

std::vector<at::Tensor> fused_attn_bwd(
    size_t max_seqlen_q, size_t max_seqlen_kv,
    float attn_scale, float dropout, int qkv_layout,
    at::Tensor Q, at::Tensor K, at::Tensor V,
    at::Tensor O, at::Tensor dO, ...);

void mha_fill(at::Tensor &self, at::Scalar value);
at::Tensor convert_thd_to_bshd(at::Tensor input, at::Tensor cu_seqlens, int max_seqlen);
at::Tensor convert_bshd_to_thd(at::Tensor input, at::Tensor cu_seqlens, int max_seqlen);
void copy_to_kv_cache(at::Tensor kv_cache, at::Tensor new_kv, ...);

}

Import

#include "../extensions.h"
#include "common.h"

I/O Contract

Inputs

Name Type Required Description
Q at::Tensor Yes Query tensor
K at::Tensor Yes Key tensor
V at::Tensor Yes Value tensor
max_seqlen_q size_t Yes Maximum query sequence length
max_seqlen_kv size_t Yes Maximum key/value sequence length
attn_scale float Yes Attention scaling factor
S_quantizer py::handle No Quantizer for softmax intermediate
O_quantizer py::handle No Quantizer for attention output
bias at::Tensor No Optional attention bias

Outputs

Name Type Description
output py::object Attention output (possibly quantized)
softmax_lse at::Tensor Log-sum-exp of softmax
rng_state at::Tensor RNG state for dropout reproducibility

Usage Examples

import transformer_engine_torch as tex

# Called internally by the fused_attn Python wrappers
output, softmax_lse, rng_state = tex.fused_attn_fwd(
    max_seqlen_q, max_seqlen_kv, True,
    attn_scale, dropout, qkv_layout,
    bias_type, mask_type, softmax_type,
    Q, K, V, S_quantizer, O_quantizer, ...
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment