Implementation:Mlc ai Mlc llm Dispatch KV Cache Creation
Overview
DispatchKVCacheCreation is a TVM compiler pass that rewrites the generic KV cache creation function in the IRModule into target-specific implementations. It replaces a single create_paged_kv_cache function with up to two specialized creation functions: one based on TIR (always created) and one based on FlashInfer (created when applicable). This pass also attaches KV cache metadata to the model's metadata dictionary for use during inference.
File: python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Architecture
The pass operates in three stages:
- Extract -- Parse the generic
create_paged_kv_cachefunction to extract all configuration arguments - Dispatch -- Create target-specific KV cache implementations (TIR-based and optionally FlashInfer-based)
- Attach metadata -- Store KV cache configuration in the model metadata dictionary
Function: extract_creation_args
This function parses the Relax function create_paged_kv_cache to extract all KV cache configuration parameters from the function's call to mlc.create_paged_kv_cache_generic.
def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
assert isinstance(func.body, relax.SeqExpr)
assert len(func.body.blocks) == 1
# ... multiple assertions validating structure ...
assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
args = call_args[1:]
assert len(args) == 18
The function enforces strict structural assumptions about the IR form and extracts 18 arguments. The attn_kind argument supports two forms:
- A single
relax.StringImmvalue: either"mha"or"mla" - A
relax.Tupleof per-layer attention types: each can be"mha","mla", or"mha_sliding"
Extracted Parameters
| Parameter | Source Argument | Description |
|---|---|---|
attn_kind |
args[0] | Attention type (string or list of strings) |
max_batch_size |
args[1].values[0] | Maximum batch size |
max_total_seq_len |
args[1].values[1] | Maximum total sequence length |
prefill_chunk_size |
args[1].values[2] | Size of prefill chunks |
page_size |
args[1].values[3] | KV cache page size |
support_sliding_window |
args[1].values[4] | Whether sliding window attention is supported |
layer_partition |
args[2] | Layer partition shape expression |
num_hidden_layers |
args[3] | Number of hidden layers |
num_attention_heads |
args[4] | Number of attention heads |
num_key_value_heads |
args[5] | Number of key/value heads (for GQA/MQA) |
qk_head_dim |
args[6] | Query/key head dimension |
v_head_dim |
args[7] | Value head dimension |
mla_original_qk_head_dim |
args[8] | Original QK head dim for MLA |
mla_original_v_head_dim |
args[9] | Original V head dim for MLA |
rope_mode |
args[10] | RoPE (Rotary Position Embedding) mode |
rope_scale |
args[11] | RoPE scaling factor |
rope_theta |
args[12] | RoPE theta parameter |
rope_scaling |
args[13] | RoPE scaling config (JSON string) |
rope_ext_factors |
args[14] | RoPE extension factors |
rotary_dim |
args[15] | Rotary embedding dimension |
enable_disaggregation |
args[16] | Whether disaggregation is enabled |
dtype |
args[17] | Data type for KV cache |
Class: DispatchKVCacheCreation
@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation")
class DispatchKVCacheCreation:
def __init__(self, target: tvm.target.Target, flashinfer: bool, metadata: Dict[str, Any]) -> None:
self.target = target
self.flashinfer = flashinfer
self.metadata = metadata
Parameters
| Parameter | Type | Description |
|---|---|---|
target |
tvm.target.Target |
Compilation target (e.g., CUDA, Metal) |
flashinfer |
bool |
Whether FlashInfer is enabled |
metadata |
Dict[str, Any] |
Model metadata dict (mutated by this pass to include KV cache info) |
transform_module
The main transformation entry point:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
func_dict = {}
creation_func = None
for g_var, func in mod.functions_items():
if g_var.name_hint == "create_paged_kv_cache":
creation_func = func
else:
func_dict[g_var] = func
if creation_func is None:
return mod
new_mod = IRModule(func_dict)
if mod.attrs is not None:
new_mod = new_mod.with_attrs(mod.attrs)
kwargs = extract_creation_args(creation_func)
self.attach_kv_cache_metadata(kwargs)
bb = relax.BlockBuilder(new_mod)
extern_mods = []
extern_mods += self.create_tir_paged_kv_cache(bb, kwargs)
extern_mods += self.create_flashinfer_paged_kv_cache(bb, kwargs)
mod = bb.finalize()
mod = mod.with_attr("external_mods", mod_attrs.get("external_mods", []) + extern_mods)
return mod
The pass first finds and removes the generic create_paged_kv_cache function, then replaces it with specialized implementations.
attach_kv_cache_metadata
Writes a subset of the KV cache configuration into the model metadata:
def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
self.metadata["kv_cache"] = {
"num_hidden_layers": kwargs["num_hidden_layers"],
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["qk_head_dim"],
}
create_tir_paged_kv_cache
Creates the TIR-based paged KV cache implementation. This is always generated regardless of target. The function creates a Relax function named create_tir_paged_kv_cache with five parameters (max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window) and delegates to kv_cache.TIRPagedKVCache.
def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]):
# ... parameter setup ...
with bb.function(name="create_tir_paged_kv_cache", params=[...]):
cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
bb.emit_func_output(cache._expr)
return cache.extern_mods
create_flashinfer_paged_kv_cache
Creates the FlashInfer-based paged KV cache implementation. This is only generated when all of the following conditions are met:
flashinferisTrue- The target is CUDA
- The dtype is
float16orbfloat16 - When using inline RoPE mode,
rotary_dimmust equalqk_head_dimandqk_head_dimmust equalv_head_dim
def create_flashinfer_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]):
if (
not self.flashinfer
or self.target.kind.name != "cuda"
or str(kwargs["dtype"]) not in ["float16", "bfloat16"]
or (kwargs["rope_mode"] == RopeMode.INLINE and (...))
):
return []
try:
with bb.function(name="create_flashinfer_paged_kv_cache", params=[...]):
cache = kv_cache.FlashInferPagedKVCache(target=self.target, **kwargs)
bb.emit_func_output(cache._expr)
except Exception as e:
logger.info("Error caught when creating FlashInfer PagedKVCache: %s\n"
"The model will fallback to TIR-based KV cache.", e)
return []
return cache.extern_mods
The FlashInfer creation is wrapped in a try/except block. If FlashInfer cache creation fails for any reason, the error is logged and the model falls back to the TIR-based cache.
External Modules
Both cache creation methods return a list of tvm.runtime.Module external modules. These are accumulated and attached to the final IRModule via the "external_mods" attribute, allowing the runtime to load pre-compiled kernels needed by the cache implementations.
Dependencies
tvm-- Core TVM frameworktvm.relax.frontend.nn.llm.kv_cache-- KV cache implementations (TIRPagedKVCache,FlashInferPagedKVCache,RopeMode)mlc_llm.support.logging-- MLC-LLM logging utilities