Implementation:Vllm project Vllm Marlin MoE Template
| Knowledge Sources | |
|---|---|
| Domains | MoE, Quantization, GEMM, CUDA Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the Marlin GEMM kernel template for Mixture-of-Experts workloads with quantized weights and 16-bit activations, targeting NVIDIA Turing (SM75) and later GPUs.
Description
This header contains the core CUDA kernel template for Marlin MoE GEMM operations, adapted from the IST-DASLab Marlin project and modified by Neural Magic. It extends the standard Marlin GEMM with MoE-specific features including sorted token IDs, expert routing via expert_ids, topk weight application, and atomic accumulation for parallel expert processing. The kernel supports multiple quantization formats (INT4/INT8/FP8/FP4) with configurable thread block sizes, pipeline stages, and group quantization parameters. It uses inline PTX assembly for ldmatrix instructions and tensor core MMA operations.
Usage
This header is included by the generated .cu kernel files produced by generate_kernels.py. It provides the Marlin kernel template that is explicitly instantiated with specific type and configuration parameters for each supported combination of activation type, weight type, and thread block configuration.
Code Reference
Source Location
- Repository: vllm
- File: csrc/moe/marlin_moe_wna16/marlin_template.h
- Lines: 1-2230
Signature
namespace marlin_moe_wna16 {
template <typename scalar_t,
const vllm::ScalarTypeId b_type_id,
const int threads,
const int thread_m_blocks,
const int thread_n_blocks,
const int thread_k_blocks,
const bool m_block_size_8,
const int stages,
const bool has_act_order,
const int group_blocks,
const bool is_zp_float>
__global__ void Marlin(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
int4* __restrict__ C_tmp,
const int4* __restrict__ scales_ptr,
const int4* __restrict__ zp_ptr,
const int* __restrict__ g_idx,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr,
const float* __restrict__ topk_weights_ptr,
int top_k,
bool mul_topk_weights,
int num_groups,
int prob_m,
int prob_n,
int prob_k,
int* locks,
bool use_atomic_add,
bool use_fp32_reduce
);
// Helper device functions
template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(FragA& frag_a, const void* smem_ptr);
template <vllm::ScalarTypeId type_id>
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i);
template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub(FragB& frag_b, scalar_t s, scalar_t zp);
template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(FragB& frag_b, scalar_t2& frag_zp, int i);
} // namespace marlin_moe_wna16
Import
#include "marlin_template.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | const int4* | Yes | Input activation matrix of shape (M, K) in fp16/bf16/int8/fp8 format |
| B | const int4* | Yes | Quantized weight matrix of shape (K, N) in packed format |
| scales_ptr | const int4* | Yes | Quantization scales of shape (K/groupsize, N) in fp16/bf16 |
| zp_ptr | const int4* | No | Packed zero-points of shape (K/groupsize, N/pack_factor) |
| g_idx | const int* | No | Group indices of shape (K) for act_order reordering |
| sorted_token_ids_ptr | const int32_t* | Yes | MoE sorted token IDs for expert routing |
| expert_ids_ptr | const int32_t* | Yes | MoE expert IDs indicating which expert processes each token group |
| num_tokens_past_padded_ptr | const int32_t* | Yes | Number of tokens past padding for each expert |
| topk_weights_ptr | const float* | No | Top-k routing weights for weighted expert output combination |
| top_k | int | Yes | Number of experts per token |
| prob_m, prob_n, prob_k | int | Yes | Problem dimensions (batch, output, reduction) |
| locks | int* | Yes | Global synchronization buffer for thread block coordination |
Outputs
| Name | Type | Description |
|---|---|---|
| C | int4* | Output buffer of shape (M, N) in fp16/bf16 format |
| C_tmp | int4* | Temporary fp32 output buffer used when fp32 reduction or atomic add is enabled |
Usage Examples
// This template is instantiated by generated .cu files:
// (from sm80_kernel_float16_u4_float16.cu)
#include "kernel.h"
#include "marlin_template.h"
namespace marlin_moe_wna16 {
template __global__ void Marlin<
vllm::kFloat16.id(), // a_type_id
vllm::kU4.id(), // b_type_id
vllm::kFloat16.id(), // c_type_id
vllm::kFloat16.id(), // s_type_id
256, // threads
1, // thread_m_blocks
8, // thread_n_blocks
8, // thread_k_blocks
false, // m_block_size_8
4, // stages
-1, // group_blocks
false // is_zp_float
>(MARLIN_KERNEL_PARAMS);
} // namespace marlin_moe_wna16