Implementation:Sgl project Sglang MoE Ops
| Knowledge Sources | |
|---|---|
| Domains | GPU Kernels, Mixture of Experts, LLM Inference |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python interface for Mixture-of-Experts (MoE) routing, gating, grouped computation, and reduction kernels used in models like Mixtral, DeepSeek, and Kimi K2.
Description
moe.py provides the complete infrastructure for MoE computation in LLM inference, covering expert selection, token dispatch, grouped matrix multiplication, and output reduction. All functions delegate to torch.ops.sgl_kernel.* C++ CUDA ops.
Token-to-Expert Alignment:
- moe_align_block_size -- Sorts and aligns token IDs to fixed block sizes for efficient expert computation. Takes topk_ids and produces sorted_token_ids, experts_ids, and num_tokens_post_pad with a cumsum_buffer for prefix sums. Supports optional padding of sorted token IDs.
Gating Functions:
- topk_softmax -- Computes top-k softmax for MoE routing, producing topk_weights and topk_ids from gating_output logits. Supports renormalization, tanh softcapping (moe_softcapping), and per-expert correction bias (float32).
- topk_sigmoid -- Top-k sigmoid variant for routing, with renormalization and correction bias support.
- moe_fused_gate -- Hierarchical 2-layer expert selection: splits experts into num_expert_group groups, selects top groups by summed weights, then selects topk experts within those groups. Supports fused shared experts (num_fused_shared_experts) and routed scaling factor. Limited to power-of-2 expert counts with at most 32 experts per group.
- kimi_k2_moe_fused_gate -- Simplified fused gate for the Kimi K2 model (single expert group), removing the grouped topk logic and adding renormalization support.
Output Reduction:
- moe_sum_reduce -- Reduces MoE expert outputs with optional routed scaling factor.
- moe_sum -- Simple summation of MoE expert outputs.
Grouped Matrix Multiplication:
- fp8_blockwise_scaled_grouped_mm -- FP8 blockwise scaled grouped matrix multiplication using CUTLASS. Takes pointer arrays (a_ptrs, b_ptrs, out_ptrs), scale factor pointer arrays, stride information, and problem sizes per expert.
- cutlass_fp4_group_mm -- CUTLASS FP4 blockscaled group GEMM for NVFP4 quantized MoE forward. Takes FP4 inputs with block scales, global alphas, and per-expert strides/offsets from a params dictionary.
MoE Input Preparation:
- prepare_moe_input -- Prepares input/output permutations and problem sizes for grouped MM operations. Takes topk_ids and produces expert_offsets, problem_sizes1/problem_sizes2, input_permutation, and output_permutation. Optionally computes blockscale_offsets.
Permutation and Reduction:
- apply_shuffle_mul_sum -- Applies a permutation to input, multiplies by optional factors, and sums into the output tensor.
Fused QK Normalization + RoPE:
- fused_qk_norm_rope -- Fuses QK normalization and rotary position embedding into a single kernel. Takes a combined qkv tensor with head counts for Q, K, V, a head dimension, normalization weights, and RoPE parameters (base, factor, low/high frequency, attention factor). The rotary_dim defaults to head_dim if not specified.
Usage
Use these functions when implementing MoE model inference. topk_softmax or topk_sigmoid handles expert selection, moe_align_block_size prepares sorted token dispatching, grouped MM functions execute expert computation, and moe_sum/'moe_sum_reduce aggregates results.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/moe.py
- Lines: 1-336
Signature
def moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids,
experts_ids, num_tokens_post_pad, cumsum_buffer, pad_sorted_token_ids=False):
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
gating_output: torch.Tensor, renormalize: bool = False,
moe_softcapping: float = 0.0, correction_bias: Optional[torch.Tensor] = None) -> None:
def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
gating_output: torch.Tensor, renormalize: bool = False,
correction_bias: Optional[torch.Tensor] = None) -> None:
def moe_sum_reduce(input_tensor, output_tensor, routed_scaling_factor=0):
def moe_sum(input_tensor: torch.Tensor, output_tensor: torch.Tensor):
def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk,
num_fused_shared_experts=0, routed_scaling_factor=0,
apply_routed_scaling_factor_on_output=False):
def kimi_k2_moe_fused_gate(input_tensor, bias, topk, renormalize=True,
routed_scaling_factor=1.0, apply_routed_scaling_factor_on_output=False):
def fp8_blockwise_scaled_grouped_mm(output, a_ptrs, b_ptrs, out_ptrs,
a_scales_ptrs, b_scales_ptrs, a, b, scales_a, scales_b,
stride_a, stride_b, stride_c, layout_sfa, layout_sfb,
problem_sizes, expert_offsets, workspace):
def prepare_moe_input(topk_ids, expert_offsets, problem_sizes1, problem_sizes2,
input_permutation, output_permutation, num_experts, n, k,
blockscale_offsets: Optional[torch.Tensor] = None):
def apply_shuffle_mul_sum(input, output, permutation, factors):
def fused_qk_norm_rope(qkv: torch.Tensor, num_heads_q: int, num_heads_k: int,
num_heads_v: int, head_dim: int, eps: float, q_weight: torch.Tensor,
k_weight: torch.Tensor, base: float, is_neox: bool, position_ids: torch.Tensor,
factor: float, low: float, high: float, attention_factor: float,
rotary_dim: Optional[int] = None) -> None:
def cutlass_fp4_group_mm(a_fp4, b_fp4, a_blockscale, b_blockscale,
alphas, out_dtype, device, params: Dict[str, Any]):
Import
from sgl_kernel import topk_softmax, moe_align_block_size, moe_sum
from sgl_kernel import fp8_blockwise_scaled_grouped_mm, prepare_moe_input
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| topk_ids | torch.Tensor | Yes | Token-to-expert assignments: (num_tokens, topk) |
| gating_output | torch.Tensor | Yes | Router logits: (num_tokens, num_experts) |
| num_experts | int | Yes | Total number of experts |
| block_size | int | Yes | Block size for token alignment |
| topk | int | Yes | Number of experts selected per token |
| renormalize | bool | No | Whether to renormalize top-k weights |
| moe_softcapping | float | No | Tanh softcapping value (0.0 to disable) |
| correction_bias | torch.Tensor (float32) | No | Per-expert bias correction: (num_experts,) |
| routed_scaling_factor | float | No | Scaling factor for routed expert weights |
Outputs
| Name | Type | Description |
|---|---|---|
| topk_weights | torch.Tensor | Expert routing weights: (num_tokens, topk) |
| topk_ids | torch.Tensor | Expert indices: (num_tokens, topk) |
| sorted_token_ids | torch.Tensor | Block-aligned sorted token IDs |
| output | torch.Tensor | Grouped MM result or reduced MoE output |
Usage Examples
from sgl_kernel import topk_softmax, moe_align_block_size, moe_sum
import torch
num_tokens, num_experts, topk = 64, 8, 2
# Step 1: Expert gating
topk_weights = torch.empty(num_tokens, topk, device="cuda", dtype=torch.float32)
topk_ids = torch.empty(num_tokens, topk, device="cuda", dtype=torch.int32)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32)
topk_softmax(topk_weights, topk_ids, gating_output, renormalize=True)
# Step 2: Align tokens to block size for expert computation
block_size = 128
sorted_token_ids = torch.empty(num_tokens * topk + num_experts * block_size,
device="cuda", dtype=torch.int32)
experts_ids = torch.empty(num_experts * block_size, device="cuda", dtype=torch.int32)
num_tokens_post_pad = torch.empty(1, device="cuda", dtype=torch.int32)
cumsum_buffer = torch.empty(num_experts + 1, device="cuda", dtype=torch.int32)
moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids, num_tokens_post_pad, cumsum_buffer)
# Step 3: Reduce expert outputs
expert_output = torch.randn(num_tokens, topk, 4096, device="cuda", dtype=torch.bfloat16)
final_output = torch.zeros(num_tokens, 4096, device="cuda", dtype=torch.bfloat16)
moe_sum(expert_output, final_output)