Implementation:Ggml org Llama cpp Common Speculative Draft
| Field | Value |
|---|---|
| Implementation Name | Common Speculative Draft |
| Doc Type | API Doc |
| Workflow | Speculative_Decoding |
| Step | 5 of 5 (CORE) |
| Source Files | common/speculative.cpp, common/sampling.cpp
|
Overview
Description
This implementation documents the three core functions that form the speculative generation loop: common_speculative_draft() for generating draft tokens, common_sampler_sample_and_accept_n() for verifying drafts against the target model and accepting matching tokens, and common_speculative_accept() for informing the speculation engine of how many tokens were accepted.
Together these functions implement the draft-then-verify paradigm: the draft function produces candidate tokens cheaply, the sampler verifies them against the target model's distribution, and the accept function updates internal state for the next iteration.
Usage
// Draft phase
llama_tokens drafts = common_speculative_draft(spec, params.speculative, prompt_tokens, id_last);
// Verify phase (after target model forward pass)
std::vector<llama_token> accepted = common_sampler_sample_and_accept_n(smpl, ctx_tgt, drafts, false);
// Accept phase
common_speculative_accept(spec, accepted.size() - 1);
Code Reference
| Field | Value |
|---|---|
| common_speculative_draft | common/speculative.cpp:995-1025
|
| common_speculative_accept | common/speculative.cpp:1027-1046
|
| common_sampler_sample_and_accept_n (with idxs) | common/sampling.cpp:521-548
|
| common_sampler_sample_and_accept_n (simple) | common/sampling.cpp:551-558
|
| Import | #include "speculative.h", #include "sampling.h"
|
common_speculative_draft signature and implementation:
llama_tokens common_speculative_draft(
common_speculative * spec,
const common_params_speculative & params,
const llama_tokens & prompt_tgt, // specified in target model vocab
llama_token id_last) {
llama_tokens result;
spec->curr_impl = nullptr; // reset current implementation
for (auto & impl : spec->impls) {
{
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(params, prompt_tgt, id_last, result);
impl->n_call_draft++;
}
if (!result.empty()) {
spec->curr_impl = impl.get();
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
break; // We have a draft, so break out of the loop
}
}
return result;
}
common_speculative_accept implementation:
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
if (n_accepted == 0) {
return;
}
common_speculative_state * impl = spec->curr_impl;
GGML_ASSERT(impl);
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
impl->n_acc_drafts++;
impl->n_acc_tokens += n_accepted;
}
impl->accept(n_accepted);
impl->n_call_accept++;
}
}
common_sampler_sample_and_accept_n implementation:
std::vector<llama_token> common_sampler_sample_and_accept_n(
struct common_sampler * gsmpl,
struct llama_context * ctx,
const std::vector<int> & idxs,
const llama_tokens & draft,
bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1
&& "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
}
return result;
}
Simplified overload (contiguous positions):
std::vector<llama_token> common_sampler_sample_and_accept_n(
struct common_sampler * gsmpl,
struct llama_context * ctx,
const llama_tokens & draft,
bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
Additional lifecycle functions:
// Optionally call at the beginning of a new generation
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
// Print statistics about speculative decoding performance
void common_speculative_print_stats(const common_speculative * spec);
I/O Contract
| Direction | Name | Type | Description |
|---|---|---|---|
| common_speculative_draft | |||
| Input | spec | common_speculative * |
Speculative decoding engine |
| Input | params | const common_params_speculative & |
Speculation parameters (n_max, p_min, etc.) |
| Input | prompt_tgt | const llama_tokens & |
Current prompt/context tokens in target vocabulary |
| Input | id_last | llama_token |
Last accepted token |
| Output | (return) | llama_tokens |
Vector of draft tokens (may be empty if no implementation produced drafts) |
| common_sampler_sample_and_accept_n | |||
| Input | gsmpl | common_sampler * |
Sampler with target model distribution |
| Input | ctx | llama_context * |
Target model context (after forward pass with drafts) |
| Input | idxs | const std::vector<int> & |
Batch indices for each position |
| Input | draft | const llama_tokens & |
Draft tokens to verify |
| Input | grammar_first | bool |
Whether to apply grammar before sampling |
| Output | (return) | std::vector<llama_token> |
Accepted tokens: matching prefix from drafts + one target-sampled token |
| common_speculative_accept | |||
| Input | spec | common_speculative * |
Speculative decoding engine |
| Input | n_accepted | uint16_t |
Number of draft tokens that were accepted by the target model |
Return value semantics for common_sampler_sample_and_accept_n:
- If all k draft tokens match: returns k+1 tokens (all drafts + one bonus token)
- If first j tokens match and j+1 does not: returns j+1 tokens (matched prefix + target's correction)
- Minimum return size is 1 (the target model's own sampled token)
Usage Examples
Complete speculative generation loop:
#include "speculative.h"
#include "sampling.h"
// Initialize
common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// Begin generation
common_speculative_begin(spec, prompt_tokens);
while (n_generated < n_predict) {
// 1. Draft phase: get candidate tokens
llama_tokens drafts = common_speculative_draft(
spec, params.speculative, prompt_tokens, id_last);
// 2. Build batch with draft tokens for target model
llama_batch batch = llama_batch_init(drafts.size() + 1, 0, 1);
// ... add id_last and draft tokens to batch
llama_decode(ctx_tgt, batch);
// 3. Verify phase: sample from target and compare to drafts
std::vector<llama_token> accepted =
common_sampler_sample_and_accept_n(smpl, ctx_tgt, drafts, false);
// 4. Accept phase: inform speculative engine
// accepted.size()-1 because the last token is always from target
common_speculative_accept(spec, (uint16_t)(accepted.size() - 1));
// 5. Process accepted tokens
for (auto token : accepted) {
// output token, update context, etc.
}
id_last = accepted.back();
n_generated += accepted.size();
}
// Print performance statistics
common_speculative_print_stats(spec);
// Cleanup
common_speculative_free(spec);
common_sampler_free(smpl);