Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Ext Swizzle

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements scale factor swizzling for MXFP8 and NVFP4 tensors, reordering scaling factors into the memory layout required by cuBLAS GEMM kernels.

Description

swizzle_scales_for_gemm takes a TensorWrapper with MXFP8 or NVFP4 scaling mode and swizzles both row-wise and column-wise scale_inv tensors independently by calling nvte_swizzle_scaling_factors. Allocates output buffers for swizzled scales and updates the tensor's scale_inv pointers. multi_tensor_swizzle_scales_for_gemm batches the operation across multiple tensors, allocating a single contiguous buffer for all swizzled scales. convert_block_scaling_to_mxfp8_tensor converts FP8 block-scaled tensors to MXFP8 format in-place, reinterpreting columnwise data as rowwise when needed, and returns the swizzled scaling factors. Uses reset_tensor_data helper to clear unused data/scale slots. Includes early-exit paths for already-swizzled tensors.

Usage

Scale swizzling is a required data layout transformation for GEMM compatibility. cuBLAS expects scaling factors in a specific interleaved format for MXFP8/NVFP4 matmul.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/swizzle.cpp
Lines
1--394

Signature

namespace transformer_engine::pytorch {

void swizzle_scales_for_gemm(py::handle tensor_handle);

void multi_tensor_swizzle_scales_for_gemm(
    std::vector<py::handle> tensor_handles);

py::tuple convert_block_scaling_to_mxfp8_tensor(
    py::handle tensor_handle, bool force_columnwise_as_rowwise);

}

Import

#include "extensions.h"
#include "common.h"

I/O Contract

Inputs

Name Type Required Description
tensor_handle py::handle Yes MXFP8 or NVFP4 tensor with unswizzled scales
force_columnwise_as_rowwise bool No Force reinterpretation of columnwise data

Outputs

Name Type Description
N/A N/A Swizzling is performed in-place on the tensor's scale_inv data

Usage Examples

import transformer_engine_torch as tex

# Swizzle MXFP8 tensor scales for GEMM compatibility
tex.swizzle_scales_for_gemm(mxfp8_tensor)

# Batch swizzle multiple tensors
tex.multi_tensor_swizzle_scales_for_gemm([tensor1, tensor2, tensor3])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment