Implementation:NVIDIA TransformerEngine JAX Triton Permutation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements JAX custom primitives for MoE token permutation operations using Triton kernels, including row ID map generation, token permutation/unpermutation, and chunk sorting.
Description
Token routing is implemented as a three-pass row ID map generation: Pass 1 (RowIdMapPass1Primitive) computes per-block cumulative sums, Pass 2 (RowIdMapPass2Primitive) aggregates across blocks, Pass 3 (RowIdMapPass3Primitive) produces final per-token output positions. PermuteWithMaskMapPrimitive scatters tokens to expert slots based on the row ID map. UnpermuteWithMaskMapPrimitive gathers tokens back. UnpermuteBwdWithMergingProbsPrimitive handles the backward pass with probability-weighted merging. Padding variants (_and_pad, _and_unpad) support alignment requirements. MakeChunkSortMapPrimitive and SortChunksByMapPrimitive handle expert chunk reordering. All primitives use BasePrimitive with triton_call_lowering for MLIR integration and support SPMD sharding via custom partitioning.
This module provides the high-performance GPU kernels that underpin MoE token routing -- the three-pass row ID map algorithm enables efficient parallel token dispatch/combine operations critical for scaling MoE models.
Usage
Use this module indirectly through token_dispatch and token_combine in transformer_engine.jax.permutation. Direct usage is needed when implementing custom MoE routing strategies.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/triton_extensions/permutation.py- Lines
- 1--2259
Signature
class RowIdMapPass1Primitive(BasePrimitive): ...
class RowIdMapPass2Primitive(BasePrimitive): ...
class RowIdMapPass3Primitive(BasePrimitive): ...
class PermuteWithMaskMapPrimitive(BasePrimitive): ...
class UnpermuteWithMaskMapPrimitive(BasePrimitive): ...
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): ...
class MakeChunkSortMapPrimitive(BasePrimitive): ...
class SortChunksByMapPrimitive(BasePrimitive): ...
def make_row_id_map(
routing_map: jnp.ndarray, num_tokens_per_expert: int,
) -> jnp.ndarray: ...
def permute_with_mask_map(
tokens: jnp.ndarray, row_id_map: jnp.ndarray, num_out_tokens: int,
) -> jnp.ndarray: ...
def permute_with_mask_map_and_pad(
tokens: jnp.ndarray, row_id_map: jnp.ndarray, num_out_tokens: int,
pad_offsets: jnp.ndarray,
) -> jnp.ndarray: ...
def unpermute_with_mask_map(
tokens: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray,
num_original_tokens: int,
) -> jnp.ndarray: ...
def make_chunk_sort_map(split_sizes, sorted_indices) -> jnp.ndarray: ...
def sort_chunks_by_map(tokens, chunk_sort_map, num_out_tokens) -> jnp.ndarray: ...
Import
from transformer_engine.jax.triton_extensions.permutation import make_row_id_map, permute_with_mask_map, unpermute_with_mask_map
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| routing_map | jnp.ndarray |
Yes | Boolean routing map of shape [num_tokens, num_experts] |
| tokens | jnp.ndarray |
Yes | Token tensor to permute |
| row_id_map | jnp.ndarray |
Yes | Row ID mapping from make_row_id_map |
| num_out_tokens | int |
Yes | Number of output tokens |
| probs | jnp.ndarray |
No | Routing probabilities for unpermute merging |
Outputs
| Name | Type | Description |
|---|---|---|
| permuted_tokens | jnp.ndarray |
Tokens scattered to expert slots |
| row_id_map | jnp.ndarray |
Token-to-output-position mapping |
Usage Examples
from transformer_engine.jax.triton_extensions.permutation import (
make_row_id_map, permute_with_mask_map, unpermute_with_mask_map
)
# Build row ID map from routing decisions
row_id_map = make_row_id_map(routing_map, num_tokens_per_expert=capacity)
# Scatter tokens to experts
permuted = permute_with_mask_map(tokens, row_id_map, num_out_tokens)
# Gather tokens back from experts
output = unpermute_with_mask_map(expert_output, row_id_map, probs, num_original_tokens)