Implementation:NVIDIA TransformerEngine Fp8Unpadding
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Removes row padding from tensors after Grouped GEMM operations, reversing the alignment padding applied by Fp8Padding.
Description
Fp8Unpadding is the counterpart to Fp8Padding. It removes the alignment padding from output tensors after quantized Grouped GEMM operations. The implementation uses a custom autograd function (_Fp8Unpadding) that calls tex.fused_multi_row_unpadding for the forward pass and tex.fused_multi_row_padding for the backward pass. If no padding was needed (input already aligned), the module is a no-op.
Usage
Use after Grouped GEMM operations to restore the original (unpadded) tensor dimensions. Typically paired with Fp8Padding before the GEMM.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/module/fp8_unpadding.py- Lines
- 1--142
Signature
class Fp8Unpadding(torch.nn.Module):
def __init__(self, num_gemms: int, align_size: Optional[int] = None) -> None: ...
def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: ...
Import
from transformer_engine.pytorch.module.fp8_unpadding import Fp8Unpadding
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | torch.Tensor | Yes | Padded input tensor from GEMM output |
| m_splits | List[int] | Yes | Original (unpadded) split sizes |
Outputs
| Name | Type | Description |
|---|---|---|
| out | torch.Tensor | Unpadded tensor with original row dimensions |
Usage Examples
from transformer_engine.pytorch.module.fp8_unpadding import Fp8Unpadding
unpadding = Fp8Unpadding(num_gemms=4)
output = unpadding(gemm_output, m_splits=[128, 65, 200, 33])