Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Fused Softmax

From Leeroopedia
Revision as of 15:57, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Fused_Softmax.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Attention, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Collection of fused scale + mask + softmax implementations for attention, including upper-triangular causal masking, aligned causal masking, arbitrary masking, and unmasked variants.

Description

This module provides several fused softmax implementations for different masking scenarios in transformer attention:

  • ScaledUpperTriangMaskedSoftmax -- Scale + upper triangular mask + softmax (GPT-style causal)
  • ScaledAlignedCausalMaskedSoftmax -- Scale + bottom-right aligned causal mask + softmax
  • ScaledMaskedSoftmax -- Scale + arbitrary mask + softmax (padding, padding_causal)
  • ScaledSoftmax -- Scale + softmax (no mask)
  • FusedScaleMaskSoftmax -- High-level nn.Module that selects the appropriate kernel based on mask type, tensor dimensions, and hardware constraints

The FusedScaleMaskSoftmax module automatically falls back to PyTorch softmax when the fused kernel is not available (e.g., wrong dtype, unsupported dimensions, NVTE_MASKED_SOFTMAX_FUSION=0). Fused kernels require FP16/BF16 input, sequence length > 1, key length 16--16384 and divisible by 8, and specific batch alignment.

Usage

FusedScaleMaskSoftmax is used internally by the dot-product attention implementation. It is configured with a mask function and automatically selects the optimal kernel.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/attention/dot_product_attention/softmax.py
Lines
1--287

Signature

class FusedScaleMaskSoftmax(nn.Module):
    def __init__(self, mask_func: Callable, softmax_in_fp32: bool = True) -> None: ...
    def forward(self, inp: torch.Tensor, mask: torch.Tensor, attn_mask_type: str, scale: Optional[float] = None) -> torch.Tensor: ...
    def is_kernel_available(self, mask, b, np, sq, sk) -> bool: ...
    def forward_fused_softmax(self, inp, mask, scale) -> torch.Tensor: ...
    def forward_torch_softmax(self, inp, mask, scale) -> torch.Tensor: ...
    @staticmethod
    def get_batch_per_block(key_seq_len: int) -> int: ...

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): ...
class ScaledMaskedSoftmax(torch.autograd.Function): ...
class ScaledSoftmax(torch.autograd.Function): ...

Import

from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax

I/O Contract

Inputs

Name Type Required Description
inp torch.Tensor Yes Attention scores of shape (b, np, sq, sk)
mask torch.Tensor No Attention mask of shape (1, 1, sq, sk) or (b, 1, sq, sk)
attn_mask_type str Yes One of "no_mask", "causal", "causal_bottom_right", "padding", "padding_causal", "padding_causal_bottom_right", "arbitrary"
scale Optional[float] No Scaling factor (default 1.0)
mask_func Callable Yes Function to apply mask to input

Outputs

Name Type Description
probs torch.Tensor Softmax probabilities of shape (b, np, sq, sk)

Usage Examples

from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax

def mask_func(inp, mask):
    return inp.masked_fill(mask, float("-inf"))

softmax = FusedScaleMaskSoftmax(mask_func=mask_func, softmax_in_fp32=True)
probs = softmax(attention_scores, mask, attn_mask_type="causal_bottom_right", scale=1.0 / math.sqrt(head_dim))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment