Implementation:Predibase Lorax GPTQ Utils Quant Linear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a GPTQ-quantized linear layer implementation using Triton GPU kernels to perform dequantization and matrix multiplication for 2-bit, 4-bit, and 8-bit weight representations.
Description
This module implements the core GPTQ quantized linear layer used during inference. It contains three main components:
matmul_248_kernel: A Triton JIT-compiled GPU kernel decorated with a custom autotuner that computes matrix multiplication C = A x B where A is float16 and B is a packed integer tensor (int32). The kernel dequantizes B on-the-fly using per-group scales and zeros, performing the operation at full GPU throughput. Eight different tile configurations are explored by the autotuner, with key dimensions M, N, and K driving configuration selection.
QuantLinearFunction: A custom PyTorch autograd function that wraps the Triton kernel call via matmul248(). It casts inputs to float16 using custom_fwd and delegates the computation to the kernel.
QuantLinear: An nn.Module that holds the quantized weight buffers (qweight, qzeros, scales, g_idx) and provides a standard forward() method. It includes a pack() method for converting full-precision weights into the packed quantized format, and a new() classmethod factory for creating empty quantized layers with the correct buffer dimensions.
Usage
This module is used as a drop-in replacement for standard nn.Linear layers when serving GPTQ-quantized models. The QuantLinear class is instantiated during model loading when GPTQ quantization is detected, and its forward() method is called during inference to perform quantized matrix multiplication. The pack() method is used during the quantization process to convert calibrated weights into the compressed format.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/gptq/quant_linear.py - Lines: 1-346
Signature
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq)
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias)
def pack(self, linear, scales, zeros, g_idx=None)
def forward(self, x)
@property
def weight(self) -> torch.Tensor
Import
from lorax_server.utils.gptq.quant_linear import QuantLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| qweight | torch.Tensor (int32) | Yes | Packed quantized weight tensor of shape (infeatures // 32 * bits, outfeatures) |
| qzeros | torch.Tensor (int32) | Yes | Packed quantized zero-point tensor of shape (groups, outfeatures // 32 * bits) |
| scales | torch.Tensor (float16) | Yes | Per-group scale factors of shape (groups, outfeatures) |
| g_idx | torch.Tensor (int32) | Yes | Group index mapping each input feature to its quantization group |
| bias | torch.Tensor (float16) or None | No | Optional bias vector of shape (outfeatures,) |
| bits | int | Yes | Quantization bit-width; must be 2, 4, or 8 |
| groupsize | int | Yes | Number of input features per quantization group |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor (float16) | Result of the quantized linear transformation with shape (*input_shape[:-1], outfeatures) |
Usage Examples
# Internal usage during model loading
from lorax_server.utils.gptq.quant_linear import QuantLinear
# Create a new quantized linear layer
qlayer = QuantLinear.new(bits=4, groupsize=128, infeatures=4096, outfeatures=4096, bias=True)
# Or initialize from pre-loaded quantized buffers
qlayer = QuantLinear(qweight, qzeros, scales, g_idx, bias, bits=4, groupsize=128)
# Forward pass during inference
output = qlayer(input_tensor)