Implementation:Vllm project Vllm Marlin MoE Generate Kernels
| Knowledge Sources | |
|---|---|
| Domains | MoE, Quantization, GEMM, Code Generation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Python script that generates CUDA kernel instantiations for Marlin MoE (Mixture-of-Experts) workloads with weight-only quantization (WNA16) across different thread configurations and GPU architectures.
Description
This script uses Jinja2 templates to generate architecture-specific CUDA source files containing explicit template instantiations of the Marlin GEMM kernel for MoE workloads. It supports multiple quantization formats including AWQ-INT4, GPTQ-INT4, AWQ-INT8, FP8, NVFP4, and MXFP4, with various activation types (FP16, BF16, INT8, FP8). The script also produces a kernel_selector.h file that maps runtime parameters to the correct kernel instantiation. This code generation approach reduces compilation time by splitting instantiations across multiple .cu files.
Usage
This script is executed during the build process with the target CUDA architecture list as a command-line argument. It generates .cu files and kernel_selector.h in the same directory, which are then compiled as part of the vLLM CUDA extension.
Code Reference
Source Location
- Repository: vllm
- File: csrc/moe/marlin_moe_wna16/generate_kernels.py
- Lines: 1-306
Signature
# Constants
ARCHS = [] # List of target GPU architectures
SUPPORT_FP8 = False # Whether SM89/SM120 FP8 is supported
SUPPORT_SM75 = False # Whether Turing (SM75) is targeted
SUPPORT_SM80 = False # Whether Ampere+ (SM80) is targeted
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
QUANT_CONFIGS = [...] # List of quantization configuration dicts
def remove_old_kernels():
"""Remove previously generated kernel files and selector header."""
def generate_new_kernels():
"""Generate .cu kernel files and kernel_selector.h for all configurations."""
Import
# This script is executed directly, not imported
# python csrc/moe/marlin_moe_wna16/generate_kernels.py "80,89"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| sys.argv[1] | string | Yes | Comma-separated list of CUDA architecture versions (e.g., "80,89,90") |
| QUANT_CONFIGS | list[dict] | Yes (hardcoded) | Quantization configurations defining b_type, thread_configs, thread_m_blocks, group_blocks |
| THREAD_CONFIGS | list[tuple] | Yes (hardcoded) | Thread block configurations as (thread_k, thread_n, threads) tuples |
| TEMPLATE | string | Yes (hardcoded) | Jinja2 template string for Marlin kernel instantiation |
Outputs
| Name | Type | Description |
|---|---|---|
| sm80_kernel_*.cu | CUDA source files | Generated kernel instantiation files for SM80+ architectures |
| sm75_kernel_*.cu | CUDA source files | Generated kernel instantiation files for SM75 (Turing) architecture |
| sm89_kernel_*.cu | CUDA source files | Generated kernel instantiation files for SM89 (FP8 activation) kernels |
| kernel_selector.h | C++ header | Generated if/else chain mapping runtime parameters to kernel function pointers |
Usage Examples
# Build-time invocation from command line
# python csrc/moe/marlin_moe_wna16/generate_kernels.py "80,89"
# The script generates files like:
# sm80_kernel_float16_u4_float16.cu
# sm80_kernel_float16_u4b8_float16.cu
# sm89_kernel_fe4m3fn_u4b8_float16.cu
# kernel_selector.h
# Generated kernel instantiation example (inside .cu file):
# template __global__ void Marlin<
# vllm::kFloat16.id(), vllm::kU4.id(), vllm::kFloat16.id(),
# vllm::kFloat16.id(), 256, 1, 8, 8, false, 4, -1, false
# >(MARLIN_KERNEL_PARAMS);