Implementation:NVIDIA TransformerEngine JAX Softmax
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Attention |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides a high-level differentiable softmax function that dispatches to different fused softmax kernel variants based on the fusion type.
Description
SoftmaxFusionType enum defines three variants: SCALED, SCALED_MASKED, and SCALED_UPPER_TRIANG_MASKED. The public softmax() delegates to _softmax via jax.custom_vjp. The forward rule dispatches to the appropriate C++ extension (tex.scaled_softmax_fwd, tex.scaled_masked_softmax_fwd, or tex.scaled_upper_triang_masked_softmax_fwd). The backward rule similarly dispatches to the corresponding backward kernel. Context (softmax output, logits, mask) is saved for the backward pass.
This is a thin wrapper that provides a clean API over the fused softmax primitives, used by Flax modules when the unfused attention path is selected (the fused attention path handles softmax internally).
Usage
Use this function when implementing attention with the unfused path or when standalone softmax with fused GPU kernels is needed. The fused attention path (fused_attn) handles softmax internally and does not use this module.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/softmax.py- Lines
- 1--68
Signature
class SoftmaxFusionType(Enum):
SCALED = "scaled"
SCALED_MASKED = "scaled_masked"
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def softmax(
logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
softmax_fusion_type: Optional[SoftmaxFusionType] = SoftmaxFusionType.SCALED,
) -> jnp.ndarray: ...
Import
from transformer_engine.jax.softmax import softmax, SoftmaxFusionType
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | jnp.ndarray |
Yes | Attention logits tensor |
| mask | Optional[jnp.ndarray] |
No | Optional attention mask (required for SCALED_MASKED) |
| scale_factor | Optional[float] |
No | Scaling factor applied before softmax (default 1.0) |
| softmax_fusion_type | Optional[SoftmaxFusionType] |
No | Softmax fusion variant (default SCALED) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Softmax probability tensor with same shape as logits |
Usage Examples
from transformer_engine.jax.softmax import softmax, SoftmaxFusionType
import jax.numpy as jnp
# Scaled softmax
probs = softmax(logits, scale_factor=1.0 / math.sqrt(head_dim))
# Scaled masked softmax
probs = softmax(
logits, mask=attention_mask,
scale_factor=1.0 / math.sqrt(head_dim),
softmax_fusion_type=SoftmaxFusionType.SCALED_MASKED,
)
# Causal (upper triangular) masked softmax
probs = softmax(
logits,
scale_factor=1.0 / math.sqrt(head_dim),
softmax_fusion_type=SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
)