Implementation:NVIDIA TransformerEngine Fused Softmax
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Attention, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Collection of fused scale + mask + softmax implementations for attention, including upper-triangular causal masking, aligned causal masking, arbitrary masking, and unmasked variants.
Description
This module provides several fused softmax implementations for different masking scenarios in transformer attention:
- ScaledUpperTriangMaskedSoftmax -- Scale + upper triangular mask + softmax (GPT-style causal)
- ScaledAlignedCausalMaskedSoftmax -- Scale + bottom-right aligned causal mask + softmax
- ScaledMaskedSoftmax -- Scale + arbitrary mask + softmax (padding, padding_causal)
- ScaledSoftmax -- Scale + softmax (no mask)
- FusedScaleMaskSoftmax -- High-level
nn.Modulethat selects the appropriate kernel based on mask type, tensor dimensions, and hardware constraints
The FusedScaleMaskSoftmax module automatically falls back to PyTorch softmax when the fused kernel is not available (e.g., wrong dtype, unsupported dimensions, NVTE_MASKED_SOFTMAX_FUSION=0). Fused kernels require FP16/BF16 input, sequence length > 1, key length 16--16384 and divisible by 8, and specific batch alignment.
Usage
FusedScaleMaskSoftmax is used internally by the dot-product attention implementation. It is configured with a mask function and automatically selects the optimal kernel.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/attention/dot_product_attention/softmax.py- Lines
- 1--287
Signature
class FusedScaleMaskSoftmax(nn.Module):
def __init__(self, mask_func: Callable, softmax_in_fp32: bool = True) -> None: ...
def forward(self, inp: torch.Tensor, mask: torch.Tensor, attn_mask_type: str, scale: Optional[float] = None) -> torch.Tensor: ...
def is_kernel_available(self, mask, b, np, sq, sk) -> bool: ...
def forward_fused_softmax(self, inp, mask, scale) -> torch.Tensor: ...
def forward_torch_softmax(self, inp, mask, scale) -> torch.Tensor: ...
@staticmethod
def get_batch_per_block(key_seq_len: int) -> int: ...
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): ...
class ScaledMaskedSoftmax(torch.autograd.Function): ...
class ScaledSoftmax(torch.autograd.Function): ...
Import
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | torch.Tensor | Yes | Attention scores of shape (b, np, sq, sk)
|
| mask | torch.Tensor | No | Attention mask of shape (1, 1, sq, sk) or (b, 1, sq, sk)
|
| attn_mask_type | str | Yes | One of "no_mask", "causal", "causal_bottom_right", "padding", "padding_causal", "padding_causal_bottom_right", "arbitrary"
|
| scale | Optional[float] | No | Scaling factor (default 1.0) |
| mask_func | Callable | Yes | Function to apply mask to input |
Outputs
| Name | Type | Description |
|---|---|---|
| probs | torch.Tensor | Softmax probabilities of shape (b, np, sq, sk)
|
Usage Examples
from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax
def mask_func(inp, mask):
return inp.masked_fill(mask, float("-inf"))
softmax = FusedScaleMaskSoftmax(mask_func=mask_func, softmax_in_fp32=True)
probs = softmax(attention_scores, mask, attn_mask_type="causal_bottom_right", scale=1.0 / math.sqrt(head_dim))