Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Marlin GEMM Template

From Leeroopedia


Knowledge Sources
Domains CUDA_Kernels, Quantized_GEMM, Marlin_GEMM
Last Updated 2026-02-10 00:00 GMT

Overview

Full CUDA kernel template implementing the Marlin quantized GEMM algorithm with on-the-fly weight dequantization, tensor core MMA, and asynchronous memory pipelines for INT4, INT8, FP4, and FP8 quantized models.

Description

The marlin_template.h header contains the complete implementation of the Marlin global CUDA kernel function, adapted from the IST-DASLab Marlin project and modified by Neural Magic. The kernel is heavily templated with the following parameters:

  • scalar_t -- compute data type (half or nv_bfloat16)
  • w_type_id -- weight quantization type via sglang::ScalarTypeId
  • threads -- number of threads per threadblock
  • thread_m_blocks / thread_n_blocks / thread_k_blocks -- tile dimensions in units of 16x16 blocks
  • m_block_size_8 -- support for 8-element M-blocks (only when thread_m_blocks == 1)
  • stages -- number of pipeline stages for async global-to-shared fetch
  • has_act_order -- enable GPTQ-style activation reordering
  • group_blocks -- number of consecutive blocks sharing a quantization scale
  • is_zp_float -- whether zero points are floating-point type

The file provides a mma_sp helper function implementing m16n8k16 tensor core instructions via inline PTX assembly. An empty stub is provided for architectures below SM80 (compute capability 8.0), with the full implementation targeting SM80+.

Key algorithmic features include asynchronous global-to-shared memory pipelines using cp.async, register-tiled matrix operations with tensor cores, configurable split-K with lock-based synchronization (locks array), optional FP32 accumulation for numerical stability (use_fp32_reduce), and support for activation reordering via g_idx permutation indices.

Usage

This kernel is instantiated by the Marlin dispatch code with specific template parameters matching the model's quantization configuration. It is the primary kernel for quantized weight-only GEMM during inference.

Code Reference

Source Location

Signature

namespace MARLIN_NAMESPACE_NAME {  // defaults to "marlin"

// m16n8k16 tensor core MMA via inline PTX
template <typename scalar_t>
__device__ inline void mma_sp(
    const typename ScalarType<scalar_t>::FragA& a_frag,
    const typename ScalarType<scalar_t>::FragB& b_frag,
    typename ScalarType<scalar_t>::FragC& c_frag);

template <
    typename scalar_t,                     // half or nv_bfloat16
    const sglang::ScalarTypeId w_type_id,  // weight type
    const int threads,                     // threads per block
    const int thread_m_blocks,             // M-dimension blocks
    const int thread_n_blocks,             // N-dimension blocks
    const int thread_k_blocks,             // K-dimension blocks
    const bool m_block_size_8,             // 8-element M-blocks
    const int stages,                      // pipeline stages
    const bool has_act_order,              // activation reordering
    const int group_blocks,                // quantization group size
    const bool is_zp_float                 // float zero points
>
__global__ void Marlin(
    const int4* __restrict__ A,           // fp16 input [m, k]
    const int4* __restrict__ B,           // quantized weights [k, n]
    int4* __restrict__ C,                 // fp16 output [m, n]
    int4* __restrict__ C_tmp,             // fp32 tmp for reduce
    const int4* __restrict__ scales_ptr,  // scales [k/group, n]
    const int* __restrict__ g_idx,        // group indices [k]
    int num_groups,
    int prob_m,
    int prob_n,
    int prob_k,
    int* locks,
    bool use_fp32_reduce
);

} // namespace MARLIN_NAMESPACE_NAME

Import

#include "marlin_template.h"

// Or via individual dependencies:
#include "dequant.h"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "scalar_type.hpp"

I/O Contract

Inputs

Name Type Required Description
A const int4* Yes FP16/BF16 activation input matrix of shape [m, k]
B const int4* Yes Packed quantized weight matrix of shape [k, n]
scales_ptr const int4* Yes FP16 quantization scales of shape [k/groupsize, n]
g_idx const int* No Group indices for activation reordering (GPTQ act_order)
prob_m / prob_n / prob_k int Yes GEMM problem dimensions
locks int* Yes Global synchronization barriers for split-K reduction
use_fp32_reduce bool Yes Whether to accumulate partial results in FP32

Outputs

Name Type Description
C int4* FP16/BF16 output buffer of shape [m, n]
C_tmp int4* FP32 temporary output buffer used during split-K reduction

Usage Examples

// Kernel launch with specific template parameters
Marlin<
    half,                        // scalar_t: FP16 compute
    sglang::kU4B8.id(),         // w_type_id: INT4 with bias 8
    256,                         // threads per block
    1,                           // thread_m_blocks
    4,                           // thread_n_blocks
    8,                           // thread_k_blocks
    false,                       // m_block_size_8
    4,                           // pipeline stages
    false,                       // has_act_order
    -1,                          // group_blocks (-1 = channel-wise)
    false                        // is_zp_float
><<<grid, block, smem, stream>>>(
    A_ptr, B_ptr, C_ptr, C_tmp_ptr,
    scales_ptr, g_idx_ptr,
    num_groups, prob_m, prob_n, prob_k,
    locks_ptr, use_fp32_reduce
);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment