Implementation:Sgl project Sglang Expert Specialization
| Knowledge Sources | |
|---|---|
| Domains | Kernel, Mixture of Experts, Quantization |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python interface for expert-specialized grouped matrix multiplication and quantization kernels, targeting FP8 blockwise-scaled and MXFP8 block-scaled formats.
Description
The expert_specialization.py module wraps three C++ ops for expert-specialized MoE inference. es_fp8_blockwise_scaled_grouped_mm performs FP8 blockwise-scaled grouped GEMM, taking activation and weight tensors with per-block scaling factors, stride information, problem sizes, and expert offsets. es_sm100_mxfp8_blockscaled_grouped_mm performs grouped GEMM on SM100 (Blackwell) GPUs using the MXFP8 (microscaling FP8) format with block-scaled factors (sfa and sfb). es_sm100_mxfp8_blockscaled_grouped_quant quantizes input tensors into MXFP8 format with per-block scale factors for use in subsequent grouped GEMM operations. All functions delegate to torch.ops.sgl_kernel.* C++ ops.
Usage
Use these functions for next-generation expert-specialized MoE inference, particularly on Blackwell (SM100) GPUs using MXFP8 format, or on Hopper GPUs using FP8 blockwise scaling.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/expert_specialization.py
- Lines: 1-51
Signature
def es_fp8_blockwise_scaled_grouped_mm(
output, a, b, scales_a, scales_b,
stride_a, stride_b, stride_d,
problem_sizes, expert_offsets, workspace,
): ...
def es_sm100_mxfp8_blockscaled_grouped_mm(
output, a, b, sfa, sfb,
problem_sizes, expert_offsets, blockscale_offsets,
): ...
def es_sm100_mxfp8_blockscaled_grouped_quant(
input, problem_sizes, expert_offsets,
blockscale_offsets, quant_output, scale_factor,
): ...
Import
from sgl_kernel.expert_specialization import (
es_fp8_blockwise_scaled_grouped_mm,
es_sm100_mxfp8_blockscaled_grouped_mm,
es_sm100_mxfp8_blockscaled_grouped_quant,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| output | torch.Tensor | Yes | Pre-allocated output tensor for GEMM result |
| a | torch.Tensor | Yes | Activation tensor (FP8 or MXFP8 format) |
| b | torch.Tensor | Yes | Weight tensor (FP8 or MXFP8 format) |
| scales_a | torch.Tensor | Yes (FP8) | Per-block scale factors for activations |
| scales_b | torch.Tensor | Yes (FP8) | Per-block scale factors for weights |
| sfa | torch.Tensor | Yes (MXFP8) | MXFP8 scale factors for activations |
| sfb | torch.Tensor | Yes (MXFP8) | MXFP8 scale factors for weights |
| problem_sizes | torch.Tensor | Yes | Per-expert M x N x K problem sizes |
| expert_offsets | torch.Tensor | Yes | Starting index offsets for each expert |
| blockscale_offsets | torch.Tensor | Yes (MXFP8) | Block scale offset indices |
| workspace | torch.Tensor | Yes (FP8) | Workspace buffer for FP8 GEMM |
| input | torch.Tensor | Yes (quant) | Input tensor to quantize |
| quant_output | torch.Tensor | Yes (quant) | Pre-allocated quantized output |
| scale_factor | torch.Tensor | Yes (quant) | Pre-allocated scale factor output |
Outputs
| Name | Type | Description |
|---|---|---|
| (in-place) | - | GEMM result written to output tensor |
| (in-place) | - | Quantized data written to quant_output and scale_factor |
Usage Examples
from sgl_kernel.expert_specialization import (
es_fp8_blockwise_scaled_grouped_mm,
es_sm100_mxfp8_blockscaled_grouped_mm,
)
# FP8 blockwise grouped GEMM
es_fp8_blockwise_scaled_grouped_mm(
output, a, b, scales_a, scales_b,
stride_a, stride_b, stride_d,
problem_sizes, expert_offsets, workspace
)
# MXFP8 block-scaled grouped GEMM (Blackwell)
es_sm100_mxfp8_blockscaled_grouped_mm(
output, a, b, sfa, sfb,
problem_sizes, expert_offsets, blockscale_offsets
)