Implementation:Predibase Lorax GPTQ Quant Linear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a GPTQ-quantized linear layer using a custom Triton kernel for dequantization and matrix multiplication with 2, 4, or 8-bit packed integer weights.
Description
This module provides a Triton-based implementation of matrix multiplication for GPTQ-quantized models. It consists of three main components:
matmul_248_kernel: A Triton JIT-compiled kernel that performs matrix multiplication between a float16 activation matrix A (M, K) and a packed integer weight matrix B (K//8, N). The kernel unpacks the N-bit values from 32-bit integers using bitwise shifts, applies per-group scale and zero-point dequantization, and accumulates the result in float32 before storing as float16. The kernel is decorated with a custom autotuner that evaluates 8 different block size configurations.
QuantLinearFunction: A PyTorch autograd Function that wraps the Triton matrix multiplication in a custom forward pass, casting inputs to float16 via custom_fwd.
QuantLinear: An nn.Module that holds the packed quantized weights (qweight), zero points (qzeros), scales, and group indices (g_idx) as registered buffers. It supports 2, 4, and 8-bit quantization. The class provides a new classmethod for creating empty quantized layers and a pack method to quantize a full-precision linear layer into the packed format using NumPy. The forward method reshapes input, delegates to QuantLinearFunction, and adds bias if present.
Usage
This layer is used when loading GPTQ-quantized models without exllama acceleration. It is instantiated by the get_linear factory function in linear.py when quantize == "gptq" and use_exllama is False.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/gptq/quant_linear.py
- Lines: 1-345
Signature
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
Import
from lorax_server.layers.gptq.quant_linear import QuantLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| qweight | torch.Tensor (int32) | Yes | Packed quantized weight matrix of shape (infeatures // 32 * bits, outfeatures) |
| qzeros | torch.Tensor (int32) | Yes | Packed quantized zero points of shape (groups, outfeatures // 32 * bits) |
| scales | torch.Tensor (float16) | Yes | Per-group scale factors of shape (groups, outfeatures) |
| g_idx | torch.Tensor (int32) | Yes | Group index mapping each input feature to its quantization group |
| bias | torch.Tensor or None | No | Optional bias vector of shape (outfeatures,) |
| bits | int | Yes | Quantization bit width, must be 2, 4, or 8 |
| groupsize | int | Yes | Number of input features per quantization group |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor (float16) | Result of the quantized linear transformation, shape (*input_shape[:-1], outfeatures) |
Usage Examples
# Used internally by model layers via the linear factory
from lorax_server.layers.gptq.quant_linear import QuantLinear
layer = QuantLinear(qweight, qzeros, scales, g_idx, bias, bits=4, groupsize=128)
output = layer(input_tensor)