Overview
A collection of TVM-based Mixture of Experts (MoE) utility operators for token routing, cumulative summation, index computation, and output scattering within the MLC LLM framework.
Description
The moe_misc module provides low-level, performance-critical operators needed by MoE layers in large language models. These operators handle the data movement and routing logic that distributes tokens to the appropriate experts and reassembles the results. The module implements several functions:
- moe_sum: A specialized sum operator optimized for the MoE case where the second dimension equals the number of experts per token (commonly 2). It avoids generic summation overhead.
- gating_topk: A TIR (Tensor IR) kernel that selects the top-k experts from gating scores using an insertion-sort-like algorithm executed within GPU thread blocks.
- gating_softmax_topk: Combines softmax normalization with top-k selection in a single fused kernel, optionally normalizing the resulting expert weights.
- group_limited_greedy_topk: Implements group-limited greedy top-k selection used by advanced MoE architectures like DeepSeek-V2. Supports both "group_limited_greedy" and "noaux_tc" methods. Expert selection is constrained within groups, and only a subset of groups contribute experts.
- moe_cumsum: Converts expert indices into a cumulative sum array used for determining contiguous memory regions for each expert's inputs.
- get_indices: Computes shuffling and reverse-scatter indices from the cumsum array, enabling efficient token redistribution.
- get_indptr: Extracts the indptr array (index pointers) from the cumsum, marking expert boundaries in the flattened array.
- scatter_output: Scatters expert outputs back to their original token positions after expert computation.
All operators are implemented using TVM's Tensor IR (TIR) or Tensor Expression (TE) and are designed to be compiled for GPU execution with thread-level parallelism.
Usage
Use these operators when building MoE model architectures in MLC LLM. They are called internally by MoE layer implementations (such as Mixtral and DeepSeek MoE) to handle token-to-expert routing, the gather/scatter data movement patterns, and the final weighted aggregation of expert outputs. They are not typically called directly by end users but rather composed within model definition modules.
Code Reference
Source Location
Signature
def moe_sum(x: Tensor, dim: int) -> Tensor
def gating_topk(scores: Tensor, k: int) -> Tuple[Tensor, Tensor]
def gating_softmax_topk(x: Tensor, k: int, norm_topk_prob=True) -> Tuple[Tensor, Tensor]
def group_limited_greedy_topk(
scores: Tensor,
top_k: int,
num_routed_experts: int,
n_group: int,
topk_group: int,
topk_method: Literal["group_limited_greedy", "noaux_tc"],
num_tokens: IntExpr,
e_score_correction_bias: Optional[Tensor],
) -> Tuple[Tensor, Tensor]
def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor
def get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tuple[Tensor, Tensor]
def get_indptr(
cumsum: Tensor,
num_local_experts: int,
batch_size: Union[int, tir.Var],
inclusive: bool,
out_dtype: str,
) -> Tensor
def scatter_output(x: Tensor, indices: Tensor) -> Tensor
Import
from mlc_llm.op.moe_misc import (
moe_sum,
gating_topk,
gating_softmax_topk,
group_limited_greedy_topk,
moe_cumsum,
get_indices,
get_indptr,
scatter_output,
)
I/O Contract
moe_sum
| Parameter |
Type |
Description
|
| x |
Tensor |
Input tensor of shape [batch_size, num_experts_per_tok, hidden_size]
|
| dim |
int |
Axis along which to sum (typically 1)
|
| Return |
Type |
Description
|
| result |
Tensor |
Summed tensor of shape [batch_size, hidden_size]
|
gating_topk
| Parameter |
Type |
Description
|
| scores |
Tensor |
Gating scores with shape [batch_size, num_local_experts]
|
| k |
int |
Number of top experts to select (num_experts_per_tok)
|
| Return |
Type |
Description
|
| expert_weights |
Tensor |
Top-k expert scores with shape [batch_size, k]
|
| expert_indices |
Tensor |
Top-k expert indices with shape [batch_size, k], dtype int32
|
gating_softmax_topk
| Parameter |
Type |
Description
|
| x |
Tensor |
Input tensor with shape [batch_size, num_local_experts]
|
| k |
int |
Number of top elements to select
|
| norm_topk_prob |
bool |
Whether to normalize top-k expert scores (default True)
|
| Return |
Type |
Description
|
| expert_weights |
Tensor |
Normalized top-k expert scores [batch_size, k]
|
| expert_indices |
Tensor |
Top-k expert indices [batch_size, k], dtype int32
|
moe_cumsum
| Parameter |
Type |
Description
|
| expert_indices |
Tensor |
Topk indices [batch_size, experts_per_tok], int32
|
| num_local_experts |
int |
Total number of experts
|
| Return |
Type |
Description
|
| cumsum |
Tensor |
Cumsum result [num_local_experts * batch_size], int32
|
get_indices
| Parameter |
Type |
Description
|
| cumsum |
Tensor |
Flattened 1D cumsum tensor
|
| expert_indices |
Tensor |
Expert indices [batch_size, experts_per_tok]
|
| Return |
Type |
Description
|
| reverse_indices |
Tensor |
Scatter indices [batch_size * experts_per_tok], int32
|
| token_indices |
Tensor |
Shuffling indices [batch_size * experts_per_tok], int32
|
scatter_output
| Parameter |
Type |
Description
|
| x |
Tensor |
Expert output [batch_size * num_experts_per_tok, hidden_size]
|
| indices |
Tensor |
Scatter indices [batch_size * num_experts_per_tok]
|
| Return |
Type |
Description
|
| out |
Tensor |
Scattered output [batch_size * num_experts_per_tok, hidden_size]
|
Usage Examples
# Top-k expert selection from gating scores
expert_weights, expert_indices = gating_topk(gating_scores, k=2)
# Fused softmax + top-k with normalization
expert_weights, expert_indices = gating_softmax_topk(logits, k=2, norm_topk_prob=True)
# Compute cumsum for MoE routing
cumsum = moe_cumsum(expert_indices, num_local_experts=8)
# Get shuffling and reverse indices
reverse_indices, token_indices = get_indices(cumsum, expert_indices)
# Extract expert boundary pointers
indptr = get_indptr(cumsum, num_local_experts=8, batch_size=batch_size, inclusive=False, out_dtype="int32")
# Scatter expert outputs back to original positions
output = scatter_output(expert_output, reverse_indices)
# Weighted sum of expert outputs
result = moe_sum(weighted_expert_outputs, dim=1)
Related Pages