Implementation:Mlc ai Mlc llm Calibrate Interface
Overview
The Calibrate Interface module provides the Python entrypoint for calibrating quantized models in MLC LLM. Located at python/mlc_llm/interface/calibrate.py, this file implements a complete calibration pipeline that loads a dataset, samples representative requests, sends them through the model engine to collect activation statistics, and saves the resulting calibration parameters. The module contains three main components: the CalibrationObserver singleton class, the sample_requests function, and the calibrate entrypoint function.
Purpose
Post-training quantization can benefit significantly from calibration data that captures the typical activation ranges of a model. This module orchestrates the collection of such statistics by:
- Loading and filtering a conversation dataset.
- Sampling a specified number of prompts.
- Running those prompts through the
AsyncMLCEngine. - Collecting activation statistics via a TVM global callback function.
- Saving the calibration parameters to disk.
File Location
python/mlc_llm/interface/calibrate.py
Imports and Dependencies
import asyncio
import json
import random
from typing import List, Mapping, Optional, Tuple
import numpy as np
import tqdm.asyncio
import tvm
from tvm.contrib import tvmjs
from mlc_llm.serve.engine import AsyncMLCEngine, EngineConfig
from mlc_llm.tokenizers import Tokenizer
Key external dependencies:
- numpy -- Used for numerical reduction operations on calibration tensors.
- tqdm.asyncio -- Provides progress bars for asynchronous task gathering.
- tvm -- The TVM runtime, used for tensor operations and the global function callback mechanism.
- tvmjs -- Used to serialize calibration parameters to disk in a TVM-compatible format.
- AsyncMLCEngine -- The asynchronous MLC LLM serving engine.
CalibrationObserver Class
The CalibrationObserver is a singleton that intercepts calibration data during model inference via a TVM global function callback.
class CalibrationObserver:
"""A singleton class to observe the calibration parameters."""
instance: "CalibrationObserver" = None
params: Mapping[str, tvm.runtime.Tensor] = {}
@staticmethod
def get():
"""Get the singleton instance of the class."""
if CalibrationObserver.instance is None:
CalibrationObserver.instance = CalibrationObserver()
return CalibrationObserver.instance
Callback Mechanism
The callback method is registered as a TVM global function under the name "mlc_llm.calibration_observer". The TVM runtime invokes this callback during model execution when calibration is active.
@tvm.register_global_func("mlc_llm.calibration_observer")
@staticmethod
def callback(
name: str,
mode: str,
value: "tvm.runtime.Tensor",
out_value: "tvm.runtime.Tensor",
):
"""The callback function to update the saved calibration parameters."""
instance = CalibrationObserver.get()
if mode == "max":
reducer = np.maximum
else:
raise NotImplementedError(f"Unsupported calibration mode: {mode}")
if name in instance.params:
instance.params[name] = reducer(instance.params[name], value.numpy())
else:
instance.params[name] = value.numpy()
out_value.copyfrom(instance.params[name])
The callback operates as follows:
- Receives a parameter name (identifying which layer/weight is being calibrated), a mode (currently only
"max"is supported), the current value tensor, and an out_value tensor to write results back into. - Selects the appropriate reducer function based on the mode. For
"max", it usesnp.maximumfor element-wise maximum. - If the parameter has been seen before, it applies the reducer to accumulate the running statistic. Otherwise, it initializes the parameter with the current value.
- Copies the accumulated result back into
out_valueso the TVM runtime can use the updated calibration data.
Saving Parameters
def save_params(self, output: str):
"""Save the calibration parameters to the given output directory."""
tvmjs.dump_tensor_cache(
self.params,
output,
encode_format="f32-to-bf16",
meta_data=None,
show_progress=False,
update_if_exists=True,
)
Parameters are serialized using tvmjs.dump_tensor_cache with f32-to-bf16 encoding, which converts 32-bit floating point values to bfloat16 format for storage efficiency.
sample_requests Function
This function loads a JSON dataset, filters it, tokenizes prompts and completions, and returns a sampled subset suitable for calibration.
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: Tokenizer,
) -> List[Tuple[str, int, int]]:
Dataset Filtering Pipeline
The filtering applies several criteria:
- Minimum conversation length: Only conversations with at least 2 turns are kept.
- First two turns: Only the first user prompt and first assistant completion are extracted.
- Tokenization: Both prompts and completions are batch-tokenized.
- Length filtering:
- Prompts with fewer than 4 tokens or completions with fewer than 4 tokens are pruned.
- Prompts longer than 1024 tokens are pruned.
- Combined prompt + completion length exceeding 2048 tokens is pruned.
- Random sampling:
random.sample()selects exactlynum_requestsentries from the filtered set.
The function returns a list of tuples (prompt_text, prompt_length, output_length).
send_calibration_requests Function
This asynchronous function sends the sampled requests to the engine concurrently, respecting a semaphore-based concurrency limit.
async def send_calibration_requests(
async_engine: AsyncMLCEngine,
sampled_requests: List[Tuple[str, int, int]],
max_concurrent_requests: int,
) -> None:
tasks = []
semaphore = asyncio.Semaphore(max_concurrent_requests)
async def generate_task(request_idx):
async with semaphore:
prompt, _, output_len = sampled_requests[request_idx]
await async_engine.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
max_tokens=output_len,
request_id=str(request_idx),
)
for i in range(len(sampled_requests)):
task = asyncio.create_task(generate_task(i))
tasks.append(task)
await tqdm.asyncio.tqdm.gather(*tasks)
Each request is dispatched as an asyncio task. The asyncio.Semaphore limits the number of concurrent in-flight requests to max_concurrent_requests. Progress is displayed via tqdm.asyncio.tqdm.gather.
calibrate Entrypoint
The top-level calibrate function ties all components together.
def calibrate(
model: str,
device: str,
model_lib: Optional[str],
dataset: str,
output: str,
num_calibration_samples: int,
*,
seed: int,
max_num_sequence: Optional[int] = None,
max_total_sequence_length: Optional[int] = None,
prefill_chunk_size: Optional[int] = None,
max_history_size: Optional[int] = None,
gpu_memory_utilization: Optional[float] = None,
) -> None:
Execution Flow
- Seed the random number generator with the provided seed for reproducibility.
- Initialize an
AsyncMLCEnginein"server"mode with the specified engine configuration parameters. - Sample calibration requests from the dataset using
sample_requests(). - Run the calibration requests asynchronously via
send_calibration_requests(), usingmax_num_sequence(defaulting to 32) as the concurrency limit. - Terminate the engine.
- Save accumulated calibration parameters via the
CalibrationObserversingleton.
Relationship to Other Modules
- AsyncMLCEngine (
mlc_llm.serve.engine) -- The serving engine used to run inference during calibration. - Tokenizer (
mlc_llm.tokenizers) -- Used viaasync_engine.tokenizerfor batch tokenization of dataset prompts. - CLI Help (
mlc_llm.interface.help) -- The CLI help strings for calibration-related arguments (calibration_dataset,num_calibration_samples,output_calibration,seed_calibrate) are defined there. - TVM Runtime -- The
CalibrationObserver.callbackis registered as a TVM global function and is invoked by the compiled model code during inference.