Implementation:FMInference FlexLLMGen DeepSpeed Quantization Utils
| Knowledge Sources | |
|---|---|
| Domains | CUDA, Quantization, Deep Learning Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
A CUDA header-only library providing template-based quantization and dequantization utilities for converting floating-point tensor data to low-bit integer representations on GPU.
Description
This file implements the quantize namespace containing device-side utilities for quantizing half-precision (__half) data into 4-bit or 8-bit integer formats on NVIDIA GPUs. The implementation is organized around three core template classes:
- Params: Holds quantization parameters (scale and optional offset) and implements the quantize/dequantize operations. Specialized for three quantization types: Symmetric, IntegerSymmetric, and Asymmetric.
- GroupStats: Tracks per-group statistics (running min/max values) needed to compute quantization parameters. Uses warp and block reductions from the companion reduction_utils.h header.
- local_array: The main quantization kernel loop that processes arrays of __half2 data in register-file local memory, computes group statistics, derives quantization parameters, and writes quantized int8 output.
Key design choices include 16-byte granularity for memory access (constexpr int granularity = 16), support for both 4-bit packing (via PackedInt4) and 8-bit storage, and cooperative groups for warp/block synchronization during reduction phases.
Usage
These utilities are used as device-side building blocks within DeepSpeed's custom CUDA quantization kernels. They are included by quantization kernel source files that launch grid configurations and invoke the local_array device functions to quantize model weights or activations for inference.
Code Reference
Source Location
- Repository: FMInference_FlexLLMGen
- File: benchmark/third_party/DeepSpeed/csrc/includes/quantization_utils.h
- Lines: 1-510
Signature
namespace quantize {
// Quantization parameters holder (specialized per quantization type)
template <Type qType, int numBits>
class Params {
DS_D_INLINE int8_t quantize(__half val);
DS_D_INLINE __half dequantize(int8_t val);
DS_D_INLINE void store(float* params, int group_index);
DS_D_INLINE Params(const float* params, int group_index);
};
// Group statistics tracker
template <Type qType>
class GroupStats {
DS_D_INLINE void update(__half2 val);
DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
};
// Main quantization loop (two overloads: with and without pre-computed params)
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
__device__ void local_array(__half2* local_buffer, float* global_params,
int8_t* output_data, const int& elems_per_group, const int& groups);
}
Import
#include "quantization_utils.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| local_buffer | __half2* | Yes | Pointer to register-file local buffer containing input half-precision data to quantize. |
| elems_per_group | int | Yes | Number of elements per quantization group. |
| groups | int | Yes | Number of quantization groups. |
| qType | Type (template) | Yes | Quantization type: Symmetric, IntegerSymmetric, or Asymmetric. |
| numBits | int (template) | Yes | Number of bits for quantized output (4 or 8). |
| numChunks | int (template) | Yes | Number of 16-byte chunks of input data per thread. |
Outputs
| Name | Type | Description |
|---|---|---|
| global_params | float* | Quantization scale (and offset for asymmetric) parameters written per group. |
| output_data | int8_t* | Quantized integer output data (packed for 4-bit). |
Usage Examples
// Inside a CUDA kernel: quantize a local buffer of __half2 values
// using 8-bit symmetric quantization with 256 threads per group
__half2 local_buffer[4 * quantize::h2_per_load];
// Load data into local_buffer from global memory ...
quantize::local_array<quantize::Type::Symmetric, 8, 4, 256, 256>(
local_buffer,
output_scales, // float* for per-group scale parameters
output_data, // int8_t* for quantized output
elems_per_group,
num_groups
);