Implementation:Predibase Lorax Exllama V1 CUDA Bindings
| Knowledge Sources | |
|---|---|
| Domains | Quantization, GPU_Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
This C++ file provides the PyTorch/pybind11 extension bindings for ExLlama v1 GPTQ 4-bit quantized matrix operations, exposing CUDA kernel functions to Python for quantized inference.
Description
Adapted from turboderp's ExLlama project, this extension module bridges Python and CUDA kernels for 4-bit (Q4) quantized matrix multiplication. It exposes five functions through pybind11: set_tuning_params configures performance tuning knobs for matrix reconstruction thresholds and half2 optimizations; prepare_buffers allocates temporary GPU buffers on a specific device; cleanup frees all unmanaged CUDA allocations including buffers and Q4 matrices; make_q4 constructs a Q4Matrix object from quantized weight, zero-point, scale, and group index tensors, returning an opaque handle; and q4_matmul performs the core half-precision times 4-bit quantized matrix multiplication, choosing between a direct CUDA kernel or a cuBLAS-backed reconstruction path based on tuning parameters. The file also defines a column_remap function (not exported to Python) for reordering columns in half-precision tensors. Extensive validation macros check tensor dtypes, shapes, and device indices before kernel dispatch.
Usage
This extension is loaded by the LoRAX Python server when running inference with GPTQ 4-bit quantized models using the ExLlama v1 backend. The Python-side code creates Q4 matrices from model weights, configures tuning parameters, and calls q4_matmul during each forward pass to perform efficient quantized matrix multiplications on the GPU.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/exllama_kernels/exllama_kernels/exllama_ext.cpp - Lines: 1-253
Signature
void set_tuning_params(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
);
void prepare_buffers(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
);
void cleanup();
uintptr_t make_q4(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
);
void q4_matmul(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
);
// pybind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}
Import
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| matmul_recons_thd | int |
Yes | Batch size threshold for switching between direct CUDA and cuBLAS reconstruction kernels |
| matmul_fused_remap | bool |
Yes | Whether to use fused column remapping during matrix multiplication |
| matmul_no_half2 | bool |
Yes | Disable half2 optimization for compatibility |
| device | torch::Device / int |
Yes | Target CUDA device for buffer allocation or matrix creation |
| temp_state | torch::Tensor (half) |
Yes | Temporary state buffer for intermediate computations |
| temp_dq | torch::Tensor (half) |
Yes | Temporary dequantization buffer |
| qweight | torch::Tensor (int32) |
Yes | Packed 4-bit quantized weight matrix |
| qzeros | torch::Tensor (int32) |
Yes | Quantization zero points |
| scales | torch::Tensor (half) |
Yes | Quantization scale factors per group |
| g_idx | torch::Tensor (int32) |
No | Group index mapping (optional, may be meta device) |
| x | torch::Tensor (half) |
Yes | Input activation tensor for matrix multiplication |
| w | uintptr_t |
Yes | Opaque handle to a Q4Matrix object
|
| out | torch::Tensor (half) |
Yes | Pre-allocated output tensor for matmul result |
Outputs
| Name | Type | Description |
|---|---|---|
| make_q4 return | uintptr_t |
Opaque pointer handle to the constructed Q4Matrix object
|
| out (mutated) | torch::Tensor (half) |
The output tensor is written in-place with the result of x @ dequantize(w)
|
Usage Examples
// From Python via the compiled extension:
import exllama_ext
# Configure tuning parameters
exllama_ext.set_tuning_params(
matmul_recons_thd=8,
matmul_fused_remap=False,
matmul_no_half2=False
)
# Prepare CUDA buffers on device 0
exllama_ext.prepare_buffers(device, temp_state, temp_dq)
# Create a Q4 quantized matrix from GPTQ weights
handle = exllama_ext.make_q4(qweight, qzeros, scales, g_idx, device_idx)
# Perform quantized matmul: out = x @ Q4Matrix
exllama_ext.q4_matmul(x, handle, out)
# Cleanup all GPU allocations
exllama_ext.cleanup()