Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX XLA Attention

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Implements XLA FFI handlers for fused multi-head attention forward and backward passes, along with workspace size query functions and auxiliary tensor preparation, enabling efficient attention computation from JAX.

Description

GetFusedAttnBackend queries TE for the appropriate backend (max512 or arbitrary sequence length). PrepareFusedAttnForwardAuxTensors builds an NVTETensorPack with softmax, RNG state, bias, and softmax offset tensors shaped according to the chosen backend. PrepareFusedAttnBackwardAuxTensors reuses the forward logic with dummy parameters to pack all auxiliary tensors. GetFusedAttnForwardWorkspaceSizes and GetFusedAttnBackwardWorkspaceSizes create dummy tensor wrappers and call nvte_fused_attn_fwd/nvte_fused_attn_bwd to determine required workspace sizes. The actual FFI forward/backward handlers construct full tensor wrappers with Q, K, V, bias, and sequence-length descriptors and dispatch to TE's fused attention kernels.

This is a critical performance component that provides fused attention with cuDNN backend support, enabling memory-efficient multi-head attention with configurable masking, bias types, QKV layouts, and dropout for JAX-based transformer models.

Usage

This C++ extension is invoked internally by the Python-side FusedAttnFwdPrimitive and FusedAttnBwdPrimitive in transformer_engine.jax.cpp_extensions.attention. Users do not call these FFI handlers directly.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/csrc/extensions/attention.cpp
Lines
1--670

Signature

namespace transformer_engine {
namespace jax {

NVTE_Fused_Attn_Backend GetFusedAttnBackend(
    bool is_training, DType q_dtype, DType kv_dtype,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
    float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads,
    size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim,
    size_t v_head_dim, int64_t window_size_left, int64_t window_size_right,
    bool deterministic);

void PrepareFusedAttnForwardAuxTensors(
    NVTETensorPack *tensor_pack, const size_t input_batch,
    const size_t bias_batch, const size_t attn_heads,
    const size_t bias_heads, const size_t q_max_seqlen,
    const size_t kv_max_seqlen, DType dtype,
    NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
    void *softmax_buf, void *rng_state_buf, void *bias_buf,
    void *softmax_offset_buf);

pybind11::tuple GetFusedAttnForwardWorkspaceSizes(...);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(...);

} // namespace jax
} // namespace transformer_engine

Import

#include "../extensions.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

I/O Contract

Inputs

Name Type Required Description
q_buf Buffer_Type Yes Query tensor buffer
k_buf Buffer_Type Yes Key tensor buffer
v_buf Buffer_Type Yes Value tensor buffer
bias_buf Buffer_Type No Optional attention bias buffer
cu_seqlen_q Buffer_Type Yes Cumulative query sequence lengths
cu_seqlen_kv Buffer_Type Yes Cumulative key/value sequence lengths
seed_buf Buffer_Type No RNG seed for dropout
qkv_layout NVTE_QKV_Layout Yes QKV memory layout
bias_type NVTE_Bias_Type Yes Bias type enum
mask_type NVTE_Mask_Type Yes Attention mask type enum
scaling_factor float Yes Attention scaling factor
dropout_probability float Yes Dropout probability

Outputs

Name Type Description
output_buf Result_Type Attention output tensor
softmax_aux Result_Type Softmax auxiliary data for backward pass
rng_state Result_Type RNG state for backward pass dropout

Usage Examples

// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
//   from transformer_engine.jax.attention import fused_attn
//   output = fused_attn(qkv=(q, k, v), ...)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment