Implementation:Vllm project Vllm Marlin Dequant
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Marlin, Dequantization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements fast GPU dequantization routines that convert INT4/INT8/FP4/FP8 quantized weights to FP16/BF16 using bitwise manipulation and fused zero-point/scale operations.
Description
This header provides device-inline dequantization functions for the Marlin GEMM kernel. It uses lookup-table-based 3-input logical operations (lop3 PTX instruction) and byte permutation (prmt PTX instruction) to efficiently extract and convert packed quantized values. The dequantization process combines bitwise extraction with floating-point computation, optionally fusing zero-point subtraction and scale factor application using __hsub2, __hmul2, and __hfma2 intrinsics. Template specializations are provided for multiple format combinations including U4B8, U4, U8B128, FE4M3fn, and FE2M1f targeting both half2 and nv_bfloat162 output types.
Usage
This header is included by the Marlin kernel template (marlin_template.h) and is compiled as part of the Marlin quantized GEMM CUDA kernels. It is used at inference time whenever quantized weight dequantization is needed during matrix multiplication on Turing (SM75) and later GPU architectures.
Code Reference
Source Location
- Repository: vllm
- File: csrc/quantization/marlin/dequant.h
- Lines: 1-609
Signature
namespace MARLIN_NAMESPACE_NAME {
// Lookup-table based 3-input logical operation
template <int lut>
__device__ inline int lop3(int a, int b, int c);
// Byte permutation from 2 sources
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a);
// Primary dequantization template
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
bool skip_flop = false>
__device__ inline void dequant(int q, scalar_t2* frag_b);
// Specializations for INT4 (U4B8) to FP16
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), true>(
int q, half2* frag_b);
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id(), false>(
int q, half2* frag_b);
// Specializations for INT4 (U4) to FP16
template <>
__device__ inline void dequant<half2, vllm::kU4.id(), true>(
int q, half2* frag_b);
// Specializations for INT8 (U8B128) to FP16/BF16
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id(), false>(
int q, half2* frag_b);
// Specializations for FP8 (FE4M3fn) and FP4 (FE2M1f)
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id(), false>(
int q, half2* frag_b);
template <>
__device__ inline void dequant<half2, vllm::kFE2M1f.id(), false>(
int q, half2* frag_b);
} // namespace MARLIN_NAMESPACE_NAME
Import
#include "dequant.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | int | Yes | Packed quantized weight values in a 32-bit integer (e.g., 8 x 4-bit values or 4 x 8-bit values) |
| w_type_id | vllm::ScalarTypeId | Yes | Compile-time template parameter specifying the quantization type (kU4B8, kU4, kU8B128, kFE4M3fn, kFE2M1f) |
| skip_flop | bool | No | When true, skips the floating-point computation step (used when zero-point subtraction handles the bias) |
Outputs
| Name | Type | Description |
|---|---|---|
| frag_b | scalar_t2* (half2* or nv_bfloat162*) | Array of 2 dequantized half2/bfloat162 values (4 scalar values total) forming a B-fragment for tensor core MMA |
Usage Examples
// Inside a Marlin kernel: dequantize 4-bit weights to FP16
int packed_weights = ...; // 8 x 4-bit values packed in int32
half2 frag_b[2];
// Dequantize with zero-point subtraction (skip_flop=false)
dequant<half2, vllm::kU4B8.id(), false>(packed_weights, frag_b);
// frag_b[0] and frag_b[1] now contain 4 dequantized FP16 values