Implementation:NVIDIA TransformerEngine JAX Permutation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides high-level token dispatch and combine operations for Mixture-of-Experts (MoE) models, handling token routing to experts with proper automatic differentiation support.
Description
token_dispatch scatters tokens to their designated experts based on a routing map, computing row ID maps internally and optionally padding for alignment. token_combine gathers tokens back from experts, merging with routing probabilities. Both functions use jax.custom_vjp for correct gradient flow -- dispatch's backward is an unpermute, and combine's backward is a permute. The underlying operations delegate to Triton-based permutation kernels (permute_with_mask_map, unpermute_with_mask_map). sort_chunks_by_index reorders expert chunks for optimized processing. Supports fused padding/unpadding for alignment requirements.
This is essential infrastructure for MoE transformer architectures, providing the differentiable token routing layer that connects the router's decisions to the actual expert computation with JIT-compatible fixed output shapes.
Usage
Use this module when implementing Mixture-of-Experts routing in JAX transformer models. token_dispatch and token_combine provide the standard dispatch/combine pattern used before and after expert computation.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/permutation.py- Lines
- 1--648
Signature
def token_dispatch(
inp: jnp.ndarray,
routing_map: jnp.ndarray,
num_out_tokens: int,
probs: Optional[jnp.ndarray] = None,
align_size: Optional[int] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]: ...
def token_combine(
inp: jnp.ndarray,
row_id_map: jnp.ndarray,
merging_probs: Optional[jnp.ndarray] = None,
pad_offsets: Optional[jnp.ndarray] = None,
) -> jnp.ndarray: ...
def sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]: ...
Import
from transformer_engine.jax.permutation import token_dispatch, token_combine, sort_chunks_by_index
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | jnp.ndarray |
Yes | Input token tensor of shape [num_tokens, hidden_size] |
| routing_map | jnp.ndarray |
Yes | Boolean routing map of shape [num_tokens, num_experts] |
| num_out_tokens | int |
Yes | Number of output tokens after dispatch |
| probs | Optional[jnp.ndarray] |
No | Routing probabilities for weighted combine |
| align_size | Optional[int] |
No | Padding alignment size |
Outputs
| Name | Type | Description |
|---|---|---|
| dispatched_tokens | jnp.ndarray |
Tokens dispatched to experts |
| merging_probs | Optional[jnp.ndarray] |
Probabilities for token combining |
| row_id_map | jnp.ndarray |
Mapping of token positions for combining |
| pad_offsets | Optional[jnp.ndarray] |
Padding offsets for unpadding |
| expert_offsets | jnp.ndarray |
Per-expert token count offsets |
Usage Examples
from transformer_engine.jax.permutation import token_dispatch, token_combine
# Dispatch tokens to experts
dispatched, probs_out, row_id_map, pad_offsets, expert_offsets = token_dispatch(
inp=tokens,
routing_map=routing_decisions,
num_out_tokens=total_expert_capacity,
probs=routing_probs,
)
# Run expert computation...
expert_output = run_experts(dispatched, expert_offsets)
# Combine tokens back
output = token_combine(expert_output, row_id_map, probs_out, pad_offsets)