Implementation:Predibase Lorax AWQ Quantized Linear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements an AWQ (Activation-aware Weight Quantization) 4-bit quantized linear layer that uses a CUDA-accelerated GEMM kernel for efficient inference.
Description
This module provides a single class adapted from the MIT-HAN-Lab AWQ project.
AWQLinear (nn.Module): A quantized linear layer that stores weights in 4-bit AWQ format and performs matrix multiplication using the awq_inference_engine.gemm_forward_cuda CUDA kernel. Key characteristics:
- 4-bit only: Raises NotImplementedError if w_bit is not 4.
- Group quantization: Weights are quantized in groups of group_size input features. If group_size is -1, it defaults to the full in_features dimension.
- Split-K iterations: Uses split_k_iters=8 for the CUDA GEMM kernel to improve parallelism.
- Automatic dtype casting: Input tensors that are not float16 are cast to half precision before the CUDA kernel call and cast back to the original dtype afterward.
- Bias handling: Optional bias is added after the GEMM operation.
- Weight property: Exposes a weight property that returns the packed qweight tensor for compatibility with weight inspection APIs.
The module validates that in_features is divisible by group_size and that out_features is divisible by 32 // w_bit.
Usage
This module is used as a drop-in replacement for nn.Linear layers when serving AWQ-quantized models. The AWQLinear class is instantiated during model loading when AWQ quantization is detected, with the pre-quantized weight buffers (qweight, qzeros, scales) loaded from the model checkpoint.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/awq/awq.py - Lines: 1-51
Signature
class AWQLinear(nn.Module):
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias)
@torch.no_grad()
def forward(self, x)
@property
def weight(self) -> torch.Tensor
Import
from lorax_server.utils.awq.awq import AWQLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| w_bit | int | Yes | Quantization bit-width (must be 4) |
| group_size | int | Yes | Number of input features per quantization group (-1 for full in_features) |
| qweight | torch.Tensor | Yes | Packed quantized weight tensor of shape (in_features, out_features * w_bit // 32) |
| qzeros | torch.Tensor | Yes | Packed quantized zero-point tensor |
| scales | torch.Tensor | Yes | Per-group scale factors |
| bias | torch.Tensor or None | No | Optional bias vector |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of quantized linear transformation with shape (*input_shape[:-1], out_features), in the same dtype as the input |
Usage Examples
# Internal usage during model loading
from lorax_server.utils.awq.awq import AWQLinear
# Create AWQ linear layer from pre-quantized buffers
layer = AWQLinear(
w_bit=4,
group_size=128,
qweight=qweight_tensor,
qzeros=qzeros_tensor,
scales=scales_tensor,
bias=bias_tensor,
)
# Forward pass during inference
output = layer(input_tensor)