Implementation:Mlc ai Mlc llm Eagle Batch Draft
Overview
File: cpp/serve/engine_actions/eagle_batch_draft.cc
Purpose: Implements the EagleBatchDraftActionObj engine action, which runs the draft proposal phase of EAGLE-style speculative decoding. In speculative decoding, a small "draft" model (SSM) proposes candidate tokens that the larger model (LLM) later verifies. This action generates multiple draft tokens per request in a batched fashion, using the second model in a two-model setup.
Namespace: mlc::llm::serve
Class: EagleBatchDraftActionObj
Inherits from EngineActionObj and implements the Step method for EAGLE-based draft token generation.
Constructor
explicit EagleBatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder)
Takes the full set of components needed for draft generation: both models (LLM + SSM), logit processor, sampler, per-model workspaces, draft token workspace manager, engine configuration, and an optional trace recorder.
Step Method
Array<Request> Step(EngineState estate) final;
The main execution method. The algorithm proceeds as follows:
- Guard conditions: Returns immediately if there are not exactly 2 models or if the running queue is empty. EAGLE speculative decoding requires exactly one LLM and one SSM.
- Preemption loop: Iterates over running request state entries and preempts low-priority requests if the draft model lacks sufficient KV cache pages. Uses
PreemptLastRunningRequestStateEntrywith the draft token workspace manager. - Request collection: Gathers request IDs, internal IDs, generation configs, and random number generators from all running request state entries.
- Draft generation loop: For each draft round (from draft ID 1 to
spec_draft_length - 1, since the first draft token is already generated during prefill/verify):- Collects the last draft output token from each request's model state as input.
- Computes token embeddings via the draft model.
- Fuses embeddings with hidden states using
FuseEmbedHidden. - Runs batched decode to produce new hidden states.
- Computes logits (uses the draft model's head if available, otherwise falls back to the base model's head).
- Applies logit processing with draft token indices for proper context.
- Computes probability distributions and renormalizes by top-p.
- Samples tokens from the renormalized probabilities.
- Allocates draft token slots and scatters draft probabilities into the workspace storage.
- Adds each sampled draft token to the respective model state with parent index tracking.
- Records per-round draft timing metrics.
- Overall timing: Accumulates total decode time in engine metrics.
Key Implementation Details
Hidden State Gathering
For draft rounds beyond the first, the action gathers hidden states from the draft hidden states storage:
if (estate->spec_draft_length > 1) {
for (int i = 0; i < num_rsentries; ++i) {
draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());
}
hidden_states = models_[model_id]->GatherHiddenStates(
model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);
}
This enables the SSM to use the hidden states from its own previous round as input context for the next draft token.
Logit Processing with Draft Context
The logit processor receives draft_token_indices to properly account for which draft position each request is at:
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids, nullptr,
&mstates, &draft_token_indices);
Prefix Cache Overlap
The action commits prefix cache changes from the previous round during GPU execution to overlap CPU and GPU work:
estate->prefix_cache->CommitSequenceExtention();
Private Methods
CanDecode
bool CanDecode(int num_rsentries);
Checks whether the draft model (model index 1 and above) has enough available KV cache pages for all the running request state entries. The first model is excluded because it is not involved in draft proposal.
Member Variables
| Member | Type | Description |
|---|---|---|
models_ |
Array<Model> |
Both the LLM (index 0) and SSM (index 1) models |
logit_processor_ |
LogitProcessor |
Logit processor for transforming raw logits |
sampler_ |
Sampler |
Token sampler (top-p, temperature, etc.) |
model_workspaces_ |
std::vector<ModelWorkspace> |
Per-model workspace containing hidden state storage |
draft_token_workspace_manager_ |
DraftTokenWorkspaceManager |
Manages allocation of draft token slots |
engine_config_ |
EngineConfig |
Engine configuration (max sequence count, etc.) |
trace_recorder_ |
Optional<EventTraceRecorder> |
Optional event trace recorder |
draft_token_slots_ |
std::vector<int> |
Temporary buffer for current draft token slot indices |
Factory Function
EngineAction EngineAction::EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);
Static factory method on EngineAction that constructs an EagleBatchDraftActionObj wrapped in a TVM object reference.
Design Notes
- The EAGLE draft model reuses the base model's language model head when the draft model does not have its own (
CanGetLogits()check), which is a common EAGLE architecture choice where the SSM shares the vocabulary projection layer with the LLM. - Draft token generation is iterative: each round depends on the previous round's sampled token and hidden states, preventing parallelization across draft positions but allowing full batching across requests.
- The
spec_draft_lengthis dynamic and stored in the engine state, allowing adaptive speculative decoding where the draft length can change over time. - The action returns an empty
Array<Request>since draft tokens are stored internally and consumed by the subsequent verification action.