Implementation:NVIDIA TransformerEngine PyTorch Ext Swizzle
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements scale factor swizzling for MXFP8 and NVFP4 tensors, reordering scaling factors into the memory layout required by cuBLAS GEMM kernels.
Description
swizzle_scales_for_gemm takes a TensorWrapper with MXFP8 or NVFP4 scaling mode and swizzles both row-wise and column-wise scale_inv tensors independently by calling nvte_swizzle_scaling_factors. Allocates output buffers for swizzled scales and updates the tensor's scale_inv pointers. multi_tensor_swizzle_scales_for_gemm batches the operation across multiple tensors, allocating a single contiguous buffer for all swizzled scales. convert_block_scaling_to_mxfp8_tensor converts FP8 block-scaled tensors to MXFP8 format in-place, reinterpreting columnwise data as rowwise when needed, and returns the swizzled scaling factors. Uses reset_tensor_data helper to clear unused data/scale slots. Includes early-exit paths for already-swizzled tensors.
Usage
Scale swizzling is a required data layout transformation for GEMM compatibility. cuBLAS expects scaling factors in a specific interleaved format for MXFP8/NVFP4 matmul.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/swizzle.cpp- Lines
- 1--394
Signature
namespace transformer_engine::pytorch {
void swizzle_scales_for_gemm(py::handle tensor_handle);
void multi_tensor_swizzle_scales_for_gemm(
std::vector<py::handle> tensor_handles);
py::tuple convert_block_scaling_to_mxfp8_tensor(
py::handle tensor_handle, bool force_columnwise_as_rowwise);
}
Import
#include "extensions.h"
#include "common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor_handle | py::handle |
Yes | MXFP8 or NVFP4 tensor with unswizzled scales |
| force_columnwise_as_rowwise | bool |
No | Force reinterpretation of columnwise data |
Outputs
| Name | Type | Description |
|---|---|---|
| N/A | N/A | Swizzling is performed in-place on the tensor's scale_inv data |
Usage Examples
import transformer_engine_torch as tex
# Swizzle MXFP8 tensor scales for GEMM compatibility
tex.swizzle_scales_for_gemm(mxfp8_tensor)
# Batch swizzle multiple tensors
tex.multi_tensor_swizzle_scales_for_gemm([tensor1, tensor2, tensor3])