Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Attention

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Defines the multi-head attention infrastructure for the JAX backend, including attention configuration enums, memory layout types, context parallelism strategies, and the high-level fused attention API.

Description

This module provides enum wrappers (AttnBiasType, AttnMaskType, AttnSoftmaxType, QKVFormat, QKVLayout) that map to the underlying C++ NVTE enum values from transformer_engine_jax. It implements CPStrategy for context-parallel attention (Ring P2P and AllGather), ReorderStrategy for context-parallel load balancing, SequenceDescriptor for variable-length sequence metadata, and helper functions like make_swa_mask for sliding window attention. The fused_attn entry point dispatches to the appropriate C++ fused attention primitive based on QKV layout and context parallelism settings.

This is the central attention module that bridges high-level JAX/Flax transformer layers to the optimized cuDNN fused attention kernels, supporting diverse attention patterns (causal, padding, sliding window) and distributed strategies (context parallelism).

Usage

Use this module when implementing multi-head attention in JAX-based transformer models. It is the primary interface for invoking cuDNN fused attention kernels with support for various mask types, bias types, QKV layouts, and context parallelism. Higher-level Flax transformer modules (DotProductAttention, MultiHeadAttention) call into this module.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/attention.py
Lines
1--1401

Signature

class AttnBiasType(Enum):
    NO_BIAS = ...
    PRE_SCALE_BIAS = ...
    POST_SCALE_BIAS = ...

class AttnMaskType(Enum):
    NO_MASK = ...
    PADDING_MASK = ...
    CAUSAL_MASK = ...
    PADDING_CAUSAL_MASK = ...
    CAUSAL_BOTTOM_RIGHT_MASK = ...
    PADDING_CAUSAL_BOTTOM_RIGHT_MASK = ...

class QKVLayout(Enum):
    BS3HD = ...
    BSHD_BS2HD = ...
    BSHD_BSHD_BSHD = ...
    T3HD = ...
    THD_T2HD = ...
    THD_THD_THD = ...

class CPStrategy(Enum):
    DEFAULT = 0
    ALL_GATHER = 1
    RING = 2

class SequenceDescriptor: ...

def fused_attn(
    qkv: Tuple[jnp.ndarray, ...],
    bias: Optional[jnp.ndarray],
    sequence_descriptor: SequenceDescriptor,
    seed: Optional[jnp.ndarray],
    attn_bias_type: AttnBiasType,
    attn_mask_type: AttnMaskType,
    qkv_layout: QKVLayout,
    softmax_type: AttnSoftmaxType,
    scaling_factor: float,
    dropout_probability: float,
    is_training: bool,
    max_segments_per_seq: int = 1,
    window_size: Optional[Tuple[int, int]] = None,
    context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
    context_parallel_causal_load_balanced: bool = False,
    context_parallel_axis: str = "",
    context_checkpoint_name: str = "context",
    softmax_offset: Optional[jnp.ndarray] = None,
    stripe_size: int | None = None,
) -> jnp.ndarray: ...

Import

from transformer_engine.jax.attention import fused_attn, AttnBiasType, AttnMaskType, QKVLayout, CPStrategy, SequenceDescriptor

I/O Contract

Inputs

Name Type Required Description
qkv Tuple[jnp.ndarray, ...] Yes Tuple of query, key, value tensors in packed or separate form
bias Optional[jnp.ndarray] No Optional attention bias tensor
sequence_descriptor SequenceDescriptor Yes Descriptor for sequence lengths and segment metadata
seed Optional[jnp.ndarray] No Random seed for dropout
attn_bias_type AttnBiasType Yes Type of attention bias (NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS)
attn_mask_type AttnMaskType Yes Type of attention mask (NO_MASK, CAUSAL_MASK, PADDING_MASK, etc.)
qkv_layout QKVLayout Yes Memory layout of QKV tensors
softmax_type AttnSoftmaxType Yes Type of softmax computation
scaling_factor float Yes Scaling factor for attention scores
dropout_probability float Yes Dropout probability
is_training bool Yes Whether in training mode

Outputs

Name Type Description
output jnp.ndarray Attention output tensor with same shape as query

Usage Examples

from transformer_engine.jax.attention import (
    fused_attn, AttnBiasType, AttnMaskType, QKVLayout,
    AttnSoftmaxType, CPStrategy, SequenceDescriptor
)
import jax.numpy as jnp

# Self-attention with causal mask
q = jnp.ones((batch, seqlen, num_heads, head_dim))
k = jnp.ones((batch, seqlen, num_heads, head_dim))
v = jnp.ones((batch, seqlen, num_heads, head_dim))

seq_desc = SequenceDescriptor.from_seqlens_padded(seqlens_q, seqlens_kv, batch, seqlen, seqlen)

output = fused_attn(
    qkv=(q, k, v),
    bias=None,
    sequence_descriptor=seq_desc,
    seed=None,
    attn_bias_type=AttnBiasType.NO_BIAS,
    attn_mask_type=AttnMaskType.CAUSAL_MASK,
    qkv_layout=QKVLayout.BSHD_BSHD_BSHD,
    softmax_type=AttnSoftmaxType.VANILLA_SOFTMAX,
    scaling_factor=1.0 / (head_dim ** 0.5),
    dropout_probability=0.0,
    is_training=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment