Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Triton Permutation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

PyTorch wrapper for Triton-based token permutation kernels used in Mixture-of-Experts routing, handling token-to-expert mapping and reverse unpermutation.

Description

make_row_id_map builds a token-to-expert mapping via a 3-pass Triton approach: pass 1 computes per-block cumulative sums, pass 2 converts to global destination row indices, pass 3 compacts the sparse structure into a dense format. permute_with_mask_map reorders tokens according to the row_id_map with optional probability/scale permutation and FP8 padding support. unpermute_with_mask_map reverses the permutation, optionally merging expert probabilities during backward. unpermute_with_mask_map_bwd_with_merging_probs handles the combined backward pass for unpermute with probability merging. Additional utilities handle chunk sorting for expert-aligned memory layouts.

Usage

Token permutation is the core data movement operation in MoE layers. These Triton kernels fuse scatter/gather operations into efficient GPU kernels with FP8 padding alignment support.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/triton/permutation.py
Lines
1--442

Signature

def make_row_id_map(
    mask: torch.Tensor, num_out_tokens: int = 0,
    num_topK: int = 1, fp8_padding: int = 0, ...
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def permute_with_mask_map(
    inp: torch.Tensor, row_id_map: torch.Tensor,
    num_out_tokens: int, probs: Optional[torch.Tensor] = None, ...
) -> torch.Tensor: ...

def unpermute_with_mask_map(
    inp: torch.Tensor, row_id_map: torch.Tensor,
    probs: Optional[torch.Tensor] = None, ...
) -> torch.Tensor: ...

def make_chunk_sort_map(split_sizes, sorted_indices, ...): ...
def sort_chunks_by_map(inp, chunk_map, ...): ...

Import

from transformer_engine.pytorch.triton.permutation import (
    make_row_id_map,
    permute_with_mask_map,
    unpermute_with_mask_map,
)

I/O Contract

Inputs

Name Type Required Description
mask torch.Tensor Yes Boolean routing mask (tokens x experts)
inp torch.Tensor Yes Input tokens to permute
row_id_map torch.Tensor Yes Pre-computed mapping from input to output positions
num_out_tokens int No Total number of output tokens
probs torch.Tensor No Routing probabilities for weighted combination
fp8_padding int No Padding alignment for FP8 quantized operations

Outputs

Name Type Description
row_id_map torch.Tensor Token-to-expert position mapping (from make_row_id_map)
permuted torch.Tensor Tokens reordered by expert assignment
unpermuted torch.Tensor Tokens restored to original order with routing weights

Usage Examples

from transformer_engine.pytorch.triton.permutation import (
    make_row_id_map,
    permute_with_mask_map,
    unpermute_with_mask_map,
)

# Build routing map from boolean mask
row_id_map, num_out = make_row_id_map(routing_mask, num_topK=2)

# Permute tokens to expert order
permuted = permute_with_mask_map(tokens, row_id_map, num_out)

# After expert processing, unpermute back
output = unpermute_with_mask_map(expert_output, row_id_map, probs=routing_probs)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment