Implementation:Predibase Lorax AWQ WQLinear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a 4-bit quantized linear layer using the AWQ inference engine CUDA kernels for weight-quantized GEMM with group-wise scaling.
Description
This module is adapted from the MIT-HAN-Lab llm-awq project and provides a quantized linear layer using AWQ (Activation-aware Weight Quantization):
WQLinear: An nn.Module that performs 4-bit weight-quantized linear transformation using the awq_inference_engine CUDA extension. During __init__:
- Validates that w_bit is 4 (only 4-bit is supported).
- Computes in_features from qweight.shape[0] and out_features as qweight.shape[1] * 32 // w_bit (since weights are packed as 32-bit integers with 8 four-bit values each).
- Sets group_size for per-group dequantization, defaulting to in_features when -1 is passed.
- Performs alignment assertions: in_features must be divisible by group_size, and out_features must be divisible by 8 (32 // 4).
The forward method (decorated with @torch.no_grad()):
- Reshapes input to 2D.
- Calls awq_inference_engine.gemm_forward_cuda with the quantized weight, scales, zeros, and a pack factor of 8.
- Adds optional bias and reshapes output to match the input batch dimensions.
A weight property returns the raw qweight tensor for compatibility with the layer interface.
Usage
This layer is instantiated by the get_linear factory (indirectly through AWQLinear) when quantize == "awq". It requires the awq_inference_engine CUDA extension to be installed.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/awq/quantize/qmodule.py
- Lines: 1-52
Signature
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
Import
from lorax_server.layers.awq.quantize.qmodule import WQLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| w_bit | int | Yes | Quantization bit width, must be 4 |
| group_size | int | Yes | Number of input features per quantization group (-1 for full-row grouping) |
| qweight | torch.Tensor (int32) | Yes | Packed 4-bit quantized weight matrix of shape (in_features, out_features // 8) |
| qzeros | torch.Tensor (int32) | Yes | Packed quantized zero points |
| scales | torch.Tensor | Yes | Per-group scale factors |
| bias | torch.Tensor or None | No | Optional bias tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of the AWQ quantized linear transformation, shape (*input_shape[:-1], out_features) |
Usage Examples
# Used internally by model layers
from lorax_server.layers.awq.quantize.qmodule import WQLinear
layer = WQLinear(w_bit=4, group_size=128, qweight=qweight, qzeros=qzeros, scales=scales, bias=None)
output = layer(input_tensor)