Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang MoE Ops

From Leeroopedia


Knowledge Sources
Domains GPU Kernels, Mixture of Experts, LLM Inference
Last Updated 2026-02-10 00:00 GMT

Overview

Python interface for Mixture-of-Experts (MoE) routing, gating, grouped computation, and reduction kernels used in models like Mixtral, DeepSeek, and Kimi K2.

Description

moe.py provides the complete infrastructure for MoE computation in LLM inference, covering expert selection, token dispatch, grouped matrix multiplication, and output reduction. All functions delegate to torch.ops.sgl_kernel.* C++ CUDA ops.

Token-to-Expert Alignment:

  • moe_align_block_size -- Sorts and aligns token IDs to fixed block sizes for efficient expert computation. Takes topk_ids and produces sorted_token_ids, experts_ids, and num_tokens_post_pad with a cumsum_buffer for prefix sums. Supports optional padding of sorted token IDs.

Gating Functions:

  • topk_softmax -- Computes top-k softmax for MoE routing, producing topk_weights and topk_ids from gating_output logits. Supports renormalization, tanh softcapping (moe_softcapping), and per-expert correction bias (float32).
  • topk_sigmoid -- Top-k sigmoid variant for routing, with renormalization and correction bias support.
  • moe_fused_gate -- Hierarchical 2-layer expert selection: splits experts into num_expert_group groups, selects top groups by summed weights, then selects topk experts within those groups. Supports fused shared experts (num_fused_shared_experts) and routed scaling factor. Limited to power-of-2 expert counts with at most 32 experts per group.
  • kimi_k2_moe_fused_gate -- Simplified fused gate for the Kimi K2 model (single expert group), removing the grouped topk logic and adding renormalization support.

Output Reduction:

  • moe_sum_reduce -- Reduces MoE expert outputs with optional routed scaling factor.
  • moe_sum -- Simple summation of MoE expert outputs.

Grouped Matrix Multiplication:

  • fp8_blockwise_scaled_grouped_mm -- FP8 blockwise scaled grouped matrix multiplication using CUTLASS. Takes pointer arrays (a_ptrs, b_ptrs, out_ptrs), scale factor pointer arrays, stride information, and problem sizes per expert.
  • cutlass_fp4_group_mm -- CUTLASS FP4 blockscaled group GEMM for NVFP4 quantized MoE forward. Takes FP4 inputs with block scales, global alphas, and per-expert strides/offsets from a params dictionary.

MoE Input Preparation:

  • prepare_moe_input -- Prepares input/output permutations and problem sizes for grouped MM operations. Takes topk_ids and produces expert_offsets, problem_sizes1/problem_sizes2, input_permutation, and output_permutation. Optionally computes blockscale_offsets.

Permutation and Reduction:

  • apply_shuffle_mul_sum -- Applies a permutation to input, multiplies by optional factors, and sums into the output tensor.

Fused QK Normalization + RoPE:

  • fused_qk_norm_rope -- Fuses QK normalization and rotary position embedding into a single kernel. Takes a combined qkv tensor with head counts for Q, K, V, a head dimension, normalization weights, and RoPE parameters (base, factor, low/high frequency, attention factor). The rotary_dim defaults to head_dim if not specified.

Usage

Use these functions when implementing MoE model inference. topk_softmax or topk_sigmoid handles expert selection, moe_align_block_size prepares sorted token dispatching, grouped MM functions execute expert computation, and moe_sum/'moe_sum_reduce aggregates results.

Code Reference

Source Location

Signature

def moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids,
    experts_ids, num_tokens_post_pad, cumsum_buffer, pad_sorted_token_ids=False):

def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
    gating_output: torch.Tensor, renormalize: bool = False,
    moe_softcapping: float = 0.0, correction_bias: Optional[torch.Tensor] = None) -> None:

def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
    gating_output: torch.Tensor, renormalize: bool = False,
    correction_bias: Optional[torch.Tensor] = None) -> None:

def moe_sum_reduce(input_tensor, output_tensor, routed_scaling_factor=0):

def moe_sum(input_tensor: torch.Tensor, output_tensor: torch.Tensor):

def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk,
    num_fused_shared_experts=0, routed_scaling_factor=0,
    apply_routed_scaling_factor_on_output=False):

def kimi_k2_moe_fused_gate(input_tensor, bias, topk, renormalize=True,
    routed_scaling_factor=1.0, apply_routed_scaling_factor_on_output=False):

def fp8_blockwise_scaled_grouped_mm(output, a_ptrs, b_ptrs, out_ptrs,
    a_scales_ptrs, b_scales_ptrs, a, b, scales_a, scales_b,
    stride_a, stride_b, stride_c, layout_sfa, layout_sfb,
    problem_sizes, expert_offsets, workspace):

def prepare_moe_input(topk_ids, expert_offsets, problem_sizes1, problem_sizes2,
    input_permutation, output_permutation, num_experts, n, k,
    blockscale_offsets: Optional[torch.Tensor] = None):

def apply_shuffle_mul_sum(input, output, permutation, factors):

def fused_qk_norm_rope(qkv: torch.Tensor, num_heads_q: int, num_heads_k: int,
    num_heads_v: int, head_dim: int, eps: float, q_weight: torch.Tensor,
    k_weight: torch.Tensor, base: float, is_neox: bool, position_ids: torch.Tensor,
    factor: float, low: float, high: float, attention_factor: float,
    rotary_dim: Optional[int] = None) -> None:

def cutlass_fp4_group_mm(a_fp4, b_fp4, a_blockscale, b_blockscale,
    alphas, out_dtype, device, params: Dict[str, Any]):

Import

from sgl_kernel import topk_softmax, moe_align_block_size, moe_sum
from sgl_kernel import fp8_blockwise_scaled_grouped_mm, prepare_moe_input

I/O Contract

Inputs

Name Type Required Description
topk_ids torch.Tensor Yes Token-to-expert assignments: (num_tokens, topk)
gating_output torch.Tensor Yes Router logits: (num_tokens, num_experts)
num_experts int Yes Total number of experts
block_size int Yes Block size for token alignment
topk int Yes Number of experts selected per token
renormalize bool No Whether to renormalize top-k weights
moe_softcapping float No Tanh softcapping value (0.0 to disable)
correction_bias torch.Tensor (float32) No Per-expert bias correction: (num_experts,)
routed_scaling_factor float No Scaling factor for routed expert weights

Outputs

Name Type Description
topk_weights torch.Tensor Expert routing weights: (num_tokens, topk)
topk_ids torch.Tensor Expert indices: (num_tokens, topk)
sorted_token_ids torch.Tensor Block-aligned sorted token IDs
output torch.Tensor Grouped MM result or reduced MoE output

Usage Examples

from sgl_kernel import topk_softmax, moe_align_block_size, moe_sum
import torch

num_tokens, num_experts, topk = 64, 8, 2

# Step 1: Expert gating
topk_weights = torch.empty(num_tokens, topk, device="cuda", dtype=torch.float32)
topk_ids = torch.empty(num_tokens, topk, device="cuda", dtype=torch.int32)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32)

topk_softmax(topk_weights, topk_ids, gating_output, renormalize=True)

# Step 2: Align tokens to block size for expert computation
block_size = 128
sorted_token_ids = torch.empty(num_tokens * topk + num_experts * block_size,
                                device="cuda", dtype=torch.int32)
experts_ids = torch.empty(num_experts * block_size, device="cuda", dtype=torch.int32)
num_tokens_post_pad = torch.empty(1, device="cuda", dtype=torch.int32)
cumsum_buffer = torch.empty(num_experts + 1, device="cuda", dtype=torch.int32)

moe_align_block_size(topk_ids, num_experts, block_size,
    sorted_token_ids, experts_ids, num_tokens_post_pad, cumsum_buffer)

# Step 3: Reduce expert outputs
expert_output = torch.randn(num_tokens, topk, 4096, device="cuda", dtype=torch.bfloat16)
final_output = torch.zeros(num_tokens, 4096, device="cuda", dtype=torch.bfloat16)
moe_sum(expert_output, final_output)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment