Implementation:Sgl project Sglang Marlin GEMM Template
| Knowledge Sources | |
|---|---|
| Domains | CUDA_Kernels, Quantized_GEMM, Marlin_GEMM |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Full CUDA kernel template implementing the Marlin quantized GEMM algorithm with on-the-fly weight dequantization, tensor core MMA, and asynchronous memory pipelines for INT4, INT8, FP4, and FP8 quantized models.
Description
The marlin_template.h header contains the complete implementation of the Marlin global CUDA kernel function, adapted from the IST-DASLab Marlin project and modified by Neural Magic. The kernel is heavily templated with the following parameters:
- scalar_t -- compute data type (half or nv_bfloat16)
- w_type_id -- weight quantization type via sglang::ScalarTypeId
- threads -- number of threads per threadblock
- thread_m_blocks / thread_n_blocks / thread_k_blocks -- tile dimensions in units of 16x16 blocks
- m_block_size_8 -- support for 8-element M-blocks (only when thread_m_blocks == 1)
- stages -- number of pipeline stages for async global-to-shared fetch
- has_act_order -- enable GPTQ-style activation reordering
- group_blocks -- number of consecutive blocks sharing a quantization scale
- is_zp_float -- whether zero points are floating-point type
The file provides a mma_sp helper function implementing m16n8k16 tensor core instructions via inline PTX assembly. An empty stub is provided for architectures below SM80 (compute capability 8.0), with the full implementation targeting SM80+.
Key algorithmic features include asynchronous global-to-shared memory pipelines using cp.async, register-tiled matrix operations with tensor cores, configurable split-K with lock-based synchronization (locks array), optional FP32 accumulation for numerical stability (use_fp32_reduce), and support for activation reordering via g_idx permutation indices.
Usage
This kernel is instantiated by the Marlin dispatch code with specific template parameters matching the model's quantization configuration. It is the primary kernel for quantized weight-only GEMM during inference.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/gemm/marlin/marlin_template.h
- Lines: 1-1629
Signature
namespace MARLIN_NAMESPACE_NAME { // defaults to "marlin"
// m16n8k16 tensor core MMA via inline PTX
template <typename scalar_t>
__device__ inline void mma_sp(
const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& b_frag,
typename ScalarType<scalar_t>::FragC& c_frag);
template <
typename scalar_t, // half or nv_bfloat16
const sglang::ScalarTypeId w_type_id, // weight type
const int threads, // threads per block
const int thread_m_blocks, // M-dimension blocks
const int thread_n_blocks, // N-dimension blocks
const int thread_k_blocks, // K-dimension blocks
const bool m_block_size_8, // 8-element M-blocks
const int stages, // pipeline stages
const bool has_act_order, // activation reordering
const int group_blocks, // quantization group size
const bool is_zp_float // float zero points
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input [m, k]
const int4* __restrict__ B, // quantized weights [k, n]
int4* __restrict__ C, // fp16 output [m, n]
int4* __restrict__ C_tmp, // fp32 tmp for reduce
const int4* __restrict__ scales_ptr, // scales [k/group, n]
const int* __restrict__ g_idx, // group indices [k]
int num_groups,
int prob_m,
int prob_n,
int prob_k,
int* locks,
bool use_fp32_reduce
);
} // namespace MARLIN_NAMESPACE_NAME
Import
#include "marlin_template.h"
// Or via individual dependencies:
#include "dequant.h"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "scalar_type.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | const int4* | Yes | FP16/BF16 activation input matrix of shape [m, k] |
| B | const int4* | Yes | Packed quantized weight matrix of shape [k, n] |
| scales_ptr | const int4* | Yes | FP16 quantization scales of shape [k/groupsize, n] |
| g_idx | const int* | No | Group indices for activation reordering (GPTQ act_order) |
| prob_m / prob_n / prob_k | int | Yes | GEMM problem dimensions |
| locks | int* | Yes | Global synchronization barriers for split-K reduction |
| use_fp32_reduce | bool | Yes | Whether to accumulate partial results in FP32 |
Outputs
| Name | Type | Description |
|---|---|---|
| C | int4* | FP16/BF16 output buffer of shape [m, n] |
| C_tmp | int4* | FP32 temporary output buffer used during split-K reduction |
Usage Examples
// Kernel launch with specific template parameters
Marlin<
half, // scalar_t: FP16 compute
sglang::kU4B8.id(), // w_type_id: INT4 with bias 8
256, // threads per block
1, // thread_m_blocks
4, // thread_n_blocks
8, // thread_k_blocks
false, // m_block_size_8
4, // pipeline stages
false, // has_act_order
-1, // group_blocks (-1 = channel-wise)
false // is_zp_float
><<<grid, block, smem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr,
scales_ptr, g_idx_ptr,
num_groups, prob_m, prob_n, prob_k,
locks_ptr, use_fp32_reduce
);