Implementation:Predibase Lorax Exllama V2 CUDA Bindings
| Knowledge Sources | |
|---|---|
| Domains | Quantization, GPU_Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
This C++ file provides the PyTorch/pybind11 extension bindings for ExLlama v2 quantized matrix operations, exposing CUDA GEMM kernels for mixed-precision quantized inference to Python.
Description
This extension module is the v2 iteration of ExLlama's quantized matrix support, with a more flexible QMatrix representation that supports both native EXL2 quantization and GPTQ-format weights. It exposes two functions through pybind11: make_q_matrix constructs a QMatrix object from up to eleven tensor parameters covering EXL2 quantization (weight, permutations, inverse permutations, scales, scale maximums, groups, group maps) and GPTQ-format quantization (qzeros, scales, group index), plus a temporary dequantization buffer; and gemm_half_q_half performs a half-precision GEMM between a dense half-precision matrix a and a quantized matrix b, writing the result into c. The make_q_matrix function uses meta-device checks to determine whether EXL2 or GPTQ format tensors are provided, allowing a single code path for both formats. The function throws a std::runtime_error if CUDA memory allocation fails.
Usage
This extension is loaded by the LoRAX Python server when running inference with quantized models using the ExLlama v2 backend. It supports both EXL2 and GPTQ quantization formats and is used for efficient quantized matrix multiplication during the forward pass of transformer layers.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/exllamav2_kernels/exllamav2_kernels/ext.cpp - Lines: 1-139
Signature
uintptr_t make_q_matrix(
torch::Tensor q_weight,
torch::Tensor q_perm,
torch::Tensor q_invperm,
torch::Tensor q_scale,
torch::Tensor q_scale_max,
torch::Tensor q_groups,
torch::Tensor q_group_map,
torch::Tensor gptq_qzeros,
torch::Tensor gptq_scales,
torch::Tensor gptq_g_idx,
torch::Tensor temp_dq
);
void gemm_half_q_half(
torch::Tensor a,
uintptr_t b,
torch::Tensor c,
bool force_cuda
);
// pybind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}
Import
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q_weight | torch::Tensor (int32) |
Yes | Packed quantized weight data |
| q_perm | torch::Tensor (int16) |
No* | Row permutation indices (EXL2 format; meta device if unused) |
| q_invperm | torch::Tensor (int16) |
No* | Inverse row permutation indices (EXL2 format; meta device if unused) |
| q_scale | torch::Tensor (int32) |
No* | Quantization scale data (EXL2 format; meta device if unused) |
| q_scale_max | torch::Tensor (half) |
No* | Maximum scale values per group (EXL2 format; meta device if unused) |
| q_groups | torch::Tensor (int16) |
No* | Group definitions (EXL2 format; meta device if unused) |
| q_group_map | torch::Tensor (int16) |
No* | Group-to-row mapping (EXL2 format; meta device if unused) |
| gptq_qzeros | torch::Tensor (int32) |
No* | GPTQ zero points (meta device if unused) |
| gptq_scales | torch::Tensor (half) |
No* | GPTQ scale factors (meta device if unused) |
| gptq_g_idx | torch::Tensor (int32) |
No* | GPTQ group indices (meta device if unused) |
| temp_dq | torch::Tensor (half) |
Yes | Temporary buffer for dequantized values, must be at least width * height elements
|
| a | torch::Tensor (half) |
Yes | Dense half-precision input matrix for GEMM |
| b | uintptr_t |
Yes | Opaque handle to a QMatrix object
|
| c | torch::Tensor (half) |
Yes | Pre-allocated output tensor for GEMM result |
| force_cuda | bool |
Yes | Force CUDA kernel instead of cuBLAS |
Outputs
| Name | Type | Description |
|---|---|---|
| make_q_matrix return | uintptr_t |
Opaque pointer handle to the constructed QMatrix object
|
| c (mutated) | torch::Tensor (half) |
Output tensor written in-place with result of a @ dequantize(b)
|
Usage Examples
// From Python via the compiled extension:
import exllamav2_ext
# Create a QMatrix from EXL2-format quantized weights
handle = exllamav2_ext.make_q_matrix(
q_weight, q_perm, q_invperm,
q_scale, q_scale_max, q_groups, q_group_map,
gptq_qzeros, # meta tensor if not using GPTQ
gptq_scales, # meta tensor if not using GPTQ
gptq_g_idx, # meta tensor if not using GPTQ
temp_dq
)
# Perform mixed-precision GEMM: c = a @ QMatrix
exllamav2_ext.gemm_half_q_half(a, handle, c, force_cuda=False)