Implementation:Mlc ai Mlc llm Model Runtime
| Knowledge Sources | |
|---|---|
| Domains | LLM Serving, Model Inference, KV Cache Management, GPU Computing |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Model Runtime implements the core runtime module for LLM model inference in the MLC LLM serving engine, providing token embedding, prefill, decode, verify, KV cache management, and disaggregated inference operations.
Description
The model.cc file implements the ModelImpl class, which is the central model runtime for the MLC LLM serving engine. It resides in the mlc::llm::serve namespace and provides the complete interface for all model-related operations.
Initialization: The constructor loads the model configuration from JSON (extracting sliding_window_size, attention_sink_size, vocab_size, and model_type), initializes the function table (FunctionTable) from the compiled model library, stores tensor parallelism and pipeline parallelism parameters, and performs a reset. The static Model::Create factory method and Model::LoadModelConfig utility are provided for construction and configuration loading.
Model Computation includes several key methods:
TokenEmbed: Converts token IDs to embeddings by copying token IDs to device memory and calling the embed function. Supports offset-based embedding into a pre-allocated destination tensor and handles sequence length padding whenseqlen_padding_factor > 1.
ImageEmbed: Computes image embeddings using the vision encoder. Calculates resize and crop dimensions based on the model type, then calls theimage_embedfunction with the processed image tensor.
BatchPrefill: Runs batched prefill on multiple sequences. It computes logit positions, handles sequence padding, begins KV cache forward pass, and dispatches to either the single-batch or multi-batch prefill function. For pipeline parallelism, results are gathered from the last worker group. It also tracks whether to use the extend function for already-prefilled sequences.
BatchDecode: Runs batched autoregressive decoding for one token per sequence. It reshapes embeddings to(num_sequence, 1, hidden_size)and dispatches to single-batch or multi-batch decode functions.
BatchTreeDecode: Similar toBatchDecodebut supports tree-structured token sequences with parent pointers for speculative decoding.
BatchVerify: Runs verification of speculative draft tokens with token tree parent pointers.
BatchPrefillToLastHiddenandBatchDecodeToLastHidden: Variants that return hidden states instead of logits, used in multi-step speculative decoding architectures.
GetLogits,GetMultiStepLogits, andGetMedusaLogits: Extract logits from hidden states for various decoding strategies.
FuseEmbedHidden: Fuses embeddings with previous hidden states for architectures that require it.
KV Cache Management provides a complete lifecycle:
CreateKVCache: Initializes PagedKVCache or RNNState based on the model's KV state kind, with configurable page size, max sequences, max total length, and prefill chunk size.AddNewSequence,ForkSequence,RemoveSequence: Manage sequence lifecycle in the KV cache.PopNFromKVCache: Removes trailing tokens from a sequence's KV data.CommitAcceptedTokenTreeNodesToKVCache: Commits accepted tree nodes after speculative verification.EnableSlidingWindowForSeq: Enables sliding window attention with attention sink for a specific sequence.
Disaggregated Inference methods:
DisaggPrepareKVRecv: Prepares KV cache pages for receiving data from a remote instance.DisaggMarkKVSend: Marks KV cache data for sending to a remote decode instance.
Utilities include methods for parameter loading, setting max sequences and prefill chunk sizes, creating LogitProcessor and Sampler instances, allocating embedding and hidden state tensors, and speculative decoding workspace operations (gather/scatter hidden states and draft probabilities).
The file also registers a global TVM function mlc.copy_embedding_to_offset for copying embedding data at specific offsets.
Usage
Use Model Runtime as the central inference component in the MLC LLM serving engine. It is created via Model::Create and used by engine actions (prefill, decode, verify, disaggregated operations) through its interface methods.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: cpp/serve/model.cc
- Lines: 1-1141
Signature
class ModelImpl : public ModelObj {
public:
explicit ModelImpl(String reload_lib_path, String model_path,
picojson::object model_config, DLDevice device,
const Optional<Session>& session, int num_shards,
int num_stages, bool trace_enabled);
// Model Computation
ObjectRef TokenEmbed(IntTuple token_ids, ObjectRef* dst, int offset) final;
ObjectRef ImageEmbed(const Tensor& image, ObjectRef* dst, int offset) final;
Tensor BatchPrefill(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final;
Tensor BatchDecode(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids) final;
Tensor BatchVerify(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths,
const std::vector<int64_t>& token_tree_parent_ptr) final;
// KV Cache Management
void CreateKVCache(int page_size, int max_num_sequence,
int64_t max_total_sequence_length,
int64_t prefill_chunk_size, int max_history_size) final;
void AddNewSequence(int64_t seq_id) final;
void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id,
int64_t fork_pos) final;
void RemoveSequence(int64_t seq_id) final;
// Disaggregated Inference
IntTuple DisaggPrepareKVRecv(int64_t seq_id, int length) final;
void DisaggMarkKVSend(int64_t seq_id, int begin_pos,
IntTuple compressed_kv_append_metadata,
int dst_group_offset) final;
// Utilities
LogitProcessor CreateLogitProcessor(int max_num_token,
Optional<EventTraceRecorder> trace_recorder) final;
Sampler CreateSampler(int max_num_sample, int num_models,
Optional<EventTraceRecorder> trace_recorder) final;
};
// Factory methods
static Model Model::Create(String reload_lib_path, String model_path,
const picojson::object& model_config,
DLDevice device, const Optional<Session>& session,
int num_shards, int num_stages, bool trace_enabled);
static Result<picojson::object> Model::LoadModelConfig(const String& model_path);
Import
#include "model.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| reload_lib_path | String | Yes | Path to the compiled model library |
| model_path | String | Yes | Path to the model directory containing config and parameters |
| model_config | picojson::object | Yes | Model configuration JSON object |
| device | DLDevice | Yes | Target device for inference (CPU, CUDA, ROCm, etc.) |
| session | Optional<Session> | No | Optional distributed session for tensor/pipeline parallelism |
| num_shards | int | Yes | Number of tensor parallel shards |
| num_stages | int | Yes | Number of pipeline parallel stages |
| token_ids | IntTuple | Yes (TokenEmbed) | Token IDs to embed |
| embeddings | ObjectRef | Yes (Prefill/Decode) | Embedding tensors for forward pass |
| seq_ids | std::vector<int64_t> | Yes (Prefill/Decode) | Sequence IDs for KV cache lookup |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | Tensor | Model output logits, shape depends on operation (prefill: (1, num_seq, v), decode: (b, 1, v)) |
| embeddings | ObjectRef | Token or image embeddings |
| hidden_states | ObjectRef | Hidden states from last layer (for ToLastHidden variants) |
| compressed_kv_append_metadata | IntTuple | KV cache page metadata for disaggregated transfer |
Usage Examples
// Create a model instance
Model model = Model::Create(
lib_path, model_path, model_config, device,
session, num_shards, num_stages, trace_enabled);
// Load parameters and configure
model->LoadParams();
model->SetMaxNumSequence(max_num_seq);
model->SetPrefillChunkSize(prefill_chunk_size);
model->CreateKVCache(page_size, max_num_seq, max_total_seq_len,
prefill_chunk_size, max_history_size);
// Allocate embedding workspace
ObjectRef embeddings = model->AllocEmbeddingTensor();
// Token embedding
ObjectRef embedded = model->TokenEmbed(token_ids, &embeddings, 0);
// Prefill
model->AddNewSequence(seq_id);
Tensor logits = model->BatchPrefill(embedded, {seq_id}, {length});
// Decode
Tensor decode_logits = model->BatchDecode(embedded, {seq_id});