Implementation:NVIDIA TransformerEngine JAX Attention
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Defines the multi-head attention infrastructure for the JAX backend, including attention configuration enums, memory layout types, context parallelism strategies, and the high-level fused attention API.
Description
This module provides enum wrappers (AttnBiasType, AttnMaskType, AttnSoftmaxType, QKVFormat, QKVLayout) that map to the underlying C++ NVTE enum values from transformer_engine_jax. It implements CPStrategy for context-parallel attention (Ring P2P and AllGather), ReorderStrategy for context-parallel load balancing, SequenceDescriptor for variable-length sequence metadata, and helper functions like make_swa_mask for sliding window attention. The fused_attn entry point dispatches to the appropriate C++ fused attention primitive based on QKV layout and context parallelism settings.
This is the central attention module that bridges high-level JAX/Flax transformer layers to the optimized cuDNN fused attention kernels, supporting diverse attention patterns (causal, padding, sliding window) and distributed strategies (context parallelism).
Usage
Use this module when implementing multi-head attention in JAX-based transformer models. It is the primary interface for invoking cuDNN fused attention kernels with support for various mask types, bias types, QKV layouts, and context parallelism. Higher-level Flax transformer modules (DotProductAttention, MultiHeadAttention) call into this module.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/attention.py- Lines
- 1--1401
Signature
class AttnBiasType(Enum):
NO_BIAS = ...
PRE_SCALE_BIAS = ...
POST_SCALE_BIAS = ...
class AttnMaskType(Enum):
NO_MASK = ...
PADDING_MASK = ...
CAUSAL_MASK = ...
PADDING_CAUSAL_MASK = ...
CAUSAL_BOTTOM_RIGHT_MASK = ...
PADDING_CAUSAL_BOTTOM_RIGHT_MASK = ...
class QKVLayout(Enum):
BS3HD = ...
BSHD_BS2HD = ...
BSHD_BSHD_BSHD = ...
T3HD = ...
THD_T2HD = ...
THD_THD_THD = ...
class CPStrategy(Enum):
DEFAULT = 0
ALL_GATHER = 1
RING = 2
class SequenceDescriptor: ...
def fused_attn(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
softmax_type: AttnSoftmaxType,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int = 1,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
context_checkpoint_name: str = "context",
softmax_offset: Optional[jnp.ndarray] = None,
stripe_size: int | None = None,
) -> jnp.ndarray: ...
Import
from transformer_engine.jax.attention import fused_attn, AttnBiasType, AttnMaskType, QKVLayout, CPStrategy, SequenceDescriptor
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| qkv | Tuple[jnp.ndarray, ...] |
Yes | Tuple of query, key, value tensors in packed or separate form |
| bias | Optional[jnp.ndarray] |
No | Optional attention bias tensor |
| sequence_descriptor | SequenceDescriptor |
Yes | Descriptor for sequence lengths and segment metadata |
| seed | Optional[jnp.ndarray] |
No | Random seed for dropout |
| attn_bias_type | AttnBiasType |
Yes | Type of attention bias (NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS) |
| attn_mask_type | AttnMaskType |
Yes | Type of attention mask (NO_MASK, CAUSAL_MASK, PADDING_MASK, etc.) |
| qkv_layout | QKVLayout |
Yes | Memory layout of QKV tensors |
| softmax_type | AttnSoftmaxType |
Yes | Type of softmax computation |
| scaling_factor | float |
Yes | Scaling factor for attention scores |
| dropout_probability | float |
Yes | Dropout probability |
| is_training | bool |
Yes | Whether in training mode |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Attention output tensor with same shape as query |
Usage Examples
from transformer_engine.jax.attention import (
fused_attn, AttnBiasType, AttnMaskType, QKVLayout,
AttnSoftmaxType, CPStrategy, SequenceDescriptor
)
import jax.numpy as jnp
# Self-attention with causal mask
q = jnp.ones((batch, seqlen, num_heads, head_dim))
k = jnp.ones((batch, seqlen, num_heads, head_dim))
v = jnp.ones((batch, seqlen, num_heads, head_dim))
seq_desc = SequenceDescriptor.from_seqlens_padded(seqlens_q, seqlens_kv, batch, seqlen, seqlen)
output = fused_attn(
qkv=(q, k, v),
bias=None,
sequence_descriptor=seq_desc,
seed=None,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.CAUSAL_MASK,
qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
scaling_factor=1.0 / (head_dim ** 0.5),
dropout_probability=0.0,
is_training=True,
)