Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Exllama V2 CUDA Bindings

From Leeroopedia


Knowledge Sources
Domains Quantization, GPU_Kernels
Last Updated 2026-02-08 00:00 GMT

Overview

This C++ file provides the PyTorch/pybind11 extension bindings for ExLlama v2 quantized matrix operations, exposing CUDA GEMM kernels for mixed-precision quantized inference to Python.

Description

This extension module is the v2 iteration of ExLlama's quantized matrix support, with a more flexible QMatrix representation that supports both native EXL2 quantization and GPTQ-format weights. It exposes two functions through pybind11: make_q_matrix constructs a QMatrix object from up to eleven tensor parameters covering EXL2 quantization (weight, permutations, inverse permutations, scales, scale maximums, groups, group maps) and GPTQ-format quantization (qzeros, scales, group index), plus a temporary dequantization buffer; and gemm_half_q_half performs a half-precision GEMM between a dense half-precision matrix a and a quantized matrix b, writing the result into c. The make_q_matrix function uses meta-device checks to determine whether EXL2 or GPTQ format tensors are provided, allowing a single code path for both formats. The function throws a std::runtime_error if CUDA memory allocation fails.

Usage

This extension is loaded by the LoRAX Python server when running inference with quantized models using the ExLlama v2 backend. It supports both EXL2 and GPTQ quantization formats and is used for efficient quantized matrix multiplication during the forward pass of transformer layers.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/exllamav2_kernels/exllamav2_kernels/ext.cpp
  • Lines: 1-139

Signature

uintptr_t make_q_matrix(
    torch::Tensor q_weight,
    torch::Tensor q_perm,
    torch::Tensor q_invperm,
    torch::Tensor q_scale,
    torch::Tensor q_scale_max,
    torch::Tensor q_groups,
    torch::Tensor q_group_map,
    torch::Tensor gptq_qzeros,
    torch::Tensor gptq_scales,
    torch::Tensor gptq_g_idx,
    torch::Tensor temp_dq
);

void gemm_half_q_half(
    torch::Tensor a,
    uintptr_t b,
    torch::Tensor c,
    bool force_cuda
);

// pybind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
    m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
}

Import

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "config.h"
#include "cuda/q_matrix.cuh"
#include "cuda/q_gemm.cuh"
#include "cpp/util.h"

I/O Contract

Inputs

Name Type Required Description
q_weight torch::Tensor (int32) Yes Packed quantized weight data
q_perm torch::Tensor (int16) No* Row permutation indices (EXL2 format; meta device if unused)
q_invperm torch::Tensor (int16) No* Inverse row permutation indices (EXL2 format; meta device if unused)
q_scale torch::Tensor (int32) No* Quantization scale data (EXL2 format; meta device if unused)
q_scale_max torch::Tensor (half) No* Maximum scale values per group (EXL2 format; meta device if unused)
q_groups torch::Tensor (int16) No* Group definitions (EXL2 format; meta device if unused)
q_group_map torch::Tensor (int16) No* Group-to-row mapping (EXL2 format; meta device if unused)
gptq_qzeros torch::Tensor (int32) No* GPTQ zero points (meta device if unused)
gptq_scales torch::Tensor (half) No* GPTQ scale factors (meta device if unused)
gptq_g_idx torch::Tensor (int32) No* GPTQ group indices (meta device if unused)
temp_dq torch::Tensor (half) Yes Temporary buffer for dequantized values, must be at least width * height elements
a torch::Tensor (half) Yes Dense half-precision input matrix for GEMM
b uintptr_t Yes Opaque handle to a QMatrix object
c torch::Tensor (half) Yes Pre-allocated output tensor for GEMM result
force_cuda bool Yes Force CUDA kernel instead of cuBLAS

Outputs

Name Type Description
make_q_matrix return uintptr_t Opaque pointer handle to the constructed QMatrix object
c (mutated) torch::Tensor (half) Output tensor written in-place with result of a @ dequantize(b)

Usage Examples

// From Python via the compiled extension:
import exllamav2_ext

# Create a QMatrix from EXL2-format quantized weights
handle = exllamav2_ext.make_q_matrix(
    q_weight, q_perm, q_invperm,
    q_scale, q_scale_max, q_groups, q_group_map,
    gptq_qzeros,  # meta tensor if not using GPTQ
    gptq_scales,   # meta tensor if not using GPTQ
    gptq_g_idx,    # meta tensor if not using GPTQ
    temp_dq
)

# Perform mixed-precision GEMM: c = a @ QMatrix
exllamav2_ext.gemm_half_q_half(a, handle, c, force_cuda=False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment