Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mit han lab Llm awq Run awq

From Leeroopedia

Overview

run_awq is a concrete tool for orchestrating the full AWQ (Activation-Aware Weight Quantization) search pipeline, provided by the llm-awq library.

Source Location

Signature

@torch.no_grad()
def run_awq(
    model,
    enc,
    w_bit,
    q_config,
    n_samples=512,
    seqlen=512,
    auto_scale=True,
    mse_range=True,
    calib_data="pileval",
):

Import

from awq.quantize.pre_quant import run_awq

I/O Contract

Inputs

Parameter Type Required Default Description
model nn.Module Yes -- The FP16 CausalLM model to be quantized
enc PreTrainedTokenizer Yes -- Tokenizer for encoding calibration data
w_bit int Yes -- Target quantization bit-width (e.g., 4 for INT4)
q_config dict Yes -- Quantization configuration with keys zero_point (bool) and q_group_size (int)
n_samples int No 512 Number of calibration samples to collect
seqlen int No 512 Sequence length for calibration blocks
auto_scale bool No True Whether to perform per-channel scaling search
mse_range bool No True Whether to perform MSE-based clipping range search
calib_data str No "pileval" Name of the calibration dataset

Output

  • dict -- A dictionary containing:
    • "scale" -- A list of tuples, each containing (prev_op_name, layer_names, scales_tensor), representing the optimal per-channel scaling factors found for each group of linear layers.
    • "clip" -- A list of tuples, each containing (layer_name, max_val_tensor), representing the optimal clipping ranges found via MSE-based search.

Implementation Details

The function orchestrates the full AWQ pipeline in the following steps:

  1. Load calibration data: Calls get_calib_dataset to load and tokenize calibration samples.
  2. Collect activation features: Passes calibration data through the model, hooking into each transformer block to capture input activations for every linear layer.
  3. Iterate over transformer blocks: For each block in the model:
    1. If auto_scale is enabled, calls auto_scale_block to find optimal per-channel scaling factors.
    2. If mse_range is enabled, calls auto_clip_block to find optimal weight clipping ranges.
    3. Records the scaling and clipping results.
  4. Return results: Returns the collected scaling and clipping parameters as a dictionary.

The function operates under @torch.no_grad() since it performs only forward passes and grid search -- no gradient computation is needed.

Usage Example

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from awq.quantize.pre_quant import run_awq

# Load the FP16 model and tokenizer
model_path = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# Define quantization configuration
q_config = {
    "zero_point": True,     # Use asymmetric quantization
    "q_group_size": 128,    # Group size for group-wise quantization
}

# Run the AWQ search
awq_results = run_awq(
    model=model,
    enc=tokenizer,
    w_bit=4,
    q_config=q_config,
    n_samples=512,
    seqlen=512,
    auto_scale=True,
    mse_range=True,
    calib_data="pileval",
)

# Save the AWQ results for later application
torch.save(awq_results, "awq_results.pt")
print(f"Saved {len(awq_results['scale'])} scale entries and {len(awq_results['clip'])} clip entries")

Related Pages

Knowledge Sources

Domains

  • NLP
  • Quantization

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment