Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm Marlin Template

From Leeroopedia


Knowledge Sources
Domains Quantization, Marlin, GEMM
Last Updated 2026-02-08 00:00 GMT

Overview

Core Marlin GEMM kernel template implementing high-performance mixed-precision quantized matrix multiplication using tensor cores, async copy pipelines, and shared memory optimization on Turing/Ampere/Hopper GPUs.

Description

This header implements the Marlin global kernel function template, which performs A@B matrix multiplication where A is in FP16/BF16/INT8/FP8 and B is quantized in INT4/INT8/FP4/FP8 formats. The kernel uses multi-stage asynchronous global-to-shared memory copy pipelines, ldmatrix instructions for loading operand fragments directly in tensor core layout, and grouped/channelwise dequantization with optional activation reordering (act_order). It includes helper device functions for ldsm (loading shared memory matrix fragments), scale (applying quantization scales), sub_zp (subtracting zero points), and scale4 (per-element scaling for act_order). The kernel is parameterized on scalar types, thread block dimensions, pipeline stages, group block sizes, and M-block configuration.

Usage

This template is instantiated by the auto-generated kernel files produced by generate_kernels.py. Each instantiation specifies concrete template parameters for a specific quantization format, thread configuration, and target architecture. The resulting kernels are called from the Marlin Python interface during quantized model inference.

Code Reference

Source Location

Signature

namespace MARLIN_NAMESPACE_NAME {

template <typename scalar_t,             // compute dtype (half or nv_bfloat16)
          const vllm::ScalarTypeId b_type_id,  // weight quantization type
          const vllm::ScalarTypeId s_type_id,  // scale type
          const int threads,             // threads per threadblock
          const int thread_m_blocks,     // 16x16 blocks in M dimension
          const int thread_n_blocks,     // 16x16 blocks in N dimension
          const int thread_k_blocks,     // 16x16 blocks in K dimension
          const bool m_block_size_8,     // M block size of 8 (half-block)
          const int stages,              // async pipeline stages
          const bool has_act_order,      // activation reordering enabled
          const int group_blocks,        // consecutive blocks per scale group
          const bool is_zp_float         // float16 zero points
          >
__global__ void Marlin(
    const int4* __restrict__ A,          // input matrix (mxk)
    const int4* __restrict__ B,          // quantized weight matrix (kxn)
    int4* __restrict__ C,                // output buffer (mxn)
    int4* __restrict__ C_tmp,            // fp32 tmp output buffer (for reduce)
    const int4* __restrict__ scales_ptr, // quantization scales ((k/groupsize)xn)
    const int* __restrict__ g_idx,       // group indices (k)
    int num_groups,
    int prob_m,
    int prob_n,
    int prob_k,
    int* locks,
    bool use_fp32_reduce
);

// Load 16x16 matrix fragment from shared memory
template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(
    typename MarlinScalarType<type_id>::FragA& frag_a,
    const void* smem_ptr);

// Scale dequantized values by quantization scale
template <vllm::ScalarTypeId type_id>
__device__ inline void scale(
    typename MarlinScalarType<type_id>::FragB& frag_b,
    typename MarlinScalarType<type_id>::FragS& frag_s,
    int i);

// Scale and subtract zero point
template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub(
    typename MarlinScalarType<type_id>::FragB& frag_b,
    typename MarlinScalarType<type_id>::scalar_t s,
    typename MarlinScalarType<type_id>::scalar_t zp);

// Subtract zero point from dequantized fragment
template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(
    typename MarlinScalarType<type_id>::FragB& frag_b,
    typename MarlinScalarType<type_id>::scalar_t2& frag_zp,
    int i);

} // namespace MARLIN_NAMESPACE_NAME

Import

#include "marlin_template.h"

I/O Contract

Inputs

Name Type Required Description
A const int4* Yes Input activation matrix of shape (m x k), in FP16/BF16/INT8/FP8
B const int4* Yes Quantized weight matrix of shape (k x n), packed in the target quantization format
scales_ptr const int4* Yes Quantization scale factors of shape ((k/groupsize) x n)
g_idx const int* No Group indices of shape (k), required when has_act_order is true
num_groups int Yes Number of quantization scale groups per output channel
prob_m int Yes Batch dimension M
prob_n int Yes Output dimension N
prob_k int Yes Reduction dimension K
locks int* Yes Global storage for barrier synchronization between threadblocks
use_fp32_reduce bool Yes Whether to use FP32 accumulation for the global reduction step

Outputs

Name Type Description
C int4* Output matrix of shape (m x n), in FP16/BF16 matching the compute type
C_tmp int4* Temporary FP32 output buffer used when use_fp32_reduce is true

Usage Examples

// Kernel instantiation (auto-generated by generate_kernels.py)
template __global__ void Marlin<
    vllm::kFloat16.id(),  // a_type_id
    vllm::kU4B8.id(),     // b_type_id (GPTQ INT4)
    vllm::kFloat16.id(),  // c_type_id
    vllm::kFloat16.id(),  // s_type_id
    256,                   // threads
    1,                     // thread_m_blocks
    8,                     // thread_n_blocks
    8,                     // thread_k_blocks
    false,                 // m_block_size_8
    4,                     // stages
    8,                     // group_blocks
    false                  // is_zp_float
>(MARLIN_KERNEL_PARAMS);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment