Implementation:NVIDIA TransformerEngine JAX XLA Attention
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements XLA FFI handlers for fused multi-head attention forward and backward passes, along with workspace size query functions and auxiliary tensor preparation, enabling efficient attention computation from JAX.
Description
GetFusedAttnBackend queries TE for the appropriate backend (max512 or arbitrary sequence length). PrepareFusedAttnForwardAuxTensors builds an NVTETensorPack with softmax, RNG state, bias, and softmax offset tensors shaped according to the chosen backend. PrepareFusedAttnBackwardAuxTensors reuses the forward logic with dummy parameters to pack all auxiliary tensors. GetFusedAttnForwardWorkspaceSizes and GetFusedAttnBackwardWorkspaceSizes create dummy tensor wrappers and call nvte_fused_attn_fwd/nvte_fused_attn_bwd to determine required workspace sizes. The actual FFI forward/backward handlers construct full tensor wrappers with Q, K, V, bias, and sequence-length descriptors and dispatch to TE's fused attention kernels.
This is a critical performance component that provides fused attention with cuDNN backend support, enabling memory-efficient multi-head attention with configurable masking, bias types, QKV layouts, and dropout for JAX-based transformer models.
Usage
This C++ extension is invoked internally by the Python-side FusedAttnFwdPrimitive and FusedAttnBwdPrimitive in transformer_engine.jax.cpp_extensions.attention. Users do not call these FFI handlers directly.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/csrc/extensions/attention.cpp- Lines
- 1--670
Signature
namespace transformer_engine {
namespace jax {
NVTE_Fused_Attn_Backend GetFusedAttnBackend(
bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim,
size_t v_head_dim, int64_t window_size_left, int64_t window_size_right,
bool deterministic);
void PrepareFusedAttnForwardAuxTensors(
NVTETensorPack *tensor_pack, const size_t input_batch,
const size_t bias_batch, const size_t attn_heads,
const size_t bias_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype,
NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend,
void *softmax_buf, void *rng_state_buf, void *bias_buf,
void *softmax_offset_buf);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(...);
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(...);
} // namespace jax
} // namespace transformer_engine
Import
#include "../extensions.h"
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q_buf | Buffer_Type |
Yes | Query tensor buffer |
| k_buf | Buffer_Type |
Yes | Key tensor buffer |
| v_buf | Buffer_Type |
Yes | Value tensor buffer |
| bias_buf | Buffer_Type |
No | Optional attention bias buffer |
| cu_seqlen_q | Buffer_Type |
Yes | Cumulative query sequence lengths |
| cu_seqlen_kv | Buffer_Type |
Yes | Cumulative key/value sequence lengths |
| seed_buf | Buffer_Type |
No | RNG seed for dropout |
| qkv_layout | NVTE_QKV_Layout |
Yes | QKV memory layout |
| bias_type | NVTE_Bias_Type |
Yes | Bias type enum |
| mask_type | NVTE_Mask_Type |
Yes | Attention mask type enum |
| scaling_factor | float |
Yes | Attention scaling factor |
| dropout_probability | float |
Yes | Dropout probability |
Outputs
| Name | Type | Description |
|---|---|---|
| output_buf | Result_Type |
Attention output tensor |
| softmax_aux | Result_Type |
Softmax auxiliary data for backward pass |
| rng_state | Result_Type |
RNG state for backward pass dropout |
Usage Examples
// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
// from transformer_engine.jax.attention import fused_attn
// output = fused_attn(qkv=(q, k, v), ...)