Implementation:NVIDIA TransformerEngine JAX Cpp Softmax
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements JAX custom primitives for optimized softmax operations used in attention, including scaled, masked, and upper-triangular masked variants with forward and backward passes.
Description
SoftmaxPrimitive base class provides kernel availability checks (max sequence length 16384, batch-per-block calculation based on warp size). Three specialized forward/backward primitive pairs handle different fusion types: ScaledSoftmaxFwdPrimitive/ScaledSoftmaxBwdPrimitive, ScaledMaskedSoftmaxFwdPrimitive/ScaledMaskedSoftmaxBwdPrimitive, and ScaledUpperTriangMaskedSoftmaxFwdPrimitive/ScaledUpperTriangMaskedSoftmaxBwdPrimitive. is_softmax_kernel_available checks hardware compatibility. Pure JAX fallback implementations (jax_scaled_softmax, etc.) are provided for when fused kernels are unavailable.
These are performance-critical softmax kernels that avoid materializing the full attention score matrix, reducing memory usage in transformer attention computation.
Usage
Use this module indirectly through the softmax() function in transformer_engine.jax.softmax or when the unfused attention path is selected in Flax transformer modules. Direct usage is needed for custom attention implementations requiring specific softmax fusion types.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/cpp_extensions/softmax.py- Lines
- 1--944
Signature
def is_softmax_kernel_available(
softmax_type: str, batch: int, heads: int, q_seqlen: int, kv_seqlen: int, dtype: jnp.dtype,
) -> bool: ...
class SoftmaxPrimitive(BasePrimitive): ...
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...
class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...
class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...
def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: ...
def scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) -> jnp.ndarray: ...
def scaled_masked_softmax_fwd(logits, mask, scale_factor) -> jnp.ndarray: ...
def scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) -> jnp.ndarray: ...
def scaled_upper_triang_masked_softmax_fwd(logits, scale_factor) -> jnp.ndarray: ...
def scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) -> jnp.ndarray: ...
Import
from transformer_engine.jax.cpp_extensions.softmax import scaled_softmax_fwd, scaled_masked_softmax_fwd, is_softmax_kernel_available
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | jnp.ndarray |
Yes | Attention logits tensor of shape [batch, heads, q_seqlen, kv_seqlen] |
| mask | jnp.ndarray |
No | Boolean attention mask (for masked variants) |
| scale_factor | float |
Yes | Scaling factor applied before softmax |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Softmax probabilities with same shape as logits |
Usage Examples
from transformer_engine.jax.cpp_extensions.softmax import (
scaled_softmax_fwd, scaled_masked_softmax_fwd,
is_softmax_kernel_available
)
# Check kernel availability
available = is_softmax_kernel_available("scaled", batch, heads, q_seqlen, kv_seqlen, dtype)
# Scaled softmax forward
probs = scaled_softmax_fwd(logits, scale_factor=1.0 / math.sqrt(head_dim))
# Scaled masked softmax forward
probs = scaled_masked_softmax_fwd(logits, mask, scale_factor=1.0 / math.sqrt(head_dim))