Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Fp8Unpadding

From Leeroopedia
Revision as of 15:57, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Fp8Unpadding.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Removes row padding from tensors after Grouped GEMM operations, reversing the alignment padding applied by Fp8Padding.

Description

Fp8Unpadding is the counterpart to Fp8Padding. It removes the alignment padding from output tensors after quantized Grouped GEMM operations. The implementation uses a custom autograd function (_Fp8Unpadding) that calls tex.fused_multi_row_unpadding for the forward pass and tex.fused_multi_row_padding for the backward pass. If no padding was needed (input already aligned), the module is a no-op.

Usage

Use after Grouped GEMM operations to restore the original (unpadded) tensor dimensions. Typically paired with Fp8Padding before the GEMM.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/module/fp8_unpadding.py
Lines
1--142

Signature

class Fp8Unpadding(torch.nn.Module):
    def __init__(self, num_gemms: int, align_size: Optional[int] = None) -> None: ...
    def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: ...

Import

from transformer_engine.pytorch.module.fp8_unpadding import Fp8Unpadding

I/O Contract

Inputs

Name Type Required Description
inp torch.Tensor Yes Padded input tensor from GEMM output
m_splits List[int] Yes Original (unpadded) split sizes

Outputs

Name Type Description
out torch.Tensor Unpadded tensor with original row dimensions

Usage Examples

from transformer_engine.pytorch.module.fp8_unpadding import Fp8Unpadding

unpadding = Fp8Unpadding(num_gemms=4)
output = unpadding(gemm_output, m_splits=[128, 65, 200, 33])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment