Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm Marlin Dequant

From Leeroopedia


Knowledge Sources
Domains Quantization, Marlin, Dequantization
Last Updated 2026-02-08 00:00 GMT

Overview

Implements fast GPU dequantization routines that convert INT4/INT8/FP4/FP8 quantized weights to FP16/BF16 using bitwise manipulation and fused zero-point/scale operations.

Description

This header provides device-inline dequantization functions for the Marlin GEMM kernel. It uses lookup-table-based 3-input logical operations (lop3 PTX instruction) and byte permutation (prmt PTX instruction) to efficiently extract and convert packed quantized values. The dequantization process combines bitwise extraction with floating-point computation, optionally fusing zero-point subtraction and scale factor application using __hsub2, __hmul2, and __hfma2 intrinsics. Template specializations are provided for multiple format combinations including U4B8, U4, U8B128, FE4M3fn, and FE2M1f targeting both half2 and nv_bfloat162 output types.

Usage

This header is included by the Marlin kernel template (marlin_template.h) and is compiled as part of the Marlin quantized GEMM CUDA kernels. It is used at inference time whenever quantized weight dequantization is needed during matrix multiplication on Turing (SM75) and later GPU architectures.

Code Reference

Source Location

Signature

namespace MARLIN_NAMESPACE_NAME {

// Lookup-table based 3-input logical operation
template <int lut>
__device__ inline int lop3(int a, int b, int c);

// Byte permutation from 2 sources
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a);

// Primary dequantization template
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
          bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);

// Specializations for INT4 (U4B8) to FP16
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(
    int q, half2* frag_b);
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(
    int q, half2* frag_b);

// Specializations for INT4 (U4) to FP16
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), true>(
    int q, half2* frag_b);

// Specializations for INT8 (U8B128) to FP16/BF16
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
    int q, half2* frag_b);

// Specializations for FP8 (FE4M3fn) and FP4 (FE2M1f)
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
    int q, half2* frag_b);
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
    int q, half2* frag_b);

} // namespace MARLIN_NAMESPACE_NAME

Import

#include "dequant.h"

I/O Contract

Inputs

Name Type Required Description
q int Yes Packed quantized weight values in a 32-bit integer (e.g., 8 x 4-bit values or 4 x 8-bit values)
w_type_id vllm::ScalarTypeId Yes Compile-time template parameter specifying the quantization type (kU4B8, kU4, kU8B128, kFE4M3fn, kFE2M1f)
skip_flop bool No When true, skips the floating-point computation step (used when zero-point subtraction handles the bias)

Outputs

Name Type Description
frag_b scalar_t2* (half2* or nv_bfloat162*) Array of 2 dequantized half2/bfloat162 values (4 scalar values total) forming a B-fragment for tensor core MMA

Usage Examples

// Inside a Marlin kernel: dequantize 4-bit weights to FP16
int packed_weights = ...; // 8 x 4-bit values packed in int32
half2 frag_b[2];

// Dequantize with zero-point subtraction (skip_flop=false)
dequant<half2, vllm::kU4B8.id(), false>(packed_weights, frag_b);

// frag_b[0] and frag_b[1] now contain 4 dequantized FP16 values

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment