Overview
Block-scale quantization configuration and quantized module implementations for FP8 inference in MLC LLM, supporting both standard linear layers and Mixture of Experts (MoE) architectures with CUTLASS and Triton backends.
Description
The block_scale_quantization module implements the complete block-scale FP8 quantization pipeline for MLC LLM models. Block-scale quantization divides weight (and optionally activation) tensors into fixed-size blocks, computing a single FP32 scale factor per block for dequantization. This approach balances quantization accuracy with memory savings.
Core Components:
- BlockScaleQuantize (dataclass): The main quantization configuration. Specifies weight dtype (float8_e4m3fn or float8_e5m2), model dtype (float16 or bfloat16), block size, and whether to use static activation scales. The
quantize_model method walks the model graph using a nn.Mutator and replaces eligible nn.Linear and MixtralExperts modules with their quantized counterparts. Final FC layers and MoE gate layers are excluded from quantization. It also handles DeepSeek-specific w_uk/w_uv parameters for fused KV-lora attention.
- BlockScaleQuantizeLinear: Replaces
nn.Linear for block-scale FP8 inference. Stores quantized weights and per-block scale_inv parameters. The forward pass:
- For single-token (m=1): uses a specialized GEMV path (
dequantize_float8_groupwise_scaled_gemv)
- For batched input: dynamically quantizes activations using
rowwise_group_quant_fp8, then dispatches to either CUTLASS or Triton FP8 GEMM
- BlockScaleQuantizeLinearStaticActivation: Extends
BlockScaleQuantizeLinear for models with pre-computed activation scales. Stores per-group activation scales and applies them during the forward pass via static_activation_group_quant_fp8.
- BlockScaleQuantizeMixtralExperts: Replaces
MixtralExperts for block-scale FP8 MoE inference. Supports both single-token GEMV and batched group GEMM paths, dispatching to either CUTLASS or Triton backends.
Utility Functions:
- rowwise_group_quant_fp8: Dynamically quantizes activations to FP8 with per-group scales. Computes max absolute values per group, derives scale factors, and clips to FP8 range.
- static_activation_group_quant_fp8: Quantizes activations using pre-computed static scales instead of dynamic per-token scales.
- broadcast_activation_scale: Broadcasts 1D activation scale tensors to match the input tensor shape, with optional transpose for CUTLASS vs Triton layout differences.
- dequantize_float8_groupwise_scaled_gemv: A TIR kernel for single-token FP8 GEMV that dequantizes weights on-the-fly and performs the dot product.
Usage
Use this module when deploying FP8 quantized models in MLC LLM. The BlockScaleQuantize configuration is specified in model quantization configs and automatically applied during model compilation. The quantized modules handle the FP8 compute path at runtime, transparently switching between CUTLASS and Triton backends depending on availability.
Code Reference
Source Location
Signature
@dataclass
class BlockScaleQuantize:
name: str
kind: str = "block-scale"
weight_dtype: Literal["float8_e4m3fn", "float8_e5m2"] = "float8_e4m3fn"
model_dtype: Literal["float16", "bfloat16"] = "bfloat16"
quantize_linear: bool = True
weight_block_size: Optional[Tuple[int, int]] = None
use_activation_scale: bool = False
def quantize_model(self, model: nn.Module, quant_map: QuantizeMapping, name_prefix: str) -> nn.Module
class BlockScaleQuantizeLinear(nn.Module):
def __init__(self, in_features, out_features, weight_dtype, block_size, bias=True, dtype=None, out_dtype=None)
@staticmethod
def from_linear(src: nn.Linear, config: BlockScaleQuantize, weight_block_size) -> "BlockScaleQuantizeLinear"
def forward(self, x: nn.Tensor) -> nn.Tensor
class BlockScaleQuantizeLinearStaticActivation(BlockScaleQuantizeLinear):
@staticmethod
def from_linear(src: nn.Linear, config: BlockScaleQuantize, weight_block_size) -> "BlockScaleQuantizeLinearStaticActivation"
def forward(self, x: nn.Tensor) -> nn.Tensor
class BlockScaleQuantizeMixtralExperts(nn.Module):
def __init__(self, num_local_experts, in_features, out_features, weight_dtype, block_size)
@staticmethod
def from_mixtral_experts(src, config, weight_block_size) -> "BlockScaleQuantizeMixtralExperts"
def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor
def rowwise_group_quant_fp8(x, group_size, dtype, transpose_scale, eps=1e-10, keep_first_batch_dim=False) -> Tuple[nn.Tensor, nn.Tensor]
def static_activation_group_quant_fp8(x, activation_scale, group_size, dtype) -> nn.Tensor
def broadcast_activation_scale(x, activation_scale, transpose) -> nn.Tensor
def dequantize_float8_groupwise_scaled_gemv(x, w, w_scale, block_size, out_dtype) -> nn.Tensor
Import
from mlc_llm.quantization.block_scale_quantization import (
BlockScaleQuantize,
BlockScaleQuantizeLinear,
BlockScaleQuantizeLinearStaticActivation,
BlockScaleQuantizeMixtralExperts,
rowwise_group_quant_fp8,
)
I/O Contract
BlockScaleQuantizeLinear.forward
| Parameter |
Type |
Description
|
| x |
nn.Tensor |
Input activation tensor, shape [..., in_features]
|
| Return |
Type |
Description
|
| result |
nn.Tensor |
Output tensor, shape [..., out_features], dtype matches out_dtype or input dtype
|
BlockScaleQuantizeMixtralExperts.forward
| Parameter |
Type |
Description
|
| x |
nn.Tensor |
Input tensor [m, in_features]
|
| indptr |
nn.Tensor |
Expert boundary pointers [num_experts + 1] (batched) or expert indices [1, k] (single token)
|
| Return |
Type |
Description
|
| result |
nn.Tensor |
Output tensor [m, out_features]
|
rowwise_group_quant_fp8
| Parameter |
Type |
Description
|
| x |
nn.Tensor |
Input tensor, shape [..., k] (ndim >= 2)
|
| group_size |
int |
Group size for quantization along last dimension
|
| dtype |
str |
Target FP8 dtype ("float8_e4m3fn" or "float8_e5m2")
|
| transpose_scale |
bool |
Whether to transpose the scale tensor
|
| eps |
float |
Epsilon for numerical stability (default 1e-10)
|
| Return |
Type |
Description
|
| x_fp8 |
nn.Tensor |
Quantized tensor in FP8, same shape as x
|
| x_scale |
nn.Tensor |
Per-group scale factors, shape [..., ceildiv(k, group_size)]
|
Backend Dispatch Logic
The forward pass dynamically selects between compute backends:
| Condition |
Backend Used
|
| Single token (m=1) |
Dequantize GEMV (TIR kernel)
|
| CUTLASS available and registered |
CUTLASS FP8 group-scaled GEMM
|
| CUTLASS not available |
Triton FP8 block-scale GEMM
|
Usage Examples
from mlc_llm.quantization.block_scale_quantization import BlockScaleQuantize
# Define quantization config
quant_config = BlockScaleQuantize(
name="block_fp8",
kind="block-scale-quant",
weight_dtype="float8_e4m3fn",
model_dtype="bfloat16",
weight_block_size=(128, 128),
)
# Quantize the model
quantized_model = quant_config.quantize_model(model, quant_map, name_prefix="model")
# Dynamic activation quantization
x_fp8, x_scale = rowwise_group_quant_fp8(
x, group_size=128, dtype="float8_e4m3fn", transpose_scale=False
)
Related Pages