Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Permutation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX
Last Updated 2026-02-07 14:00 GMT

Overview

Provides high-level token dispatch and combine operations for Mixture-of-Experts (MoE) models, handling token routing to experts with proper automatic differentiation support.

Description

token_dispatch scatters tokens to their designated experts based on a routing map, computing row ID maps internally and optionally padding for alignment. token_combine gathers tokens back from experts, merging with routing probabilities. Both functions use jax.custom_vjp for correct gradient flow -- dispatch's backward is an unpermute, and combine's backward is a permute. The underlying operations delegate to Triton-based permutation kernels (permute_with_mask_map, unpermute_with_mask_map). sort_chunks_by_index reorders expert chunks for optimized processing. Supports fused padding/unpadding for alignment requirements.

This is essential infrastructure for MoE transformer architectures, providing the differentiable token routing layer that connects the router's decisions to the actual expert computation with JIT-compatible fixed output shapes.

Usage

Use this module when implementing Mixture-of-Experts routing in JAX transformer models. token_dispatch and token_combine provide the standard dispatch/combine pattern used before and after expert computation.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/permutation.py
Lines
1--648

Signature

def token_dispatch(
    inp: jnp.ndarray,
    routing_map: jnp.ndarray,
    num_out_tokens: int,
    probs: Optional[jnp.ndarray] = None,
    align_size: Optional[int] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: ...

def token_combine(
    inp: jnp.ndarray,
    row_id_map: jnp.ndarray,
    merging_probs: Optional[jnp.ndarray] = None,
    pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray: ...

def sort_chunks_by_index(
    inp: jnp.ndarray,
    split_sizes: jnp.ndarray,
    sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]: ...

Import

from transformer_engine.jax.permutation import token_dispatch, token_combine, sort_chunks_by_index

I/O Contract

Inputs

Name Type Required Description
inp jnp.ndarray Yes Input token tensor of shape [num_tokens, hidden_size]
routing_map jnp.ndarray Yes Boolean routing map of shape [num_tokens, num_experts]
num_out_tokens int Yes Number of output tokens after dispatch
probs Optional[jnp.ndarray] No Routing probabilities for weighted combine
align_size Optional[int] No Padding alignment size

Outputs

Name Type Description
dispatched_tokens jnp.ndarray Tokens dispatched to experts
merging_probs Optional[jnp.ndarray] Probabilities for token combining
row_id_map jnp.ndarray Mapping of token positions for combining
pad_offsets Optional[jnp.ndarray] Padding offsets for unpadding
expert_offsets jnp.ndarray Per-expert token count offsets

Usage Examples

from transformer_engine.jax.permutation import token_dispatch, token_combine

# Dispatch tokens to experts
dispatched, probs_out, row_id_map, pad_offsets, expert_offsets = token_dispatch(
    inp=tokens,
    routing_map=routing_decisions,
    num_out_tokens=total_expert_capacity,
    probs=routing_probs,
)

# Run expert computation...
expert_output = run_experts(dispatched, expert_offsets)

# Combine tokens back
output = token_combine(expert_output, row_id_map, probs_out, pad_offsets)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment