Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Eagle Batch Draft

From Leeroopedia
Revision as of 15:49, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_Eagle_Batch_Draft.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

File: cpp/serve/engine_actions/eagle_batch_draft.cc

Purpose: Implements the EagleBatchDraftActionObj engine action, which runs the draft proposal phase of EAGLE-style speculative decoding. In speculative decoding, a small "draft" model (SSM) proposes candidate tokens that the larger model (LLM) later verifies. This action generates multiple draft tokens per request in a batched fashion, using the second model in a two-model setup.

Namespace: mlc::llm::serve

Class: EagleBatchDraftActionObj

Inherits from EngineActionObj and implements the Step method for EAGLE-based draft token generation.

Constructor

explicit EagleBatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor,
                                  Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
                                  DraftTokenWorkspaceManager draft_token_workspace_manager,
                                  EngineConfig engine_config,
                                  Optional<EventTraceRecorder> trace_recorder)

Takes the full set of components needed for draft generation: both models (LLM + SSM), logit processor, sampler, per-model workspaces, draft token workspace manager, engine configuration, and an optional trace recorder.

Step Method

Array<Request> Step(EngineState estate) final;

The main execution method. The algorithm proceeds as follows:

  1. Guard conditions: Returns immediately if there are not exactly 2 models or if the running queue is empty. EAGLE speculative decoding requires exactly one LLM and one SSM.
  2. Preemption loop: Iterates over running request state entries and preempts low-priority requests if the draft model lacks sufficient KV cache pages. Uses PreemptLastRunningRequestStateEntry with the draft token workspace manager.
  3. Request collection: Gathers request IDs, internal IDs, generation configs, and random number generators from all running request state entries.
  4. Draft generation loop: For each draft round (from draft ID 1 to spec_draft_length - 1, since the first draft token is already generated during prefill/verify):
    • Collects the last draft output token from each request's model state as input.
    • Computes token embeddings via the draft model.
    • Fuses embeddings with hidden states using FuseEmbedHidden.
    • Runs batched decode to produce new hidden states.
    • Computes logits (uses the draft model's head if available, otherwise falls back to the base model's head).
    • Applies logit processing with draft token indices for proper context.
    • Computes probability distributions and renormalizes by top-p.
    • Samples tokens from the renormalized probabilities.
    • Allocates draft token slots and scatters draft probabilities into the workspace storage.
    • Adds each sampled draft token to the respective model state with parent index tracking.
    • Records per-round draft timing metrics.
  5. Overall timing: Accumulates total decode time in engine metrics.

Key Implementation Details

Hidden State Gathering

For draft rounds beyond the first, the action gathers hidden states from the draft hidden states storage:

if (estate->spec_draft_length > 1) {
    for (int i = 0; i < num_rsentries; ++i) {
        draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());
    }
    hidden_states = models_[model_id]->GatherHiddenStates(
        model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);
}

This enables the SSM to use the hidden states from its own previous round as input context for the next draft token.

Logit Processing with Draft Context

The logit processor receives draft_token_indices to properly account for which draft position each request is at:

logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids, nullptr,
                                      &mstates, &draft_token_indices);

Prefix Cache Overlap

The action commits prefix cache changes from the previous round during GPU execution to overlap CPU and GPU work:

estate->prefix_cache->CommitSequenceExtention();

Private Methods

CanDecode

bool CanDecode(int num_rsentries);

Checks whether the draft model (model index 1 and above) has enough available KV cache pages for all the running request state entries. The first model is excluded because it is not involved in draft proposal.

Member Variables

Member Type Description
models_ Array<Model> Both the LLM (index 0) and SSM (index 1) models
logit_processor_ LogitProcessor Logit processor for transforming raw logits
sampler_ Sampler Token sampler (top-p, temperature, etc.)
model_workspaces_ std::vector<ModelWorkspace> Per-model workspace containing hidden state storage
draft_token_workspace_manager_ DraftTokenWorkspaceManager Manages allocation of draft token slots
engine_config_ EngineConfig Engine configuration (max sequence count, etc.)
trace_recorder_ Optional<EventTraceRecorder> Optional event trace recorder
draft_token_slots_ std::vector<int> Temporary buffer for current draft token slot indices

Factory Function

EngineAction EngineAction::EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
                                           Sampler sampler,
                                           std::vector<ModelWorkspace> model_workspaces,
                                           DraftTokenWorkspaceManager draft_token_workspace_manager,
                                           EngineConfig engine_config,
                                           Optional<EventTraceRecorder> trace_recorder);

Static factory method on EngineAction that constructs an EagleBatchDraftActionObj wrapped in a TVM object reference.

Design Notes

  • The EAGLE draft model reuses the base model's language model head when the draft model does not have its own (CanGetLogits() check), which is a common EAGLE architecture choice where the SSM shares the vocabulary projection layer with the LLM.
  • Draft token generation is iterative: each round depends on the previous round's sampled token and hidden states, preventing parallelization across draft positions but allowing full batching across requests.
  • The spec_draft_length is dynamic and stored in the engine state, allowing adaptive speculative decoding where the draft length can change over time.
  • The action returns an empty Array<Request> since draft tokens are stored internally and consumed by the subsequent verification action.

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment