Implementation:NVIDIA TransformerEngine Triton Permutation
Appearance
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements Triton JIT kernels for efficient token permutation and unpermutation operations used in Mixture-of-Experts (MoE) routing, including row ID mapping, argsort, and forward/backward permutation passes.
Description
triton/permutation.py provides several Triton JIT kernels essential for MoE architectures:
- Bitonic sort:
_compare_and_swap,_bitonic_merge, and_argsortimplement a bitonic sort for argsort operations within Triton. - Row ID mapping:
_row_id_map_pass_1_kerneland_row_id_map_pass_2_kernelbuild a row ID mapping from routing maps using a two-pass prefix sum approach (per-block cumsum then cross-block aggregation). - Permutation:
_permute_kernelreorders token rows based on routing maps, supporting probs-based weighted permutation and top-k expert routing. - Unpermutation:
_unpermute_kernelreverses the permutation to restore original token order. - Backward passes: Additional kernels handle gradient propagation for both permutation and unpermutation.
Usage
Use these kernels in MoE layers where tokens must be efficiently routed to and gathered from different expert subnetworks based on a learned routing policy.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/triton/permutation.py- Lines
- 1--642
Signature
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
...
@triton.jit
def _row_id_map_pass_1_kernel(routing_map_ptr, num_tokens,
stride_routing_map_token, ...,
row_id_map_ptr, workspace_ptr,
BLOCK_SIZE: tl.constexpr):
...
@triton.jit
def _permute_kernel(input_ptr, output_ptr, sorted_row_id_ptr, ...):
...
@triton.jit
def _unpermute_kernel(input_ptr, output_ptr, row_id_map_ptr, ...):
...
Import
from transformer_engine.common.triton.permutation import (
_permute_kernel,
_unpermute_kernel,
_row_id_map_pass_1_kernel,
_row_id_map_pass_2_kernel,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
routing_map_ptr |
pointer | Yes | Routing map from MoE gating |
input_ptr |
pointer | Yes | Input token tensor |
num_tokens |
int |
Yes | Number of tokens in the batch |
BLOCK_SIZE |
tl.constexpr |
Yes | Triton block size |
Outputs
| Name | Type | Description |
|---|---|---|
output_ptr |
pointer | Permuted (or unpermuted) token tensor |
row_id_map_ptr |
pointer | Row ID mapping for reverse permutation |
Usage Examples
import triton
from transformer_engine.common.triton.permutation import _permute_kernel
# Launch permutation kernel
grid = (num_experts,)
_permute_kernel[grid](
input_ptr, output_ptr, sorted_row_id_ptr,
row_id_map_ptr, prob_ptr, prob_grad_ptr,
input_fwd_ptr, num_rows, topK, num_cols,
num_out_tokens, BLOCK_SIZE=128,
)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment