Implementation:Sgl project Sglang Marlin MoE WNA16 Template
| Knowledge Sources | |
|---|---|
| Domains | CUDA_Kernels, Quantized_GEMM, Mixture_of_Experts |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
CUDA kernel template for Marlin-based Mixture-of-Experts (MoE) weight-only NxA16 quantized matrix multiplication, combining optimized quantized GEMM with MoE token dispatching in a single fused kernel.
Description
The marlin_template.h in the marlin_moe_wna16 directory implements a variant of the Marlin GEMM kernel specialized for MoE workloads. It resides in the marlin_moe_wna16 namespace (via MARLIN_NAMESPACE_NAME) and shares the same core algorithmic structure as the standard Marlin kernel but adds MoE-specific parameters and routing logic.
The kernel template parameters include:
- scalar_t -- compute type (half or nv_bfloat16)
- w_type_id -- weight quantization type ID from sglang::ScalarTypeId
- threads / thread_m_blocks / thread_n_blocks / thread_k_blocks -- threadblock tiling
- m_block_size_8 -- 8-element M-block support
- stages -- async pipeline depth
- group_blocks -- quantization group size
- is_zp_float -- floating-point zero points
MoE-specific parameters in the kernel signature include:
- sorted_token_ids_ptr -- pre-sorted token indices for each expert
- expert_ids_ptr -- expert assignment for each token group
- num_tokens_past_padded_ptr -- padding-aware token count
- topk_weights_ptr -- per-token expert routing weights
- top_k -- number of experts per token
- mul_topk_weights -- whether to multiply output by routing weights
- is_ep -- expert parallelism mode flag
- use_atomic_add -- atomic reduction for accumulating contributions from multiple experts
- use_fp32_reduce -- FP32 global reduction for numerical stability
The kernel uses m16n8k16 tensor core MMA instructions and provides an empty stub for architectures below SM80.
Usage
This kernel is instantiated by the MoE dispatch code when running quantized MoE models (e.g., Mixtral, DeepSeek). It handles both the quantized GEMM computation and MoE token routing in a single kernel launch, avoiding the overhead of separate expert GEMM kernels.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h
- Lines: 1-1899
Signature
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
namespace marlin_moe_wna16 {
template <
typename scalar_t,
const sglang::ScalarTypeId w_type_id,
const int threads,
const int thread_m_blocks,
const int thread_n_blocks,
const int thread_k_blocks,
const bool m_block_size_8,
const int stages,
const int group_blocks,
const bool is_zp_float
>
__global__ void Marlin(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
int4* __restrict__ C_tmp,
const int4* __restrict__ scales_ptr,
const int4* __restrict__ zp_ptr,
const int* __restrict__ g_idx,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr,
const float* __restrict__ topk_weights_ptr,
int top_k,
bool mul_topk_weights,
bool is_ep,
int num_groups,
int prob_m,
int prob_n,
int prob_k,
int* locks,
bool use_atomic_add,
bool use_fp32_reduce,
int max_par
);
} // namespace marlin_moe_wna16
Import
// This file sets its own namespace before including shared headers
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#include "gemm/marlin/dequant.h"
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | const int4* | Yes | FP16/BF16 activation input matrix of shape [m, k] |
| B | const int4* | Yes | Packed quantized weight matrix of shape [k, n] per expert |
| scales_ptr | const int4* | Yes | Quantization scales of shape [k/groupsize, n] |
| zp_ptr | const int4* | No | Packed zero-points of shape [k/groupsize, n/pack_factor] |
| sorted_token_ids_ptr | const int32_t* | Yes | MoE sorted token indices per expert |
| expert_ids_ptr | const int32_t* | Yes | Expert ID assignments per token group |
| num_tokens_past_padded_ptr | const int32_t* | Yes | Token count with padding for each expert |
| topk_weights_ptr | const float* | Yes | Per-token expert routing weights |
| top_k | int | Yes | Number of experts selected per token |
| mul_topk_weights | bool | Yes | Whether to multiply output by routing weights |
| is_ep | bool | Yes | Expert parallelism mode flag |
| prob_m / prob_n / prob_k | int | Yes | GEMM problem dimensions |
Outputs
| Name | Type | Description |
|---|---|---|
| C | int4* | FP16/BF16 output buffer of shape [m, n] with accumulated expert contributions |
| C_tmp | int4* | FP32 temporary buffer for split-K or atomic reduction |
Usage Examples
// Launch MoE Marlin kernel for INT4 quantized experts
marlin_moe_wna16::Marlin<
half, // scalar_t
sglang::kU4B8.id(), // INT4 weight type
256, // threads
1, 4, 8, // m, n, k blocks
false, // m_block_size_8
4, // pipeline stages
-1, // group_blocks
false // is_zp_float
><<<grid, block, smem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr,
scales_ptr, zp_ptr, g_idx_ptr,
sorted_ids, expert_ids, num_tokens_padded,
topk_weights, top_k,
true, // mul_topk_weights
false, // is_ep
num_groups, prob_m, prob_n, prob_k,
locks, false, // use_atomic_add
true, // use_fp32_reduce
max_par
);