Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Triton Permutation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements Triton JIT kernels for efficient token permutation and unpermutation operations used in Mixture-of-Experts (MoE) routing, including row ID mapping, argsort, and forward/backward permutation passes.

Description

triton/permutation.py provides several Triton JIT kernels essential for MoE architectures:

  • Bitonic sort: _compare_and_swap, _bitonic_merge, and _argsort implement a bitonic sort for argsort operations within Triton.
  • Row ID mapping: _row_id_map_pass_1_kernel and _row_id_map_pass_2_kernel build a row ID mapping from routing maps using a two-pass prefix sum approach (per-block cumsum then cross-block aggregation).
  • Permutation: _permute_kernel reorders token rows based on routing maps, supporting probs-based weighted permutation and top-k expert routing.
  • Unpermutation: _unpermute_kernel reverses the permutation to restore original token order.
  • Backward passes: Additional kernels handle gradient propagation for both permutation and unpermutation.

Usage

Use these kernels in MoE layers where tokens must be efficiently routed to and gathered from different expert subnetworks based on a learned routing policy.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/common/triton/permutation.py
Lines
1--642

Signature

@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
    ...

@triton.jit
def _row_id_map_pass_1_kernel(routing_map_ptr, num_tokens,
                              stride_routing_map_token, ...,
                              row_id_map_ptr, workspace_ptr,
                              BLOCK_SIZE: tl.constexpr):
    ...

@triton.jit
def _permute_kernel(input_ptr, output_ptr, sorted_row_id_ptr, ...):
    ...

@triton.jit
def _unpermute_kernel(input_ptr, output_ptr, row_id_map_ptr, ...):
    ...

Import

from transformer_engine.common.triton.permutation import (
    _permute_kernel,
    _unpermute_kernel,
    _row_id_map_pass_1_kernel,
    _row_id_map_pass_2_kernel,
)

I/O Contract

Inputs

Name Type Required Description
routing_map_ptr pointer Yes Routing map from MoE gating
input_ptr pointer Yes Input token tensor
num_tokens int Yes Number of tokens in the batch
BLOCK_SIZE tl.constexpr Yes Triton block size

Outputs

Name Type Description
output_ptr pointer Permuted (or unpermuted) token tensor
row_id_map_ptr pointer Row ID mapping for reverse permutation

Usage Examples

import triton
from transformer_engine.common.triton.permutation import _permute_kernel

# Launch permutation kernel
grid = (num_experts,)
_permute_kernel[grid](
    input_ptr, output_ptr, sorted_row_id_ptr,
    row_id_map_ptr, prob_ptr, prob_grad_ptr,
    input_fwd_ptr, num_rows, topK, num_cols,
    num_out_tokens, BLOCK_SIZE=128,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment