Heuristic:NVIDIA TransformerEngine Sequence Length Alignment
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs, Performance |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
Tensor dimension alignment requirement: pad sequence lengths to multiples of 16, 32, or 128 depending on the quantization recipe to enable efficient FP8/MXFP8/NVFP4 GEMM execution.
Description
FP8 quantized GEMM operations in TransformerEngine require tensor dimensions to be aligned to specific boundaries. The alignment size depends on the quantization recipe: 16 bytes for standard FP8 (DelayedScaling, Float8CurrentScaling), 32 bytes for MXFP8, and 128 bytes for NVFP4. When dimensions are not aligned, TransformerEngine provides `Fp8Padding` and `Fp8Unpadding` modules that automatically pad and unpad tensors, but it is more efficient to ensure input dimensions are naturally aligned.
Usage
Use this heuristic when preparing training data or configuring model dimensions for FP8 training. The most common practical application is padding sequence lengths to a multiple of 16. For MXFP8 with userbuffers, dimensions must be divisible by 128.
The Insight (Rule of Thumb)
- Action: Pad sequence lengths (and batch dimensions where applicable) to align with the quantization recipe requirements.
- Value:
- Standard FP8 (DelayedScaling, Float8CurrentScaling, Float8BlockScaling): Pad to multiple of 16
- MXFP8BlockScaling: Pad to multiple of 32
- NVFP4BlockScaling: Pad to multiple of 128
- MXFP8 + Userbuffers: Tensor dimensions must be divisible by 128
- Trade-off: Slight memory overhead from padding (typically negligible). Misaligned dimensions cause either automatic padding (with overhead) or assertion failures.
- Practical tip: For Transformer networks, set `pad_to_multiple_of = 16` in your data collator.
Reasoning
FP8 GEMM kernels on NVIDIA GPUs operate on tiles of specific sizes. When tensor dimensions are not multiples of the tile size, either the kernel must handle edge cases (slower) or the tensor must be explicitly padded. TransformerEngine's `Fp8Padding` module handles this transparently, but pre-aligned data avoids the overhead entirely. The alignment sizes correspond to the quantization block sizes: standard FP8 uses per-tensor scaling (16-byte aligned), MXFP8 uses 32-element blocks, and NVFP4 uses 128-element blocks.
Code Evidence
Alignment size selection from `transformer_engine/pytorch/quantization.py:118-124`:
def get_align_size_for_quantization(recipe: Recipe) -> int:
"""Get the alignment size for quantization."""
if recipe.mxfp8():
return 32
if recipe.nvfp4():
return 128
return 16
Practical usage in the TE Llama example from `docs/examples/te_llama/utils.py:69-70`:
pad_to_multiple_of = 16 # Ensure alignment for FP8 Linear layers