Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm Marlin MoE Template

From Leeroopedia


Knowledge Sources
Domains MoE, Quantization, GEMM, CUDA Kernels
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the Marlin GEMM kernel template for Mixture-of-Experts workloads with quantized weights and 16-bit activations, targeting NVIDIA Turing (SM75) and later GPUs.

Description

This header contains the core CUDA kernel template for Marlin MoE GEMM operations, adapted from the IST-DASLab Marlin project and modified by Neural Magic. It extends the standard Marlin GEMM with MoE-specific features including sorted token IDs, expert routing via expert_ids, topk weight application, and atomic accumulation for parallel expert processing. The kernel supports multiple quantization formats (INT4/INT8/FP8/FP4) with configurable thread block sizes, pipeline stages, and group quantization parameters. It uses inline PTX assembly for ldmatrix instructions and tensor core MMA operations.

Usage

This header is included by the generated .cu kernel files produced by generate_kernels.py. It provides the Marlin kernel template that is explicitly instantiated with specific type and configuration parameters for each supported combination of activation type, weight type, and thread block configuration.

Code Reference

Source Location

Signature

namespace marlin_moe_wna16 {

template <typename scalar_t,
          const vllm::ScalarTypeId b_type_id,
          const int threads,
          const int thread_m_blocks,
          const int thread_n_blocks,
          const int thread_k_blocks,
          const bool m_block_size_8,
          const int stages,
          const bool has_act_order,
          const int group_blocks,
          const bool is_zp_float>
__global__ void Marlin(
    const int4* __restrict__ A,
    const int4* __restrict__ B,
    int4* __restrict__ C,
    int4* __restrict__ C_tmp,
    const int4* __restrict__ scales_ptr,
    const int4* __restrict__ zp_ptr,
    const int* __restrict__ g_idx,
    const int32_t* __restrict__ sorted_token_ids_ptr,
    const int32_t* __restrict__ expert_ids_ptr,
    const int32_t* __restrict__ num_tokens_past_padded_ptr,
    const float* __restrict__ topk_weights_ptr,
    int top_k,
    bool mul_topk_weights,
    int num_groups,
    int prob_m,
    int prob_n,
    int prob_k,
    int* locks,
    bool use_atomic_add,
    bool use_fp32_reduce
);

// Helper device functions
template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(FragA& frag_a, const void* smem_ptr);

template <vllm::ScalarTypeId type_id>
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i);

template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub(FragB& frag_b, scalar_t s, scalar_t zp);

template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(FragB& frag_b, scalar_t2& frag_zp, int i);

} // namespace marlin_moe_wna16

Import

#include "marlin_template.h"

I/O Contract

Inputs

Name Type Required Description
A const int4* Yes Input activation matrix of shape (M, K) in fp16/bf16/int8/fp8 format
B const int4* Yes Quantized weight matrix of shape (K, N) in packed format
scales_ptr const int4* Yes Quantization scales of shape (K/groupsize, N) in fp16/bf16
zp_ptr const int4* No Packed zero-points of shape (K/groupsize, N/pack_factor)
g_idx const int* No Group indices of shape (K) for act_order reordering
sorted_token_ids_ptr const int32_t* Yes MoE sorted token IDs for expert routing
expert_ids_ptr const int32_t* Yes MoE expert IDs indicating which expert processes each token group
num_tokens_past_padded_ptr const int32_t* Yes Number of tokens past padding for each expert
topk_weights_ptr const float* No Top-k routing weights for weighted expert output combination
top_k int Yes Number of experts per token
prob_m, prob_n, prob_k int Yes Problem dimensions (batch, output, reduction)
locks int* Yes Global synchronization buffer for thread block coordination

Outputs

Name Type Description
C int4* Output buffer of shape (M, N) in fp16/bf16 format
C_tmp int4* Temporary fp32 output buffer used when fp32 reduction or atomic add is enabled

Usage Examples

// This template is instantiated by generated .cu files:
// (from sm80_kernel_float16_u4_float16.cu)
#include "kernel.h"
#include "marlin_template.h"

namespace marlin_moe_wna16 {

template __global__ void Marlin<
    vllm::kFloat16.id(),    // a_type_id
    vllm::kU4.id(),         // b_type_id
    vllm::kFloat16.id(),    // c_type_id
    vllm::kFloat16.id(),    // s_type_id
    256,                     // threads
    1,                       // thread_m_blocks
    8,                       // thread_n_blocks
    8,                       // thread_k_blocks
    false,                   // m_block_size_8
    4,                       // stages
    -1,                      // group_blocks
    false                    // is_zp_float
>(MARLIN_KERNEL_PARAMS);

} // namespace marlin_moe_wna16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment