Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Triton Permutation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX
Last Updated 2026-02-07 14:00 GMT

Overview

Implements JAX custom primitives for MoE token permutation operations using Triton kernels, including row ID map generation, token permutation/unpermutation, and chunk sorting.

Description

Token routing is implemented as a three-pass row ID map generation: Pass 1 (RowIdMapPass1Primitive) computes per-block cumulative sums, Pass 2 (RowIdMapPass2Primitive) aggregates across blocks, Pass 3 (RowIdMapPass3Primitive) produces final per-token output positions. PermuteWithMaskMapPrimitive scatters tokens to expert slots based on the row ID map. UnpermuteWithMaskMapPrimitive gathers tokens back. UnpermuteBwdWithMergingProbsPrimitive handles the backward pass with probability-weighted merging. Padding variants (_and_pad, _and_unpad) support alignment requirements. MakeChunkSortMapPrimitive and SortChunksByMapPrimitive handle expert chunk reordering. All primitives use BasePrimitive with triton_call_lowering for MLIR integration and support SPMD sharding via custom partitioning.

This module provides the high-performance GPU kernels that underpin MoE token routing -- the three-pass row ID map algorithm enables efficient parallel token dispatch/combine operations critical for scaling MoE models.

Usage

Use this module indirectly through token_dispatch and token_combine in transformer_engine.jax.permutation. Direct usage is needed when implementing custom MoE routing strategies.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/triton_extensions/permutation.py
Lines
1--2259

Signature

class RowIdMapPass1Primitive(BasePrimitive): ...
class RowIdMapPass2Primitive(BasePrimitive): ...
class RowIdMapPass3Primitive(BasePrimitive): ...
class PermuteWithMaskMapPrimitive(BasePrimitive): ...
class UnpermuteWithMaskMapPrimitive(BasePrimitive): ...
class UnpermuteBwdWithMergingProbsPrimitive(BasePrimitive): ...
class MakeChunkSortMapPrimitive(BasePrimitive): ...
class SortChunksByMapPrimitive(BasePrimitive): ...

def make_row_id_map(
    routing_map: jnp.ndarray, num_tokens_per_expert: int,
) -> jnp.ndarray: ...

def permute_with_mask_map(
    tokens: jnp.ndarray, row_id_map: jnp.ndarray, num_out_tokens: int,
) -> jnp.ndarray: ...

def permute_with_mask_map_and_pad(
    tokens: jnp.ndarray, row_id_map: jnp.ndarray, num_out_tokens: int,
    pad_offsets: jnp.ndarray,
) -> jnp.ndarray: ...

def unpermute_with_mask_map(
    tokens: jnp.ndarray, row_id_map: jnp.ndarray, probs: jnp.ndarray,
    num_original_tokens: int,
) -> jnp.ndarray: ...

def make_chunk_sort_map(split_sizes, sorted_indices) -> jnp.ndarray: ...
def sort_chunks_by_map(tokens, chunk_sort_map, num_out_tokens) -> jnp.ndarray: ...

Import

from transformer_engine.jax.triton_extensions.permutation import make_row_id_map, permute_with_mask_map, unpermute_with_mask_map

I/O Contract

Inputs

Name Type Required Description
routing_map jnp.ndarray Yes Boolean routing map of shape [num_tokens, num_experts]
tokens jnp.ndarray Yes Token tensor to permute
row_id_map jnp.ndarray Yes Row ID mapping from make_row_id_map
num_out_tokens int Yes Number of output tokens
probs jnp.ndarray No Routing probabilities for unpermute merging

Outputs

Name Type Description
permuted_tokens jnp.ndarray Tokens scattered to expert slots
row_id_map jnp.ndarray Token-to-output-position mapping

Usage Examples

from transformer_engine.jax.triton_extensions.permutation import (
    make_row_id_map, permute_with_mask_map, unpermute_with_mask_map
)

# Build row ID map from routing decisions
row_id_map = make_row_id_map(routing_map, num_tokens_per_expert=capacity)

# Scatter tokens to experts
permuted = permute_with_mask_map(tokens, row_id_map, num_out_tokens)

# Gather tokens back from experts
output = unpermute_with_mask_map(expert_output, row_id_map, probs, num_original_tokens)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment