Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Mit han lab Llm awq Auto scale block

From Leeroopedia

Overview

auto_scale_block is a concrete tool for finding optimal per-channel scaling factors within a single transformer block, provided by the llm-awq library.

Source Location

Signature

@torch.no_grad()
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat):

Import

from awq.quantize.auto_scale import auto_scale_block

I/O Contract

Inputs

Parameter Type Required Default Description
module nn.Module Yes -- A single transformer block (e.g., one decoder layer)
module_kwargs dict Yes -- Keyword arguments for the block's forward pass (e.g., attention mask, position ids)
w_bit int Yes -- Target quantization bit-width (e.g., 4)
q_config dict Yes -- Quantization configuration with keys zero_point (bool) and q_group_size (int)
input_feat dict Yes -- Cached activation features keyed by linear layer name. Each value is a tensor of input activations collected during calibration

Output

  • list -- A list of tuples, where each tuple contains:
    • prev_op_name (str) -- Name of the preceding operation whose weights are modified by the scaling
    • layer_names (tuple[str]) -- Names of the linear layers in this group that share the same input
    • scales (torch.Tensor) -- The optimal per-channel scaling factors found by grid search

Implementation Details

The function performs per-channel scaling search for all groups of linked linear layers within a single transformer block. The process is:

  1. Identify layer groups: The function defines groups of linear layers that share a common preceding operation, based on the model architecture. For example, in a LLaMA-style block:
    • LayerNorm (input_layernorm) -> Q, K, V projections (self_attn.q_proj, self_attn.k_proj, self_attn.v_proj)
    • V output -> output projection (self_attn.o_proj)
    • LayerNorm (post_attention_layernorm) -> gate and up projections (mlp.gate_proj, mlp.up_proj)
    • Up activation -> down projection (mlp.down_proj)
  2. For each group, run the scaling search:
    1. Retrieve cached input activations for the layers in the group from input_feat
    2. Compute per-channel activation maximums (x_max)
    3. Iterate over 20 candidate scaling ratios (alpha from 0 to ~1)
    4. For each candidate, apply the scaling transformation, quantize, measure block-level MSE
    5. Select the alpha with minimum MSE
    6. Record the resulting scaling factors
  3. Return all scaling results as a list of tuples.

The function operates under @torch.no_grad() since no gradient computation is needed. All weight modifications during the grid search are temporary and reverted after each candidate evaluation.

Usage Example

# auto_scale_block is called internally by run_awq.
# The typical internal call pattern is:

import torch
from awq.quantize.auto_scale import auto_scale_block

# Assume `model` is a loaded CausalLM and calibration has been performed
# to collect `input_feat` for each block.

block = model.model.layers[0]  # First transformer block
module_kwargs = {
    "attention_mask": attention_mask,
    "position_ids": position_ids,
}

q_config = {
    "zero_point": True,
    "q_group_size": 128,
}

# input_feat is a dict mapping layer names to activation tensors, e.g.:
# {
#     "self_attn.q_proj": tensor of shape [n_samples, hidden_size],
#     "self_attn.k_proj": tensor of shape [n_samples, hidden_size],
#     ...
# }

scale_results = auto_scale_block(
    module=block,
    module_kwargs=module_kwargs,
    w_bit=4,
    q_config=q_config,
    input_feat=input_feat,
)

# scale_results is a list of tuples:
# [
#     ("input_layernorm", ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"), scales_tensor),
#     ("self_attn.v_proj", ("self_attn.o_proj",), scales_tensor),
#     ("post_attention_layernorm", ("mlp.gate_proj", "mlp.up_proj"), scales_tensor),
#     ("mlp.up_proj", ("mlp.down_proj",), scales_tensor),
# ]
for prev_op_name, layer_names, scales in scale_results:
    print(f"Prev op: {prev_op_name}, Layers: {layer_names}, Scales shape: {scales.shape}")

Related Pages

Knowledge Sources

Domains

  • Quantization
  • Optimization

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment