Implementation:NVIDIA TransformerEngine PyTorch Triton Permutation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
PyTorch wrapper for Triton-based token permutation kernels used in Mixture-of-Experts routing, handling token-to-expert mapping and reverse unpermutation.
Description
make_row_id_map builds a token-to-expert mapping via a 3-pass Triton approach: pass 1 computes per-block cumulative sums, pass 2 converts to global destination row indices, pass 3 compacts the sparse structure into a dense format. permute_with_mask_map reorders tokens according to the row_id_map with optional probability/scale permutation and FP8 padding support. unpermute_with_mask_map reverses the permutation, optionally merging expert probabilities during backward. unpermute_with_mask_map_bwd_with_merging_probs handles the combined backward pass for unpermute with probability merging. Additional utilities handle chunk sorting for expert-aligned memory layouts.
Usage
Token permutation is the core data movement operation in MoE layers. These Triton kernels fuse scatter/gather operations into efficient GPU kernels with FP8 padding alignment support.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/triton/permutation.py- Lines
- 1--442
Signature
def make_row_id_map(
mask: torch.Tensor, num_out_tokens: int = 0,
num_topK: int = 1, fp8_padding: int = 0, ...
) -> Tuple[torch.Tensor, torch.Tensor]: ...
def permute_with_mask_map(
inp: torch.Tensor, row_id_map: torch.Tensor,
num_out_tokens: int, probs: Optional[torch.Tensor] = None, ...
) -> torch.Tensor: ...
def unpermute_with_mask_map(
inp: torch.Tensor, row_id_map: torch.Tensor,
probs: Optional[torch.Tensor] = None, ...
) -> torch.Tensor: ...
def make_chunk_sort_map(split_sizes, sorted_indices, ...): ...
def sort_chunks_by_map(inp, chunk_map, ...): ...
Import
from transformer_engine.pytorch.triton.permutation import (
make_row_id_map,
permute_with_mask_map,
unpermute_with_mask_map,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| mask | torch.Tensor |
Yes | Boolean routing mask (tokens x experts) |
| inp | torch.Tensor |
Yes | Input tokens to permute |
| row_id_map | torch.Tensor |
Yes | Pre-computed mapping from input to output positions |
| num_out_tokens | int |
No | Total number of output tokens |
| probs | torch.Tensor |
No | Routing probabilities for weighted combination |
| fp8_padding | int |
No | Padding alignment for FP8 quantized operations |
Outputs
| Name | Type | Description |
|---|---|---|
| row_id_map | torch.Tensor |
Token-to-expert position mapping (from make_row_id_map) |
| permuted | torch.Tensor |
Tokens reordered by expert assignment |
| unpermuted | torch.Tensor |
Tokens restored to original order with routing weights |
Usage Examples
from transformer_engine.pytorch.triton.permutation import (
make_row_id_map,
permute_with_mask_map,
unpermute_with_mask_map,
)
# Build routing map from boolean mask
row_id_map, num_out = make_row_id_map(routing_mask, num_topK=2)
# Permute tokens to expert order
permuted = permute_with_mask_map(tokens, row_id_map, num_out)
# After expert processing, unpermute back
output = unpermute_with_mask_map(expert_output, row_id_map, probs=routing_probs)