Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Block Scale Quantization

From Leeroopedia
Revision as of 15:48, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_Block_Scale_Quantization.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Quantization, FP8, Linear Layers, MoE
Last Updated 2026-02-09 19:00 GMT

Overview

Block-scale quantization configuration and quantized module implementations for FP8 inference in MLC LLM, supporting both standard linear layers and Mixture of Experts (MoE) architectures with CUTLASS and Triton backends.

Description

The block_scale_quantization module implements the complete block-scale FP8 quantization pipeline for MLC LLM models. Block-scale quantization divides weight (and optionally activation) tensors into fixed-size blocks, computing a single FP32 scale factor per block for dequantization. This approach balances quantization accuracy with memory savings.

Core Components:

  • BlockScaleQuantize (dataclass): The main quantization configuration. Specifies weight dtype (float8_e4m3fn or float8_e5m2), model dtype (float16 or bfloat16), block size, and whether to use static activation scales. The quantize_model method walks the model graph using a nn.Mutator and replaces eligible nn.Linear and MixtralExperts modules with their quantized counterparts. Final FC layers and MoE gate layers are excluded from quantization. It also handles DeepSeek-specific w_uk/w_uv parameters for fused KV-lora attention.
  • BlockScaleQuantizeLinear: Replaces nn.Linear for block-scale FP8 inference. Stores quantized weights and per-block scale_inv parameters. The forward pass:
    • For single-token (m=1): uses a specialized GEMV path (dequantize_float8_groupwise_scaled_gemv)
    • For batched input: dynamically quantizes activations using rowwise_group_quant_fp8, then dispatches to either CUTLASS or Triton FP8 GEMM
  • BlockScaleQuantizeLinearStaticActivation: Extends BlockScaleQuantizeLinear for models with pre-computed activation scales. Stores per-group activation scales and applies them during the forward pass via static_activation_group_quant_fp8.
  • BlockScaleQuantizeMixtralExperts: Replaces MixtralExperts for block-scale FP8 MoE inference. Supports both single-token GEMV and batched group GEMM paths, dispatching to either CUTLASS or Triton backends.

Utility Functions:

  • rowwise_group_quant_fp8: Dynamically quantizes activations to FP8 with per-group scales. Computes max absolute values per group, derives scale factors, and clips to FP8 range.
  • static_activation_group_quant_fp8: Quantizes activations using pre-computed static scales instead of dynamic per-token scales.
  • broadcast_activation_scale: Broadcasts 1D activation scale tensors to match the input tensor shape, with optional transpose for CUTLASS vs Triton layout differences.
  • dequantize_float8_groupwise_scaled_gemv: A TIR kernel for single-token FP8 GEMV that dequantizes weights on-the-fly and performs the dot product.

Usage

Use this module when deploying FP8 quantized models in MLC LLM. The BlockScaleQuantize configuration is specified in model quantization configs and automatically applied during model compilation. The quantized modules handle the FP8 compute path at runtime, transparently switching between CUTLASS and Triton backends depending on availability.

Code Reference

Source Location

Signature

@dataclass
class BlockScaleQuantize:
    name: str
    kind: str = "block-scale"
    weight_dtype: Literal["float8_e4m3fn", "float8_e5m2"] = "float8_e4m3fn"
    model_dtype: Literal["float16", "bfloat16"] = "bfloat16"
    quantize_linear: bool = True
    weight_block_size: Optional[Tuple[int, int]] = None
    use_activation_scale: bool = False

    def quantize_model(self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str) -> nn.Module

class BlockScaleQuantizeLinear(nn.Module):
    def __init__(self, in_features, out_features, weight_dtype, block_size, bias=True, dtype=None, out_dtype=None)
    @staticmethod
    def from_linear(src: nn.Linear, config: BlockScaleQuantize, weight_block_size) -> "BlockScaleQuantizeLinear"
    def forward(self, x: nn.Tensor) -> nn.Tensor

class BlockScaleQuantizeLinearStaticActivation(BlockScaleQuantizeLinear):
    @staticmethod
    def from_linear(src: nn.Linear, config: BlockScaleQuantize, weight_block_size) -> "BlockScaleQuantizeLinearStaticActivation"
    def forward(self, x: nn.Tensor) -> nn.Tensor

class BlockScaleQuantizeMixtralExperts(nn.Module):
    def __init__(self, num_local_experts, in_features, out_features, weight_dtype, block_size)
    @staticmethod
    def from_mixtral_experts(src, config, weight_block_size) -> "BlockScaleQuantizeMixtralExperts"
    def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor

def rowwise_group_quant_fp8(x, group_size, dtype, transpose_scale, eps=1e-10, keep_first_batch_dim=False) -> Tuple[nn.Tensor, nn.Tensor]
def static_activation_group_quant_fp8(x, activation_scale, group_size, dtype) -> nn.Tensor
def broadcast_activation_scale(x, activation_scale, transpose) -> nn.Tensor
def dequantize_float8_groupwise_scaled_gemv(x, w, w_scale, block_size, out_dtype) -> nn.Tensor

Import

from mlc_llm.quantization.block_scale_quantization import (
    BlockScaleQuantize,
    BlockScaleQuantizeLinear,
    BlockScaleQuantizeLinearStaticActivation,
    BlockScaleQuantizeMixtralExperts,
    rowwise_group_quant_fp8,
)

I/O Contract

BlockScaleQuantizeLinear.forward

Parameter Type Description
x nn.Tensor Input activation tensor, shape [..., in_features]
Return Type Description
result nn.Tensor Output tensor, shape [..., out_features], dtype matches out_dtype or input dtype

BlockScaleQuantizeMixtralExperts.forward

Parameter Type Description
x nn.Tensor Input tensor [m, in_features]
indptr nn.Tensor Expert boundary pointers [num_experts + 1] (batched) or expert indices [1, k] (single token)
Return Type Description
result nn.Tensor Output tensor [m, out_features]

rowwise_group_quant_fp8

Parameter Type Description
x nn.Tensor Input tensor, shape [..., k] (ndim >= 2)
group_size int Group size for quantization along last dimension
dtype str Target FP8 dtype ("float8_e4m3fn" or "float8_e5m2")
transpose_scale bool Whether to transpose the scale tensor
eps float Epsilon for numerical stability (default 1e-10)
Return Type Description
x_fp8 nn.Tensor Quantized tensor in FP8, same shape as x
x_scale nn.Tensor Per-group scale factors, shape [..., ceildiv(k, group_size)]

Backend Dispatch Logic

The forward pass dynamically selects between compute backends:

Condition Backend Used
Single token (m=1) Dequantize GEMV (TIR kernel)
CUTLASS available and registered CUTLASS FP8 group-scaled GEMM
CUTLASS not available Triton FP8 block-scale GEMM

Usage Examples

from mlc_llm.quantization.block_scale_quantization import BlockScaleQuantize

# Define quantization config
quant_config = BlockScaleQuantize(
    name="block_fp8",
    kind="block-scale-quant",
    weight_dtype="float8_e4m3fn",
    model_dtype="bfloat16",
    weight_block_size=(128, 128),
)

# Quantize the model
quantized_model = quant_config.quantize_model(model, quant_map, name_prefix="model")

# Dynamic activation quantization
x_fp8, x_scale = rowwise_group_quant_fp8(
    x, group_size=128, dtype="float8_e4m3fn", transpose_scale=False
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment