Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Cpp Softmax

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Implements JAX custom primitives for optimized softmax operations used in attention, including scaled, masked, and upper-triangular masked variants with forward and backward passes.

Description

SoftmaxPrimitive base class provides kernel availability checks (max sequence length 16384, batch-per-block calculation based on warp size). Three specialized forward/backward primitive pairs handle different fusion types: ScaledSoftmaxFwdPrimitive/ScaledSoftmaxBwdPrimitive, ScaledMaskedSoftmaxFwdPrimitive/ScaledMaskedSoftmaxBwdPrimitive, and ScaledUpperTriangMaskedSoftmaxFwdPrimitive/ScaledUpperTriangMaskedSoftmaxBwdPrimitive. is_softmax_kernel_available checks hardware compatibility. Pure JAX fallback implementations (jax_scaled_softmax, etc.) are provided for when fused kernels are unavailable.

These are performance-critical softmax kernels that avoid materializing the full attention score matrix, reducing memory usage in transformer attention computation.

Usage

Use this module indirectly through the softmax() function in transformer_engine.jax.softmax or when the unfused attention path is selected in Flax transformer modules. Direct usage is needed for custom attention implementations requiring specific softmax fusion types.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/cpp_extensions/softmax.py
Lines
1--944

Signature

def is_softmax_kernel_available(
    softmax_type: str, batch: int, heads: int, q_seqlen: int, kv_seqlen: int, dtype: jnp.dtype,
) -> bool: ...

class SoftmaxPrimitive(BasePrimitive): ...
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...
class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...
class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...
class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...

def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: ...
def scaled_softmax_bwd(dz, softmax_output, logits, scale_factor) -> jnp.ndarray: ...
def scaled_masked_softmax_fwd(logits, mask, scale_factor) -> jnp.ndarray: ...
def scaled_masked_softmax_bwd(dz, softmax_output, logits, mask, scale_factor) -> jnp.ndarray: ...
def scaled_upper_triang_masked_softmax_fwd(logits, scale_factor) -> jnp.ndarray: ...
def scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, logits, scale_factor) -> jnp.ndarray: ...

Import

from transformer_engine.jax.cpp_extensions.softmax import scaled_softmax_fwd, scaled_masked_softmax_fwd, is_softmax_kernel_available

I/O Contract

Inputs

Name Type Required Description
logits jnp.ndarray Yes Attention logits tensor of shape [batch, heads, q_seqlen, kv_seqlen]
mask jnp.ndarray No Boolean attention mask (for masked variants)
scale_factor float Yes Scaling factor applied before softmax

Outputs

Name Type Description
output jnp.ndarray Softmax probabilities with same shape as logits

Usage Examples

from transformer_engine.jax.cpp_extensions.softmax import (
    scaled_softmax_fwd, scaled_masked_softmax_fwd,
    is_softmax_kernel_available
)

# Check kernel availability
available = is_softmax_kernel_available("scaled", batch, heads, q_seqlen, kv_seqlen, dtype)

# Scaled softmax forward
probs = scaled_softmax_fwd(logits, scale_factor=1.0 / math.sqrt(head_dim))

# Scaled masked softmax forward
probs = scaled_masked_softmax_fwd(logits, mask, scale_factor=1.0 / math.sqrt(head_dim))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment