Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax GPTQ Utils Quant Linear

From Leeroopedia


Knowledge Sources
Domains Quantization, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides a GPTQ-quantized linear layer implementation using Triton GPU kernels to perform dequantization and matrix multiplication for 2-bit, 4-bit, and 8-bit weight representations.

Description

This module implements the core GPTQ quantized linear layer used during inference. It contains three main components:

matmul_248_kernel: A Triton JIT-compiled GPU kernel decorated with a custom autotuner that computes matrix multiplication C = A x B where A is float16 and B is a packed integer tensor (int32). The kernel dequantizes B on-the-fly using per-group scales and zeros, performing the operation at full GPU throughput. Eight different tile configurations are explored by the autotuner, with key dimensions M, N, and K driving configuration selection.

QuantLinearFunction: A custom PyTorch autograd function that wraps the Triton kernel call via matmul248(). It casts inputs to float16 using custom_fwd and delegates the computation to the kernel.

QuantLinear: An nn.Module that holds the quantized weight buffers (qweight, qzeros, scales, g_idx) and provides a standard forward() method. It includes a pack() method for converting full-precision weights into the packed quantized format, and a new() classmethod factory for creating empty quantized layers with the correct buffer dimensions.

Usage

This module is used as a drop-in replacement for standard nn.Linear layers when serving GPTQ-quantized models. The QuantLinear class is instantiated during model loading when GPTQ quantization is detected, and its forward() method is called during inference to perform quantized matrix multiplication. The pack() method is used during the quantization process to convert calibrated weights into the compressed format.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/gptq/quant_linear.py
  • Lines: 1-346

Signature

def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)

class QuantLinearFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float16)
    def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq)

class QuantLinear(nn.Module):
    def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize)
    @classmethod
    def new(cls, bits, groupsize, infeatures, outfeatures, bias)
    def pack(self, linear, scales, zeros, g_idx=None)
    def forward(self, x)
    @property
    def weight(self) -> torch.Tensor

Import

from lorax_server.utils.gptq.quant_linear import QuantLinear

I/O Contract

Inputs

Name Type Required Description
qweight torch.Tensor (int32) Yes Packed quantized weight tensor of shape (infeatures // 32 * bits, outfeatures)
qzeros torch.Tensor (int32) Yes Packed quantized zero-point tensor of shape (groups, outfeatures // 32 * bits)
scales torch.Tensor (float16) Yes Per-group scale factors of shape (groups, outfeatures)
g_idx torch.Tensor (int32) Yes Group index mapping each input feature to its quantization group
bias torch.Tensor (float16) or None No Optional bias vector of shape (outfeatures,)
bits int Yes Quantization bit-width; must be 2, 4, or 8
groupsize int Yes Number of input features per quantization group

Outputs

Name Type Description
output torch.Tensor (float16) Result of the quantized linear transformation with shape (*input_shape[:-1], outfeatures)

Usage Examples

# Internal usage during model loading
from lorax_server.utils.gptq.quant_linear import QuantLinear

# Create a new quantized linear layer
qlayer = QuantLinear.new(bits=4, groupsize=128, infeatures=4096, outfeatures=4096, bias=True)

# Or initialize from pre-loaded quantized buffers
qlayer = QuantLinear(qweight, qzeros, scales, g_idx, bias, bits=4, groupsize=128)

# Forward pass during inference
output = qlayer(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment