Implementation:NVIDIA TransformerEngine PyTorch Permutation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements MoE token permutation and unpermutation operations that reorder tokens according to expert routing decisions, with index-based and mask-based routing maps.
Description
Provides two permutation strategies: index-based (_moe_permute_index_map) using tex.moe_permute_fwd/bwd and mask-based (_moe_permute_mask_map) for boolean routing maps. Each has matching unpermute counterparts that reverse the reordering while applying routing probabilities as weights. _moe_chunk_sort sorts token chunks by expert index for memory-coalesced access. The module handles FP8 quantized tensors (Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor) and falls back to Triton-based permutation kernels when appropriate. Public API functions include moe_permute, moe_permute_with_probs, moe_permute_and_pad_with_probs, moe_unpermute, and sorting utilities.
Usage
Essential infrastructure for MoE execution. After the router selects experts, tokens must be physically reordered so each expert processes contiguous data.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/permutation.py- Lines
- 1--850
Signature
def moe_permute(
inp, indices, num_out_tokens=None, padded_mode=False, ...
): ...
def moe_permute_with_probs(
inp, indices, probs, num_out_tokens=None, ...
): ...
def moe_permute_and_pad_with_probs(
inp, indices, probs, num_topK, num_experts, ...
): ...
def moe_unpermute(
inp, sorted_indices, probs=None, ...
): ...
def moe_sort_chunks_by_index(inp, split_sizes, sorted_indices): ...
class _moe_permute_index_map(torch.autograd.Function): ...
class _moe_unpermute_index_map(torch.autograd.Function): ...
class _moe_permute_mask_map(torch.autograd.Function): ...
class _moe_unpermute_mask_map(torch.autograd.Function): ...
class _moe_chunk_sort(torch.autograd.Function): ...
Import
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_unpermute,
moe_sort_chunks_by_index,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | torch.Tensor |
Yes | Input tokens to permute |
| indices | torch.Tensor |
Yes | Expert routing indices for each token |
| probs | torch.Tensor |
No | Routing probabilities/weights |
| num_out_tokens | int |
No | Number of output tokens (for padding) |
| num_topK | int |
No | Top-K experts per token |
| num_experts | int |
No | Total number of experts |
Outputs
| Name | Type | Description |
|---|---|---|
| permuted_tokens | torch.Tensor |
Tokens reordered by expert assignment |
| sorted_indices | torch.Tensor |
Indices for unpermuting back to original order |
| row_id_map | torch.Tensor |
Mapping from input to output positions |
Usage Examples
from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute
# Permute tokens to expert order
permuted, sorted_indices = moe_permute(tokens, expert_indices, num_out_tokens=total)
# Process by experts (grouped_linear)
expert_output = grouped_linear(permuted, m_splits)
# Unpermute back to original order with routing weights
output = moe_unpermute(expert_output, sorted_indices, probs=routing_probs)