Implementation:Mit han lab Llm awq Auto scale block
Appearance
Overview
auto_scale_block is a concrete tool for finding optimal per-channel scaling factors within a single transformer block, provided by the llm-awq library.
Source Location
- Repository: llm-awq (https://github.com/mit-han-lab/llm-awq)
- File: awq/quantize/auto_scale.py
- Lines: 87-446
Signature
@torch.no_grad()
def auto_scale_block(module, module_kwargs, w_bit, q_config, input_feat):
Import
from awq.quantize.auto_scale import auto_scale_block
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| module | nn.Module | Yes | -- | A single transformer block (e.g., one decoder layer) |
| module_kwargs | dict | Yes | -- | Keyword arguments for the block's forward pass (e.g., attention mask, position ids) |
| w_bit | int | Yes | -- | Target quantization bit-width (e.g., 4) |
| q_config | dict | Yes | -- | Quantization configuration with keys zero_point (bool) and q_group_size (int) |
| input_feat | dict | Yes | -- | Cached activation features keyed by linear layer name. Each value is a tensor of input activations collected during calibration |
Output
- list -- A list of tuples, where each tuple contains:
- prev_op_name (str) -- Name of the preceding operation whose weights are modified by the scaling
- layer_names (tuple[str]) -- Names of the linear layers in this group that share the same input
- scales (torch.Tensor) -- The optimal per-channel scaling factors found by grid search
Implementation Details
The function performs per-channel scaling search for all groups of linked linear layers within a single transformer block. The process is:
- Identify layer groups: The function defines groups of linear layers that share a common preceding operation, based on the model architecture. For example, in a LLaMA-style block:
- LayerNorm (input_layernorm) -> Q, K, V projections (self_attn.q_proj, self_attn.k_proj, self_attn.v_proj)
- V output -> output projection (self_attn.o_proj)
- LayerNorm (post_attention_layernorm) -> gate and up projections (mlp.gate_proj, mlp.up_proj)
- Up activation -> down projection (mlp.down_proj)
- For each group, run the scaling search:
- Retrieve cached input activations for the layers in the group from input_feat
- Compute per-channel activation maximums (x_max)
- Iterate over 20 candidate scaling ratios (alpha from 0 to ~1)
- For each candidate, apply the scaling transformation, quantize, measure block-level MSE
- Select the alpha with minimum MSE
- Record the resulting scaling factors
- Return all scaling results as a list of tuples.
The function operates under @torch.no_grad() since no gradient computation is needed. All weight modifications during the grid search are temporary and reverted after each candidate evaluation.
Usage Example
# auto_scale_block is called internally by run_awq.
# The typical internal call pattern is:
import torch
from awq.quantize.auto_scale import auto_scale_block
# Assume `model` is a loaded CausalLM and calibration has been performed
# to collect `input_feat` for each block.
block = model.model.layers[0] # First transformer block
module_kwargs = {
"attention_mask": attention_mask,
"position_ids": position_ids,
}
q_config = {
"zero_point": True,
"q_group_size": 128,
}
# input_feat is a dict mapping layer names to activation tensors, e.g.:
# {
# "self_attn.q_proj": tensor of shape [n_samples, hidden_size],
# "self_attn.k_proj": tensor of shape [n_samples, hidden_size],
# ...
# }
scale_results = auto_scale_block(
module=block,
module_kwargs=module_kwargs,
w_bit=4,
q_config=q_config,
input_feat=input_feat,
)
# scale_results is a list of tuples:
# [
# ("input_layernorm", ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"), scales_tensor),
# ("self_attn.v_proj", ("self_attn.o_proj",), scales_tensor),
# ("post_attention_layernorm", ("mlp.gate_proj", "mlp.up_proj"), scales_tensor),
# ("mlp.up_proj", ("mlp.down_proj",), scales_tensor),
# ]
for prev_op_name, layer_names, scales in scale_results:
print(f"Prev op: {prev_op_name}, Layers: {layer_names}, Scales shape: {scales.shape}")
Related Pages
- Principle:Mit_han_lab_Llm_awq_Per_Channel_Scaling_Search
- Environment:Mit_han_lab_Llm_awq_Python_Runtime_Environment
- Heuristic:Mit_han_lab_Llm_awq_AWQ_Grid_Search_Tuning
- Heuristic:Mit_han_lab_Llm_awq_GPU_Memory_Management_Patterns
Knowledge Sources
- Repo|llm-awq|https://github.com/mit-han-lab/llm-awq
- Paper|AWQ|https://arxiv.org/abs/2306.00978
Domains
- Quantization
- Optimization
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment