Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm AWQ Quantization

From Leeroopedia
Revision as of 15:48, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_AWQ_Quantization.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

The AWQ Quantization module implements Activation-aware Weight Quantization (AWQ) for MLC LLM models. Located at python/mlc_llm/quantization/awq_quantization.py (282 lines), it provides the configuration class AWQQuantize, the quantized linear layer AWQQuantizeLinear, and supporting utility functions for group-wise weight quantization and dequantization using TVM's tensor expression (TE) and tensor IR (TIR) primitives.

Purpose

AWQ is a low-bit weight quantization technique that groups weight elements, quantizes them to integer types (e.g., int3, int4, int8), and stores them in packed unsigned integer storage. This module:

  • Defines the quantization configuration and validates dtype constraints
  • Provides a model mutation mechanism that replaces nn.Linear layers with quantized counterparts
  • Implements dequantization logic using TVM TE compute expressions for on-the-fly weight reconstruction during inference

Key Components

Helper Functions

_make_divisible: Rounds up a value to the nearest multiple of a divisor:

def _make_divisible(c, divisor):
    return (c + divisor - 1) // divisor

_calculate_zeros_width: Calculates the width of the zeros tensor based on input features, group size, and packing:

def _calculate_zeros_width(in_features, group_size=128, pack_num=8):
    if group_size >= 128:
        size_multiplier = 1
    elif group_size == 64:
        size_multiplier = 2
    elif group_size == 32:
        size_multiplier = 4
    else:
        raise NotImplementedError

    base_width = _make_divisible(in_features // group_size, pack_num)
    base_width = _make_divisible(base_width, size_multiplier) * size_multiplier
    return base_width

AWQQuantize Configuration Class

A dataclass defining the full AWQ quantization configuration:

@dataclass
class AWQQuantize:
    name: str
    kind: str
    group_size: int
    quantize_dtype: str   # "int3", "int4", "int8"
    storage_dtype: str    # "uint32"
    model_dtype: str      # "float16", "float32"
    num_elem_per_storage: int = 0
    num_storage_per_group: int = 0
    max_int_value: int = 0
    prebuilt_quantize_func: Dict[str, Callable[[Tensor], Tensor]] = field(default_factory=lambda: {})

Post-initialization validation (__post_init__):

def __post_init__(self):
    assert self.kind == "awq"
    quantize_dtype = DataType(self.quantize_dtype)
    storage_dtype = DataType(self.storage_dtype)
    model_dtype = DataType(self.model_dtype)
    assert quantize_dtype.type_code == DataTypeCode.INT
    assert storage_dtype.type_code == DataTypeCode.UINT
    assert model_dtype.type_code == DataTypeCode.FLOAT
    if storage_dtype.bits < quantize_dtype.bits:
        raise ValueError("Storage unit should be greater or equal to quantized element")
    self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits
    self.num_storage_per_group = self.group_size // self.num_elem_per_storage
    self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1

This computes derived values: the number of quantized elements packed per storage unit, the number of storage units per group, and the maximum representable integer value.

quantize_model Method

Applies AWQ quantization to an entire model by replacing nn.Linear layers:

def quantize_model(self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str) -> nn.Module:
    class _Mutator(nn.Mutator):
        def __init__(self, config: AWQQuantize, quant_map: QuantizeMapping) -> None:
            super().__init__()
            self.config = config
            self.quant_map = quant_map

        def visit_module(self, name: str, node: nn.Module) -> Any:
            if (
                isinstance(node, nn.Linear)
                and not is_final_fc(name)
                and not is_moe_gate(name, node)
            ):
                return AWQQuantizeLinear.from_linear(node, self.config)
            return self.visit(name, node)

    model.to(dtype=self.model_dtype)
    mutator = _Mutator(self, quant_map)
    model = mutator.visit(name_prefix, model)
    return model

The method uses the nn.Mutator pattern to traverse the model graph. It excludes the final fully-connected layer (e.g., lm_head) and MoE gate layers from quantization by calling is_final_fc and is_moe_gate from the Quantization Utils module.

_dequantize Method

Implements the dequantization computation using TVM tensor expressions. It converts packed integer weights back to floating-point values using zero-point subtraction and scale multiplication:

def _dequantize(self, weight, zeros, scale, out_shape=None):
    float_weight = convert_uint_to_float(weight, ...)
    float_zeros = convert_uint_to_float(zeros, ...)
    float_weight = topi.transpose(float_weight)
    float_zeros = topi.transpose(float_zeros)
    scale = topi.transpose(scale)
    return te.compute(
        shape=...,
        fcompute=lambda i, j: tir.multiply(
            tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]),
            scale[i, j // self.group_size],
        ),
        name="dequantize",
    )

The dequantization formula is: output[i, j] = (float_weight[i, j] - float_zeros[i, j // group_size]) * scale[i, j // group_size]

The ft_reorder=True flag is passed to convert_uint_to_float to apply FasterTransformer-style bit reordering during unpacking.

AWQQuantizeLinear Class

A quantized replacement for nn.Linear that stores weights in packed AWQ format:

Parameters:

class AWQQuantizeLinear(nn.Module):
    def __init__(self, in_features, out_features, config, bias=True, out_dtype=None):
        self.qweight = nn.Parameter(
            (in_features, out_features // config.num_elem_per_storage), config.storage_dtype)
        self.qzeros = nn.Parameter(
            (in_features // config.group_size, out_features // config.num_elem_per_storage),
            config.storage_dtype)
        self.scales = nn.Parameter(
            (in_features // config.group_size, out_features), config.model_dtype)
        if bias:
            self.bias = nn.Parameter(
                (out_features,), config.model_dtype if out_dtype is None else out_dtype)
Parameter Shape Dtype Description
qweight (in_features, out_features // num_elem_per_storage) storage dtype Packed quantized weights
qzeros (in_features // group_size, out_features // num_elem_per_storage) storage dtype Packed quantized zero points
scales (in_features // group_size, out_features) model dtype Per-group scale factors
bias (out_features,) model/out dtype Optional bias vector

Forward Method:

def forward(self, x: nn.Tensor) -> nn.Tensor:
    w = nn.op.tensor_expr_op(
        lambda weight, zeros, scale: self.config._dequantize(weight, zeros, scale, [...]),
        name_hint="dequantize",
        args=[self.qweight, self.qzeros, self.scales],
    )
    w = nn.op.permute_dims(w)
    x = nn.op.matmul(x, w, out_dtype=self.out_dtype)
    if self.bias is not None:
        x = x + self.bias
    return x

The forward pass dequantizes weights on-the-fly using tensor_expr_op, transposes the result, performs matrix multiplication, and optionally adds bias.

from_linear Static Method: Converts a standard nn.Linear to an AWQQuantizeLinear by copying dimensions and configuration.

to Method Override: Overrides the default to() to avoid converting bias dtype when out_dtype is specified, preventing dtype mismatches.

Dependencies

  • tvm -- TVM compiler framework (DataType, DataTypeCode, te, tir, topi)
  • tvm.relax.frontend.nn -- Neural network module abstraction
  • mlc_llm.loader.QuantizeMapping -- Quantization name/function mapping
  • mlc_llm.quantization.utils -- Utility functions (convert_uint_to_float, is_final_fc, is_moe_gate)

File Location

python/mlc_llm/quantization/awq_quantization.py

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment