Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Softmax

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Attention
Last Updated 2026-02-07 14:00 GMT

Overview

Provides a high-level differentiable softmax function that dispatches to different fused softmax kernel variants based on the fusion type.

Description

SoftmaxFusionType enum defines three variants: SCALED, SCALED_MASKED, and SCALED_UPPER_TRIANG_MASKED. The public softmax() delegates to _softmax via jax.custom_vjp. The forward rule dispatches to the appropriate C++ extension (tex.scaled_softmax_fwd, tex.scaled_masked_softmax_fwd, or tex.scaled_upper_triang_masked_softmax_fwd). The backward rule similarly dispatches to the corresponding backward kernel. Context (softmax output, logits, mask) is saved for the backward pass.

This is a thin wrapper that provides a clean API over the fused softmax primitives, used by Flax modules when the unfused attention path is selected (the fused attention path handles softmax internally).

Usage

Use this function when implementing attention with the unfused path or when standalone softmax with fused GPU kernels is needed. The fused attention path (fused_attn) handles softmax internally and does not use this module.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/softmax.py
Lines
1--68

Signature

class SoftmaxFusionType(Enum):
    SCALED = "scaled"
    SCALED_MASKED = "scaled_masked"
    SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"

def softmax(
    logits: jnp.ndarray,
    mask: Optional[jnp.ndarray] = None,
    scale_factor: Optional[float] = 1.0,
    softmax_fusion_type: Optional[SoftmaxFusionType] = SoftmaxFusionType.SCALED,
) -> jnp.ndarray: ...

Import

from transformer_engine.jax.softmax import softmax, SoftmaxFusionType

I/O Contract

Inputs

Name Type Required Description
logits jnp.ndarray Yes Attention logits tensor
mask Optional[jnp.ndarray] No Optional attention mask (required for SCALED_MASKED)
scale_factor Optional[float] No Scaling factor applied before softmax (default 1.0)
softmax_fusion_type Optional[SoftmaxFusionType] No Softmax fusion variant (default SCALED)

Outputs

Name Type Description
output jnp.ndarray Softmax probability tensor with same shape as logits

Usage Examples

from transformer_engine.jax.softmax import softmax, SoftmaxFusionType
import jax.numpy as jnp

# Scaled softmax
probs = softmax(logits, scale_factor=1.0 / math.sqrt(head_dim))

# Scaled masked softmax
probs = softmax(
    logits, mask=attention_mask,
    scale_factor=1.0 / math.sqrt(head_dim),
    softmax_fusion_type=SoftmaxFusionType.SCALED_MASKED,
)

# Causal (upper triangular) masked softmax
probs = softmax(
    logits,
    scale_factor=1.0 / math.sqrt(head_dim),
    softmax_fusion_type=SoftmaxFusionType.SCALED_UPPER_TRIANG_MASKED,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment