Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq W8A8OF16LinearDynamicInputScale

From Leeroopedia
Revision as of 13:16, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mit_han_lab_Llm_awq_W8A8OF16LinearDynamicInputScale.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains Quantization, Model_Architecture
Last Updated 2026-02-15 00:00 GMT

Overview

Core quantized linear layer implementing INT8 weight, INT8 activation GEMM with dynamic per-token input scaling and FP16 output accumulation.

Description

W8A8OF16LinearDynamicInputScale extends W8A8OF16LinearStaticScale to support dynamic per-token input quantization scales provided at runtime. Weights are statically quantized to INT8 at construction time with per-output-channel dequantization scales. At forward time, the layer dispatches to either awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda (when bias is present) or awq_inference_engine.w8a8_gemm_forward_cuda (without bias) for fused INT8 GEMM computation. Results are written directly into a caller-provided output buffer to avoid allocation overhead.

W8A8OF16LinearStaticScale is the base class that defines the INT8 weight storage, per-channel dequantization scale buffer, and the weight creation logic.

FakeW8A8Linear provides a fake-quantization wrapper for training or calibration: it simulates INT8 weight quantization in FP16 by rounding weights to the nearest quantization grid point, and applies dynamic per-token activation quantization during forward passes.

The fake_quant function traverses a model and replaces all nn.Linear modules with FakeW8A8Linear instances for quantization-aware evaluation.

The class methods from_linear and from_qkv provide factory constructors for converting pre-trained FP16 linear layers (or fused Q/K/V projection triplets) into quantized INT8 layers with computed per-channel scales.

Usage

Import W8A8OF16LinearDynamicInputScale to build quantized model layers in the W8A8 quantization pipeline. Use from_linear to convert individual layers and from_qkv to create fused QKV projection layers. Use FakeW8A8Linear and fake_quant for quantization-aware calibration.

Code Reference

Source Location

Signature

class W8A8OF16LinearStaticScale(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 scale: Union[torch.tensor, float] = 1.0,
                 params_dtype: Optional[torch.dtype] = None): ...
    def create_weights(self) -> None: ...
    def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: ...
    def forward(self, input_) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ...

class W8A8OF16LinearDynamicInputScale(W8A8OF16LinearStaticScale):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 scale: Union[torch.tensor, float] = 1.0,
                 params_dtype: Optional[torch.dtype] = None): ...
    def apply_weights_bias(self, x: torch.Tensor, input_scale: torch.Tensor,
                           output_buffer: torch.Tensor, bias: torch.Tensor = None): ...
    def apply_weights_no_bias(self, x: torch.Tensor, input_scale: torch.Tensor,
                              output_buffer: torch.Tensor, bias: torch.Tensor = None): ...
    def forward(self, input_: torch.Tensor, input_scale: torch.Tensor,
                output_buffer: torch.Tensor) -> None: ...
    @classmethod
    def from_linear(cls, linear, init_only=False, s1_scale=None, fc1=False)
        -> "W8A8OF16LinearDynamicInputScale": ...
    @classmethod
    def from_qkv(cls, q, k, v, init_only=False, s1_scale=None)
        -> "W8A8OF16LinearDynamicInputScale": ...

class FakeW8A8Linear(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 wbit: int = 8): ...
    def forward(self, input: torch.Tensor) -> torch.Tensor: ...
    @classmethod
    def from_linear(cls, linear: torch.nn.Linear, wbit=8) -> "FakeW8A8Linear": ...

def fake_quant(model, wbit=8) -> None:
    """Replace all nn.Linear modules with FakeW8A8Linear for fake quantization."""

Import

from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale

I/O Contract

Inputs (W8A8OF16LinearDynamicInputScale.forward)

Name Type Required Description
input_ torch.Tensor (int8) Yes Quantized input tensor of shape (batch * tokens, in_features)
input_scale torch.Tensor (float16) Yes Per-token dynamic quantization scales of shape (batch * tokens,)
output_buffer torch.Tensor (float16) Yes Pre-allocated output buffer of shape (batch * tokens, out_features)

Outputs (W8A8OF16LinearDynamicInputScale.forward)

Name Type Description
(in-place) None Result is written directly into output_buffer; no return value

Inputs (from_linear)

Name Type Required Description
linear nn.Linear Yes Pre-trained FP16 linear layer to quantize
init_only bool No If True, create structure without quantizing weights (for state dict loading)
s1_scale torch.Tensor No Optional pre-computed per-channel weight scales; auto-computed if None
fc1 bool No Flag for FC1-specific handling (default: False)

Inputs (from_qkv)

Name Type Required Description
q nn.Linear Yes Query projection linear layer
k nn.Linear Yes Key projection linear layer
v nn.Linear Yes Value projection linear layer
init_only bool No If True, create structure without quantizing weights
s1_scale torch.Tensor No Optional pre-computed weight scales

Inputs (fake_quant)

Name Type Required Description
model nn.Module Yes Model whose nn.Linear layers will be replaced with FakeW8A8Linear
wbit int No Weight bit width for fake quantization (default: 8)

Usage Examples

Convert a Linear Layer to W8A8

from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale
import torch

# Convert a pre-trained linear layer
original_linear = torch.nn.Linear(4096, 4096).cuda().half()
quant_linear = W8A8OF16LinearDynamicInputScale.from_linear(original_linear)

# Use with pre-quantized INT8 input and dynamic scale
input_int8 = torch.randint(-128, 127, (32, 4096), dtype=torch.int8, device="cuda")
input_scale = torch.ones(32, dtype=torch.float16, device="cuda")
output_buffer = torch.empty(32, 4096, dtype=torch.float16, device="cuda")
quant_linear(input_int8, input_scale, output_buffer)

Fuse Q/K/V into Single Quantized Layer

from awq.quantize.w8a8_linear import W8A8OF16LinearDynamicInputScale

# Fuse three separate projection layers
q_proj = model.attention.q_proj  # nn.Linear(4096, 4096)
k_proj = model.attention.k_proj  # nn.Linear(4096, 4096)
v_proj = model.attention.v_proj  # nn.Linear(4096, 4096)

fused_qkv = W8A8OF16LinearDynamicInputScale.from_qkv(q_proj, k_proj, v_proj)
# fused_qkv has out_features = 4096 * 3 = 12288

Apply Fake Quantization for Calibration

from awq.quantize.w8a8_linear import fake_quant

# Replace all nn.Linear modules with fake-quantized versions
fake_quant(model, wbit=8)

# Model now simulates INT8 quantization effects during FP16 forward passes
output = model(input_ids)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment