Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Dispatch KV Cache Creation

From Leeroopedia


Overview

DispatchKVCacheCreation is a TVM compiler pass that rewrites the generic KV cache creation function in the IRModule into target-specific implementations. It replaces a single create_paged_kv_cache function with up to two specialized creation functions: one based on TIR (always created) and one based on FlashInfer (created when applicable). This pass also attaches KV cache metadata to the model's metadata dictionary for use during inference.

File: python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py

Architecture

The pass operates in three stages:

  1. Extract -- Parse the generic create_paged_kv_cache function to extract all configuration arguments
  2. Dispatch -- Create target-specific KV cache implementations (TIR-based and optionally FlashInfer-based)
  3. Attach metadata -- Store KV cache configuration in the model metadata dictionary

Function: extract_creation_args

This function parses the Relax function create_paged_kv_cache to extract all KV cache configuration parameters from the function's call to mlc.create_paged_kv_cache_generic.

def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
    assert isinstance(func.body, relax.SeqExpr)
    assert len(func.body.blocks) == 1
    # ... multiple assertions validating structure ...
    assert call_args[0].global_symbol == "mlc.create_paged_kv_cache_generic"
    args = call_args[1:]
    assert len(args) == 18

The function enforces strict structural assumptions about the IR form and extracts 18 arguments. The attn_kind argument supports two forms:

  • A single relax.StringImm value: either "mha" or "mla"
  • A relax.Tuple of per-layer attention types: each can be "mha", "mla", or "mha_sliding"

Extracted Parameters

Parameter Source Argument Description
attn_kind args[0] Attention type (string or list of strings)
max_batch_size args[1].values[0] Maximum batch size
max_total_seq_len args[1].values[1] Maximum total sequence length
prefill_chunk_size args[1].values[2] Size of prefill chunks
page_size args[1].values[3] KV cache page size
support_sliding_window args[1].values[4] Whether sliding window attention is supported
layer_partition args[2] Layer partition shape expression
num_hidden_layers args[3] Number of hidden layers
num_attention_heads args[4] Number of attention heads
num_key_value_heads args[5] Number of key/value heads (for GQA/MQA)
qk_head_dim args[6] Query/key head dimension
v_head_dim args[7] Value head dimension
mla_original_qk_head_dim args[8] Original QK head dim for MLA
mla_original_v_head_dim args[9] Original V head dim for MLA
rope_mode args[10] RoPE (Rotary Position Embedding) mode
rope_scale args[11] RoPE scaling factor
rope_theta args[12] RoPE theta parameter
rope_scaling args[13] RoPE scaling config (JSON string)
rope_ext_factors args[14] RoPE extension factors
rotary_dim args[15] Rotary embedding dimension
enable_disaggregation args[16] Whether disaggregation is enabled
dtype args[17] Data type for KV cache

Class: DispatchKVCacheCreation

@tvm.transform.module_pass(opt_level=0, name="DispatchKVCacheCreation")
class DispatchKVCacheCreation:
    def __init__(self, target: tvm.target.Target, flashinfer: bool, metadata: Dict[str, Any]) -> None:
        self.target = target
        self.flashinfer = flashinfer
        self.metadata = metadata

Parameters

Parameter Type Description
target tvm.target.Target Compilation target (e.g., CUDA, Metal)
flashinfer bool Whether FlashInfer is enabled
metadata Dict[str, Any] Model metadata dict (mutated by this pass to include KV cache info)

transform_module

The main transformation entry point:

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
    func_dict = {}
    creation_func = None
    for g_var, func in mod.functions_items():
        if g_var.name_hint == "create_paged_kv_cache":
            creation_func = func
        else:
            func_dict[g_var] = func

    if creation_func is None:
        return mod

    new_mod = IRModule(func_dict)
    if mod.attrs is not None:
        new_mod = new_mod.with_attrs(mod.attrs)

    kwargs = extract_creation_args(creation_func)
    self.attach_kv_cache_metadata(kwargs)

    bb = relax.BlockBuilder(new_mod)
    extern_mods = []
    extern_mods += self.create_tir_paged_kv_cache(bb, kwargs)
    extern_mods += self.create_flashinfer_paged_kv_cache(bb, kwargs)

    mod = bb.finalize()
    mod = mod.with_attr("external_mods", mod_attrs.get("external_mods", []) + extern_mods)
    return mod

The pass first finds and removes the generic create_paged_kv_cache function, then replaces it with specialized implementations.

attach_kv_cache_metadata

Writes a subset of the KV cache configuration into the model metadata:

def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
    self.metadata["kv_cache"] = {
        "num_hidden_layers": kwargs["num_hidden_layers"],
        "num_attention_heads": kwargs["num_attention_heads"],
        "num_key_value_heads": kwargs["num_key_value_heads"],
        "head_dim": kwargs["qk_head_dim"],
    }

create_tir_paged_kv_cache

Creates the TIR-based paged KV cache implementation. This is always generated regardless of target. The function creates a Relax function named create_tir_paged_kv_cache with five parameters (max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window) and delegates to kv_cache.TIRPagedKVCache.

def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]):
    # ... parameter setup ...
    with bb.function(name="create_tir_paged_kv_cache", params=[...]):
        cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)
        bb.emit_func_output(cache._expr)
    return cache.extern_mods

create_flashinfer_paged_kv_cache

Creates the FlashInfer-based paged KV cache implementation. This is only generated when all of the following conditions are met:

  • flashinfer is True
  • The target is CUDA
  • The dtype is float16 or bfloat16
  • When using inline RoPE mode, rotary_dim must equal qk_head_dim and qk_head_dim must equal v_head_dim
def create_flashinfer_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]):
    if (
        not self.flashinfer
        or self.target.kind.name != "cuda"
        or str(kwargs["dtype"]) not in ["float16", "bfloat16"]
        or (kwargs["rope_mode"] == RopeMode.INLINE and (...))
    ):
        return []

    try:
        with bb.function(name="create_flashinfer_paged_kv_cache", params=[...]):
            cache = kv_cache.FlashInferPagedKVCache(target=self.target, **kwargs)
            bb.emit_func_output(cache._expr)
    except Exception as e:
        logger.info("Error caught when creating FlashInfer PagedKVCache: %s\n"
                     "The model will fallback to TIR-based KV cache.", e)
        return []

    return cache.extern_mods

The FlashInfer creation is wrapped in a try/except block. If FlashInfer cache creation fails for any reason, the error is logged and the model falls back to the TIR-based cache.

External Modules

Both cache creation methods return a list of tvm.runtime.Module external modules. These are accumulated and attached to the final IRModule via the "external_mods" attribute, allowing the runtime to load pre-compiled kernels needed by the cache implementations.

Dependencies

  • tvm -- Core TVM framework
  • tvm.relax.frontend.nn.llm.kv_cache -- KV cache implementations (TIRPagedKVCache, FlashInferPagedKVCache, RopeMode)
  • mlc_llm.support.logging -- MLC-LLM logging utilities

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment