Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax AWQ Quantized Linear

From Leeroopedia


Knowledge Sources
Domains Quantization, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements an AWQ (Activation-aware Weight Quantization) 4-bit quantized linear layer that uses a CUDA-accelerated GEMM kernel for efficient inference.

Description

This module provides a single class adapted from the MIT-HAN-Lab AWQ project.

AWQLinear (nn.Module): A quantized linear layer that stores weights in 4-bit AWQ format and performs matrix multiplication using the awq_inference_engine.gemm_forward_cuda CUDA kernel. Key characteristics:

  • 4-bit only: Raises NotImplementedError if w_bit is not 4.
  • Group quantization: Weights are quantized in groups of group_size input features. If group_size is -1, it defaults to the full in_features dimension.
  • Split-K iterations: Uses split_k_iters=8 for the CUDA GEMM kernel to improve parallelism.
  • Automatic dtype casting: Input tensors that are not float16 are cast to half precision before the CUDA kernel call and cast back to the original dtype afterward.
  • Bias handling: Optional bias is added after the GEMM operation.
  • Weight property: Exposes a weight property that returns the packed qweight tensor for compatibility with weight inspection APIs.

The module validates that in_features is divisible by group_size and that out_features is divisible by 32 // w_bit.

Usage

This module is used as a drop-in replacement for nn.Linear layers when serving AWQ-quantized models. The AWQLinear class is instantiated during model loading when AWQ quantization is detected, with the pre-quantized weight buffers (qweight, qzeros, scales) loaded from the model checkpoint.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/awq/awq.py
  • Lines: 1-51

Signature

class AWQLinear(nn.Module):
    def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias)
    @torch.no_grad()
    def forward(self, x)
    @property
    def weight(self) -> torch.Tensor

Import

from lorax_server.utils.awq.awq import AWQLinear

I/O Contract

Inputs

Name Type Required Description
w_bit int Yes Quantization bit-width (must be 4)
group_size int Yes Number of input features per quantization group (-1 for full in_features)
qweight torch.Tensor Yes Packed quantized weight tensor of shape (in_features, out_features * w_bit // 32)
qzeros torch.Tensor Yes Packed quantized zero-point tensor
scales torch.Tensor Yes Per-group scale factors
bias torch.Tensor or None No Optional bias vector

Outputs

Name Type Description
output torch.Tensor Result of quantized linear transformation with shape (*input_shape[:-1], out_features), in the same dtype as the input

Usage Examples

# Internal usage during model loading
from lorax_server.utils.awq.awq import AWQLinear

# Create AWQ linear layer from pre-quantized buffers
layer = AWQLinear(
    w_bit=4,
    group_size=128,
    qweight=qweight_tensor,
    qzeros=qzeros_tensor,
    scales=scales_tensor,
    bias=bias_tensor,
)

# Forward pass during inference
output = layer(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment