Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Romsto Speculative Decoding Speculative Decoding Inference

From Leeroopedia
Knowledge Sources
Domains LLMs, Inference_Optimization, Speculative_Decoding
Last Updated 2026-02-14 04:30 GMT

Overview

End-to-end process for accelerating Large Language Model inference using speculative decoding with a smaller drafter model and rejection sampling verification.

Description

This workflow implements the speculative decoding algorithm (Leviathan et al., 2023; Chen et al., 2023) to speed up transformer inference without changing the output distribution. A small, fast drafter model generates candidate token sequences (drafts) of length gamma, which are then verified in a single forward pass by the larger target model. Accepted drafts are kept; rejected drafts trigger a corrected sample from an adjusted distribution. The process requires two models sharing the same tokenizer and vocabulary, and supports both greedy and nucleus sampling strategies.

Goal: Produce text output faster than standard autoregressive decoding while preserving the target model's output distribution.

Scope: Covers dependency installation, model loading, input tokenization, sampling configuration, speculative generation with draft-then-verify, and output decoding.

Strategy: Exploits the parallel verification property of transformers: the target model can verify all gamma draft tokens in one forward pass, amortizing the cost of the large model across multiple tokens per step.

Usage

Execute this workflow when you have a large target language model (decoder-only) that is slow at autoregressive generation and a smaller drafter model from the same model family sharing the same tokenizer. This is appropriate when inference latency is the bottleneck and the drafter model is substantially faster than the target, typically when the target model causes a memory bandwidth bottleneck on your hardware.

Execution Steps

Step 1: Install Dependencies

Set up the Python environment with the required libraries. The implementation depends on PyTorch for tensor operations, HuggingFace Transformers for model loading, and bitsandbytes/accelerate for optional quantization support. Additional utilities include termcolor for debug output and tqdm for progress tracking.

Key considerations:

  • Requires Python 3.7 or later
  • PyTorch version must be 2.3.0 or higher for compatibility
  • The bitsandbytes library enables int8 quantization to reduce memory usage
  • All dependencies are listed in the repository's requirements.txt

Step 2: Load Target and Drafter Models

Load both the target (large) model and the drafter (small) model using HuggingFace's AutoModelForCausalLM. Both models must share the same tokenizer and output the same vocabulary size. Optionally apply int8 quantization to reduce GPU memory consumption.

Key considerations:

  • The drafter model must share the same tokenizer as the target model
  • Both models should output logits of the same vocabulary size
  • The target model should be large enough that it is memory-bandwidth-bound during inference
  • The drafter model should be small enough to generate tokens faster than the target
  • Both models should be set to evaluation mode after loading
  • Optional quantization (e.g., QuantoConfig with int8 weights) reduces memory footprint

Step 3: Prepare Input

Tokenize the user's text prompt into a list of token IDs. If using a chat-instructed model, apply the appropriate chat template before tokenization to ensure proper formatting with special tokens (e.g., system prompts, turn markers).

Key considerations:

  • The generation functions expect a flat Python list of integer token IDs, not a batched tensor
  • Chat-templated inputs require the correct template for the model family (e.g., Llama chat template)
  • Special tokens like BOS, EOS, and turn markers must be correctly inserted

Step 4: Configure Sampling Strategy

Select and instantiate a logits processor that defines how tokens are sampled from the model's probability distribution. Options include greedy decoding (argmax), multinomial sampling, top-k filtering, nucleus (top-p) filtering, or a combination of top-k and nucleus.

Key considerations:

  • Greedy decoding is deterministic and selects the highest-probability token
  • Nucleus sampling with temperature provides controlled randomness
  • The same logits processor is used for both the drafter and target models to maintain distribution consistency
  • Temperature, top-k, and top-p are the primary hyperparameters

Step 5: Run Speculative Generation

Execute the speculative decoding loop. In each iteration, the drafter model generates gamma draft tokens autoregressively. The target model then verifies all drafts in a single forward pass. Each draft is accepted or rejected using rejection sampling: a random number is compared against the ratio of target probability to drafter probability. The first rejected draft position triggers re-sampling from an adjusted distribution (max_fn of the difference). Accepted tokens and the re-sampled token are appended to the sequence.

What happens:

  • The drafter generates gamma candidate tokens one at a time
  • The target model processes all drafts in one forward pass, producing probability distributions for each position
  • Rejection sampling compares p(target)/q(drafter) against a uniform random variable for each draft
  • Accepted drafts are kept, the first rejection triggers a corrected sample
  • KV-caches are pruned when drafts are rejected to maintain consistency
  • The acceptance rate alpha measures how well the drafter approximates the target

Key considerations:

  • The gamma hyperparameter controls the number of drafts per step; higher gamma does not always mean faster generation
  • The acceptance rate depends on how well the drafter approximates the target model
  • KV-cache usage is optional but can speed up generation (though it has known stability issues across model implementations)
  • Generation stops when an EOS token is produced or the maximum length is reached

Step 6: Decode Output

Convert the generated list of token IDs back into human-readable text using the tokenizer's decode method. Optionally skip special tokens in the output for cleaner results.

Key considerations:

  • Use skip_special_tokens=True for clean text output
  • The speculative_generate function returns both the generated token IDs and the acceptance rate
  • The acceptance rate is a useful diagnostic for tuning the gamma parameter

Execution Diagram

GitHub URL

Workflow Repository