Implementation:Mit han lab Llm awq W8A8OF16LinearDynamicInputScale
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Model_Architecture |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Core quantized linear layer implementing INT8 weight, INT8 activation GEMM with dynamic per-token input scaling and FP16 output accumulation.
Description
W8A8OF16LinearDynamicInputScale extends W8A8OF16LinearStaticScale to support dynamic per-token input quantization scales provided at runtime. Weights are statically quantized to INT8 at construction time with per-output-channel dequantization scales. At forward time, the layer dispatches to either awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda (when bias is present) or awq_inference_engine.w8a8_gemm_forward_cuda (without bias) for fused INT8 GEMM computation. Results are written directly into a caller-provided output buffer to avoid allocation overhead.
W8A8OF16LinearStaticScale is the base class that defines the INT8 weight storage, per-channel dequantization scale buffer, and the weight creation logic.
FakeW8A8Linear provides a fake-quantization wrapper for training or calibration: it simulates INT8 weight quantization in FP16 by rounding weights to the nearest quantization grid point, and applies dynamic per-token activation quantization during forward passes.
The fake_quant function traverses a model and replaces all nn.Linear modules with FakeW8A8Linear instances for quantization-aware evaluation.
The class methods from_linear and from_qkv provide factory constructors for converting pre-trained FP16 linear layers (or fused Q/K/V projection triplets) into quantized INT8 layers with computed per-channel scales.
Usage
Import W8A8OF16LinearDynamicInputScale to build quantized model layers in the W8A8 quantization pipeline. Use from_linear to convert individual layers and from_qkv to create fused QKV projection layers. Use FakeW8A8Linear and fake_quant for quantization-aware calibration.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: awq/quantize/w8a8_linear.py
- Lines: 1-276
Signature
class W8A8OF16LinearStaticScale(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
scale: Union[torch.tensor, float] = 1.0,
params_dtype: Optional[torch.dtype] = None): ...
def create_weights(self) -> None: ...
def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: ...
def forward(self, input_) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ...
class W8A8OF16LinearDynamicInputScale(W8A8OF16LinearStaticScale):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
scale: Union[torch.tensor, float] = 1.0,
params_dtype: Optional[torch.dtype] = None): ...
def apply_weights_bias(self, x: torch.Tensor, input_scale: torch.Tensor,
output_buffer: torch.Tensor, bias: torch.Tensor = None): ...
def apply_weights_no_bias(self, x: torch.Tensor, input_scale: torch.Tensor,
output_buffer: torch.Tensor, bias: torch.Tensor = None): ...
def forward(self, input_: torch.Tensor, input_scale: torch.Tensor,
output_buffer: torch.Tensor) -> None: ...
@classmethod
def from_linear(cls, linear, init_only=False, s1_scale=None, fc1=False)
-> "W8A8OF16LinearDynamicInputScale": ...
@classmethod
def from_qkv(cls, q, k, v, init_only=False, s1_scale=None)
-> "W8A8OF16LinearDynamicInputScale": ...
class FakeW8A8Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
wbit: int = 8): ...
def forward(self, input: torch.Tensor) -> torch.Tensor: ...
@classmethod
def from_linear(cls, linear: torch.nn.Linear, wbit=8) -> "FakeW8A8Linear": ...
def fake_quant(model, wbit=8) -> None:
"""Replace all nn.Linear modules with FakeW8A8Linear for fake quantization."""
Import
from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale
I/O Contract
Inputs (W8A8OF16LinearDynamicInputScale.forward)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ | torch.Tensor (int8) | Yes | Quantized input tensor of shape (batch * tokens, in_features) |
| input_scale | torch.Tensor (float16) | Yes | Per-token dynamic quantization scales of shape (batch * tokens,) |
| output_buffer | torch.Tensor (float16) | Yes | Pre-allocated output buffer of shape (batch * tokens, out_features) |
Outputs (W8A8OF16LinearDynamicInputScale.forward)
| Name | Type | Description |
|---|---|---|
| (in-place) | None | Result is written directly into output_buffer; no return value |
Inputs (from_linear)
| Name | Type | Required | Description |
|---|---|---|---|
| linear | nn.Linear | Yes | Pre-trained FP16 linear layer to quantize |
| init_only | bool | No | If True, create structure without quantizing weights (for state dict loading) |
| s1_scale | torch.Tensor | No | Optional pre-computed per-channel weight scales; auto-computed if None |
| fc1 | bool | No | Flag for FC1-specific handling (default: False) |
Inputs (from_qkv)
| Name | Type | Required | Description |
|---|---|---|---|
| q | nn.Linear | Yes | Query projection linear layer |
| k | nn.Linear | Yes | Key projection linear layer |
| v | nn.Linear | Yes | Value projection linear layer |
| init_only | bool | No | If True, create structure without quantizing weights |
| s1_scale | torch.Tensor | No | Optional pre-computed weight scales |
Inputs (fake_quant)
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Model whose nn.Linear layers will be replaced with FakeW8A8Linear |
| wbit | int | No | Weight bit width for fake quantization (default: 8) |
Usage Examples
Convert a Linear Layer to W8A8
from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale
import torch
# Convert a pre-trained linear layer
original_linear = torch.nn.Linear(4096, 4096).cuda().half()
quant_linear = W8A8OF16LinearDynamicInputScale.from_linear(original_linear)
# Use with pre-quantized INT8 input and dynamic scale
input_int8 = torch.randint(-128, 127, (32, 4096), dtype=torch.int8, device="cuda")
input_scale = torch.ones(32, dtype=torch.float16, device="cuda")
output_buffer = torch.empty(32, 4096, dtype=torch.float16, device="cuda")
quant_linear(input_int8, input_scale, output_buffer)
Fuse Q/K/V into Single Quantized Layer
from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale
# Fuse three separate projection layers
q_proj = model.attention.q_proj # nn.Linear(4096, 4096)
k_proj = model.attention.k_proj # nn.Linear(4096, 4096)
v_proj = model.attention.v_proj # nn.Linear(4096, 4096)
fused_qkv = W8A8OF16LinearDynamicInputScale.from_qkv(q_proj, k_proj, v_proj)
# fused_qkv has out_features = 4096 * 3 = 12288
Apply Fake Quantization for Calibration
from awq.quantize.w8a8_linear import fake_quant
# Replace all nn.Linear modules with fake-quantized versions
fake_quant(model, wbit=8)
# Model now simulates INT8 quantization effects during FP16 forward passes
output = model(input_ids)