Implementation:Mlc ai Mlc llm AWQ Quantization
Overview
The AWQ Quantization module implements Activation-aware Weight Quantization (AWQ) for MLC LLM models. Located at python/mlc_llm/quantization/awq_quantization.py (282 lines), it provides the configuration class AWQQuantize, the quantized linear layer AWQQuantizeLinear, and supporting utility functions for group-wise weight quantization and dequantization using TVM's tensor expression (TE) and tensor IR (TIR) primitives.
Purpose
AWQ is a low-bit weight quantization technique that groups weight elements, quantizes them to integer types (e.g., int3, int4, int8), and stores them in packed unsigned integer storage. This module:
- Defines the quantization configuration and validates dtype constraints
- Provides a model mutation mechanism that replaces
nn.Linearlayers with quantized counterparts - Implements dequantization logic using TVM TE compute expressions for on-the-fly weight reconstruction during inference
Key Components
Helper Functions
_make_divisible: Rounds up a value to the nearest multiple of a divisor:
def _make_divisible(c, divisor):
return (c + divisor - 1) // divisor
_calculate_zeros_width: Calculates the width of the zeros tensor based on input features, group size, and packing:
def _calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = _make_divisible(in_features // group_size, pack_num)
base_width = _make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
AWQQuantize Configuration Class
A dataclass defining the full AWQ quantization configuration:
@dataclass
class AWQQuantize:
name: str
kind: str
group_size: int
quantize_dtype: str # "int3", "int4", "int8"
storage_dtype: str # "uint32"
model_dtype: str # "float16", "float32"
num_elem_per_storage: int = 0
num_storage_per_group: int = 0
max_int_value: int = 0
prebuilt_quantize_func: Dict[str, Callable[[Tensor], Tensor]] = field(default_factory=lambda: {})
Post-initialization validation (__post_init__):
def __post_init__(self):
assert self.kind == "awq"
quantize_dtype = DataType(self.quantize_dtype)
storage_dtype = DataType(self.storage_dtype)
model_dtype = DataType(self.model_dtype)
assert quantize_dtype.type_code == DataTypeCode.INT
assert storage_dtype.type_code == DataTypeCode.UINT
assert model_dtype.type_code == DataTypeCode.FLOAT
if storage_dtype.bits < quantize_dtype.bits:
raise ValueError("Storage unit should be greater or equal to quantized element")
self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits
self.num_storage_per_group = self.group_size // self.num_elem_per_storage
self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1
This computes derived values: the number of quantized elements packed per storage unit, the number of storage units per group, and the maximum representable integer value.
quantize_model Method
Applies AWQ quantization to an entire model by replacing nn.Linear layers:
def quantize_model(self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str) -> nn.Module:
class _Mutator(nn.Mutator):
def __init__(self, config: AWQQuantize, quant_map: QuantizeMapping) -> None:
super().__init__()
self.config = config
self.quant_map = quant_map
def visit_module(self, name: str, node: nn.Module) -> Any:
if (
isinstance(node, nn.Linear)
and not is_final_fc(name)
and not is_moe_gate(name, node)
):
return AWQQuantizeLinear.from_linear(node, self.config)
return self.visit(name, node)
model.to(dtype=self.model_dtype)
mutator = _Mutator(self, quant_map)
model = mutator.visit(name_prefix, model)
return model
The method uses the nn.Mutator pattern to traverse the model graph. It excludes the final fully-connected layer (e.g., lm_head) and MoE gate layers from quantization by calling is_final_fc and is_moe_gate from the Quantization Utils module.
_dequantize Method
Implements the dequantization computation using TVM tensor expressions. It converts packed integer weights back to floating-point values using zero-point subtraction and scale multiplication:
def _dequantize(self, weight, zeros, scale, out_shape=None):
float_weight = convert_uint_to_float(weight, ...)
float_zeros = convert_uint_to_float(zeros, ...)
float_weight = topi.transpose(float_weight)
float_zeros = topi.transpose(float_zeros)
scale = topi.transpose(scale)
return te.compute(
shape=...,
fcompute=lambda i, j: tir.multiply(
tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]),
scale[i, j // self.group_size],
),
name="dequantize",
)
The dequantization formula is: output[i, j] = (float_weight[i, j] - float_zeros[i, j // group_size]) * scale[i, j // group_size]
The ft_reorder=True flag is passed to convert_uint_to_float to apply FasterTransformer-style bit reordering during unpacking.
AWQQuantizeLinear Class
A quantized replacement for nn.Linear that stores weights in packed AWQ format:
Parameters:
class AWQQuantizeLinear(nn.Module):
def __init__(self, in_features, out_features, config, bias=True, out_dtype=None):
self.qweight = nn.Parameter(
(in_features, out_features // config.num_elem_per_storage), config.storage_dtype)
self.qzeros = nn.Parameter(
(in_features // config.group_size, out_features // config.num_elem_per_storage),
config.storage_dtype)
self.scales = nn.Parameter(
(in_features // config.group_size, out_features), config.model_dtype)
if bias:
self.bias = nn.Parameter(
(out_features,), config.model_dtype if out_dtype is None else out_dtype)
| Parameter | Shape | Dtype | Description |
|---|---|---|---|
qweight |
(in_features, out_features // num_elem_per_storage) |
storage dtype | Packed quantized weights |
qzeros |
(in_features // group_size, out_features // num_elem_per_storage) |
storage dtype | Packed quantized zero points |
scales |
(in_features // group_size, out_features) |
model dtype | Per-group scale factors |
bias |
(out_features,) |
model/out dtype | Optional bias vector |
Forward Method:
def forward(self, x: nn.Tensor) -> nn.Tensor:
w = nn.op.tensor_expr_op(
lambda weight, zeros, scale: self.config._dequantize(weight, zeros, scale, [...]),
name_hint="dequantize",
args=[self.qweight, self.qzeros, self.scales],
)
w = nn.op.permute_dims(w)
x = nn.op.matmul(x, w, out_dtype=self.out_dtype)
if self.bias is not None:
x = x + self.bias
return x
The forward pass dequantizes weights on-the-fly using tensor_expr_op, transposes the result, performs matrix multiplication, and optionally adds bias.
from_linear Static Method: Converts a standard nn.Linear to an AWQQuantizeLinear by copying dimensions and configuration.
to Method Override: Overrides the default to() to avoid converting bias dtype when out_dtype is specified, preventing dtype mismatches.
Dependencies
tvm-- TVM compiler framework (DataType,DataTypeCode,te,tir,topi)tvm.relax.frontend.nn-- Neural network module abstractionmlc_llm.loader.QuantizeMapping-- Quantization name/function mappingmlc_llm.quantization.utils-- Utility functions (convert_uint_to_float,is_final_fc,is_moe_gate)
File Location
python/mlc_llm/quantization/awq_quantization.py