Implementation:Mlc ai Mlc llm FP8 Quantization
Overview
The FP8 Quantization module implements FP8 per-tensor quantization specifically for Mixtral-style Mixture-of-Experts (MoE) layers in MLC LLM. Located at python/mlc_llm/quantization/fp8_quantization.py (125 lines), it extends the base per-tensor quantization infrastructure with FP8-specific logic for expert weight quantization, calibration, and efficient group GEMM execution.
Purpose
This module provides an FP8 specialization of the PerTensorQuantizeMixtralExperts class, handling:
- Conversion of non-quantized MixtralExperts layers to FP8-quantized counterparts
- Weight sharding propagation for tensor-parallel deployments
- Forward pass with calibration-mode-aware quantization and dequantization
- Integration with CUTLASS group GEMM kernels for hardware-accelerated MoE execution
Key Components
FP8PerTensorQuantizeMixtralExperts Class
This class extends ptq.PerTensorQuantizeMixtralExperts to provide FP8-specific behavior for Mixtral expert layers:
class FP8PerTensorQuantizeMixtralExperts(
ptq.PerTensorQuantizeMixtralExperts
):
def __init__(
self, num_local_experts, in_features, out_features,
config: ptq.PerTensorQuantize, name: str, tensor_parallel_shards=1,
):
super().__init__(num_local_experts, in_features, out_features, config, name)
self.tensor_parallel_shards = tensor_parallel_shards
from_mixtral_experts Static Method
Converts a non-quantized MixtralExperts module to its FP8-quantized form:
@staticmethod
def from_mixtral_experts(
src: "MixtralExperts",
config: ptq.PerTensorQuantize,
name: str,
) -> "FP8PerTensorQuantizeMixtralExperts":
quantized_mistral_experts = FP8PerTensorQuantizeMixtralExperts(
num_local_experts=src.num_local_experts,
in_features=src.in_features,
out_features=src.out_features,
config=config,
name=name,
tensor_parallel_shards=src.tensor_parallel_shards,
)
if "shard_strategy" in src.weight.attrs:
shard = src.weight.attrs["shard_strategy"]
apply_sharding(shard, f"{shard.name}_q_weight", quantized_mistral_experts.q_weight)
return quantized_mistral_experts
When the source weight has a sharding strategy attribute, this method propagates that strategy to the quantized weight parameter using apply_sharding from the Quantization Utils module. The scale parameter does not need sharding since it is identical across all shards.
Forward Method
The forward pass implements three distinct execution paths depending on the calibration mode and input dimensionality:
Path 1 -- Calibration Mode ("max"):
if self.config.calibration_mode == "max":
_, x_scale = self.config.quantize_float8(
x, quantize_dtype=self.config.activation_dtype,
storage_dtype=self.config.activation_dtype,
)
if self.config.tensor_parallel_shards > 1:
x_scale = nn.ccl_allreduce(x_scale, "max")
x_scale = nn.extern(
"mlc_llm.calibration_observer",
[f"{self.name}.q_calibration_scale", "max", x_scale],
out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype),
)
x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype)
x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype)
During calibration, it quantizes activations to FP8, computes a per-tensor scale, performs an all-reduce (max) across tensor-parallel shards, and invokes an external calibration observer to record the scale.
Path 2 -- Single-batch GEMV (indptr.ndim == 2):
if indptr.ndim == 2:
assert indptr.shape[0] == 1
return moe_matmul.dequantize_float8_gemv(
x, w, self.q_scale, indptr, self.config.weight_dtype
)
For single-batch inference, it uses a specialized dequantize-and-GEMV kernel.
Path 3 -- Batched Group GEMM (CUTLASS or fallback):
if extern.get_store().cutlass_group_gemm:
if self.config.calibration_mode == "inference":
if self.q_calibration_scale is not None:
x /= self.q_calibration_scale.astype(x.dtype)
x_q = nn.op.astype(x, dtype=self.config.activation_dtype)
x_scale = self.q_calibration_scale
scale = (
x_scale * self.q_scale
if self.q_scale is not None
else nn.wrap_nested(
relax.Constant(runtime.tensor(np.array([1.0]).astype("float32"))),
"scale",
)
)
return cutlass.group_gemm(
x_q, w, indptr, scale, self.config.weight_dtype, self.config.model_dtype
)
When CUTLASS group GEMM is available, it performs inference-mode quantization of activations using the calibrated scale and then dispatches to the CUTLASS kernel. The combined scale is the product of the activation scale and weight scale.
Fallback Path:
w = nn.tensor_expr_op(
self.config.dequantize_float8,
"dequantize",
args=[w, self.q_scale, self.config.weight_dtype],
)
return moe_matmul.group_gemm(x, w, indptr)
If CUTLASS is not available, weights are dequantized to float using a TVM tensor expression and then standard group GEMM is used.
Implementation Registration
The module registers itself as the FP8 implementation for MixtralExperts quantization:
ptq.PerTensorQuantizeMixtralExperts._IMPL["fp8"] = FP8PerTensorQuantizeMixtralExperts
This pattern allows the base PerTensorQuantizeMixtralExperts class to dispatch to the appropriate implementation based on the quantization dtype.
Dependencies
numpy-- For creating constant scale arraystvm-- TVM compiler framework (relax,runtime)tvm.relax.frontend.nn-- Neural network module abstractionmlc_llm.nn.MixtralExperts-- Base MixtralExperts modulemlc_llm.op.cutlass-- CUTLASS group GEMM kernel integrationmlc_llm.op.extern-- External function store (for checking CUTLASS availability)mlc_llm.op.moe_matmul-- MoE matrix multiplication kernelsmlc_llm.quantization.per_tensor_quantization-- Base per-tensor quantization classesmlc_llm.quantization.utils.apply_sharding-- Sharding strategy utility
File Location
python/mlc_llm/quantization/fp8_quantization.py