Implementation:Vllm project Vllm Marlin Template
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Marlin, GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Core Marlin GEMM kernel template implementing high-performance mixed-precision quantized matrix multiplication using tensor cores, async copy pipelines, and shared memory optimization on Turing/Ampere/Hopper GPUs.
Description
This header implements the Marlin global kernel function template, which performs A@B matrix multiplication where A is in FP16/BF16/INT8/FP8 and B is quantized in INT4/INT8/FP4/FP8 formats. The kernel uses multi-stage asynchronous global-to-shared memory copy pipelines, ldmatrix instructions for loading operand fragments directly in tensor core layout, and grouped/channelwise dequantization with optional activation reordering (act_order). It includes helper device functions for ldsm (loading shared memory matrix fragments), scale (applying quantization scales), sub_zp (subtracting zero points), and scale4 (per-element scaling for act_order). The kernel is parameterized on scalar types, thread block dimensions, pipeline stages, group block sizes, and M-block configuration.
Usage
This template is instantiated by the auto-generated kernel files produced by generate_kernels.py. Each instantiation specifies concrete template parameters for a specific quantization format, thread configuration, and target architecture. The resulting kernels are called from the Marlin Python interface during quantized model inference.
Code Reference
Source Location
- Repository: vllm
- File: csrc/quantization/marlin/marlin_template.h
- Lines: 1-2073
Signature
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype (half or nv_bfloat16)
const vllm::ScalarTypeId b_type_id, // weight quantization type
const vllm::ScalarTypeId s_type_id, // scale type
const int threads, // threads per threadblock
const int thread_m_blocks, // 16x16 blocks in M dimension
const int thread_n_blocks, // 16x16 blocks in N dimension
const int thread_k_blocks, // 16x16 blocks in K dimension
const bool m_block_size_8, // M block size of 8 (half-block)
const int stages, // async pipeline stages
const bool has_act_order, // activation reordering enabled
const int group_blocks, // consecutive blocks per scale group
const bool is_zp_float // float16 zero points
>
__global__ void Marlin(
const int4* __restrict__ A, // input matrix (mxk)
const int4* __restrict__ B, // quantized weight matrix (kxn)
int4* __restrict__ C, // output buffer (mxn)
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // quantization scales ((k/groupsize)xn)
const int* __restrict__ g_idx, // group indices (k)
int num_groups,
int prob_m,
int prob_n,
int prob_k,
int* locks,
bool use_fp32_reduce
);
// Load 16x16 matrix fragment from shared memory
template <int count, vllm::ScalarTypeId type_id>
__device__ inline void ldsm(
typename MarlinScalarType<type_id>::FragA& frag_a,
const void* smem_ptr);
// Scale dequantized values by quantization scale
template <vllm::ScalarTypeId type_id>
__device__ inline void scale(
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragS& frag_s,
int i);
// Scale and subtract zero point
template <vllm::ScalarTypeId type_id>
__device__ inline void scale_and_sub(
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::scalar_t s,
typename MarlinScalarType<type_id>::scalar_t zp);
// Subtract zero point from dequantized fragment
template <vllm::ScalarTypeId type_id>
__device__ inline void sub_zp(
typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::scalar_t2& frag_zp,
int i);
} // namespace MARLIN_NAMESPACE_NAME
Import
#include "marlin_template.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | const int4* | Yes | Input activation matrix of shape (m x k), in FP16/BF16/INT8/FP8 |
| B | const int4* | Yes | Quantized weight matrix of shape (k x n), packed in the target quantization format |
| scales_ptr | const int4* | Yes | Quantization scale factors of shape ((k/groupsize) x n) |
| g_idx | const int* | No | Group indices of shape (k), required when has_act_order is true |
| num_groups | int | Yes | Number of quantization scale groups per output channel |
| prob_m | int | Yes | Batch dimension M |
| prob_n | int | Yes | Output dimension N |
| prob_k | int | Yes | Reduction dimension K |
| locks | int* | Yes | Global storage for barrier synchronization between threadblocks |
| use_fp32_reduce | bool | Yes | Whether to use FP32 accumulation for the global reduction step |
Outputs
| Name | Type | Description |
|---|---|---|
| C | int4* | Output matrix of shape (m x n), in FP16/BF16 matching the compute type |
| C_tmp | int4* | Temporary FP32 output buffer used when use_fp32_reduce is true |
Usage Examples
// Kernel instantiation (auto-generated by generate_kernels.py)
template __global__ void Marlin<
vllm::kFloat16.id(), // a_type_id
vllm::kU4B8.id(), // b_type_id (GPTQ INT4)
vllm::kFloat16.id(), // c_type_id
vllm::kFloat16.id(), // s_type_id
256, // threads
1, // thread_m_blocks
8, // thread_n_blocks
8, // thread_k_blocks
false, // m_block_size_8
4, // stages
8, // group_blocks
false // is_zp_float
>(MARLIN_KERNEL_PARAMS);