Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sgl project Sglang Speculative Decoding Ops

From Leeroopedia


Knowledge Sources
Domains Speculative Decoding, GPU Computing, LLM Inference
Last Updated 2026-02-10 00:00 GMT

Overview

Python interface for speculative decoding GPU kernels supporting tree-based draft verification, tree construction, and auxiliary tree mask operations.

Description

This module provides the speculative decoding operations API for the SGLang kernel library. It exposes five functions that delegate to custom CUDA kernels registered via torch.ops.sgl_kernel, enabling efficient tree-structured speculative decoding during LLM inference.

The core verification functions are tree_speculative_sampling_target_only and verify_tree_greedy. The sampling variant performs speculative sampling verification using target model probabilities against draft model probabilities, with configurable threshold_single and threshold_acc acceptance thresholds and optional deterministic behavior. The greedy variant performs straightforward greedy verification by comparing draft candidate tokens against target model predictions.

build_tree_kernel_efficient constructs the tree data structures (mask, positions, retrieval indices, next-token pointers, and next-sibling pointers) from parent lists and selected indices, parameterized by topk, depth, draft_token_num, and tree_mask_mode.

reconstruct_indices_from_tree_mask rebuilds retrieval indices from a tree mask, useful when the tree structure changes. segment_packbits packs boolean segments into bitwise representations for compact tree mask storage.

All functions operate in-place on mutable tensor arguments and return None.

Usage

Use these operations when implementing tree-based speculative decoding in SGLang. The typical workflow is: (1) build the tree structure with build_tree_kernel_efficient, (2) run the draft model to generate candidate tokens, (3) verify candidates with either tree_speculative_sampling_target_only (stochastic) or verify_tree_greedy (deterministic), and (4) use auxiliary functions like reconstruct_indices_from_tree_mask and segment_packbits for tree mask manipulation.

Code Reference

Source Location

Signature

def tree_speculative_sampling_target_only(
    predicts: torch.Tensor,           # mutable
    accept_index: torch.Tensor,       # mutable
    accept_token_num: torch.Tensor,   # mutable
    candidates: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    uniform_samples: torch.Tensor,
    uniform_samples_for_final_sampling: torch.Tensor,
    target_probs: torch.Tensor,
    draft_probs: torch.Tensor,
    threshold_single: float = 1.0,
    threshold_acc: float = 1.0,
    deterministic: bool = True,
) -> None:
    ...

def verify_tree_greedy(
    predicts: torch.Tensor,           # mutable
    accept_index: torch.Tensor,       # mutable
    accept_token_num: torch.Tensor,   # mutable
    candidates: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    target_predict: torch.Tensor,
) -> None:
    ...

def build_tree_kernel_efficient(
    parent_list: torch.Tensor,
    selected_index: torch.Tensor,
    verified_seq_len: torch.Tensor,
    tree_mask: torch.Tensor,
    positions: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    topk: int,
    depth: int,
    draft_token_num: int,
    tree_mask_mode: int,
) -> None:
    ...

def reconstruct_indices_from_tree_mask(
    tree_mask: torch.Tensor,
    verified_seq_len: torch.Tensor,
    positions: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    batch_size: int,
    draft_token_num: int,
) -> None:
    ...

def segment_packbits(
    x: torch.Tensor,
    input_indptr: torch.Tensor,
    output_indptr: torch.Tensor,
    y: torch.Tensor,
    batch_size: int,
) -> None:
    ...

Import

from sgl_kernel import (
    tree_speculative_sampling_target_only,
    verify_tree_greedy,
    build_tree_kernel_efficient,
    reconstruct_indices_from_tree_mask,
    segment_packbits,
)

I/O Contract

Inputs

tree_speculative_sampling_target_only

Name Type Required Description
predicts torch.Tensor Yes Output tensor for predicted token IDs (mutable, written in-place)
accept_index torch.Tensor Yes Output tensor for accepted token indices (mutable, written in-place)
accept_token_num torch.Tensor Yes Output tensor for the count of accepted tokens (mutable, written in-place)
candidates torch.Tensor Yes Draft candidate token IDs from the draft model
retrive_index torch.Tensor Yes Tree retrieval index mapping node positions
retrive_next_token torch.Tensor Yes Pointer to the next token in the tree traversal
retrive_next_sibling torch.Tensor Yes Pointer to the next sibling in the tree traversal
uniform_samples torch.Tensor Yes Pre-generated uniform random samples for acceptance decisions
uniform_samples_for_final_sampling torch.Tensor Yes Pre-generated uniform random samples for final token sampling
target_probs torch.Tensor Yes Probability distribution from the target model
draft_probs torch.Tensor Yes Probability distribution from the draft model
threshold_single float No Per-token acceptance threshold; defaults to 1.0
threshold_acc float No Accumulated acceptance threshold; defaults to 1.0
deterministic bool No Whether to use deterministic sampling; defaults to True

verify_tree_greedy

Name Type Required Description
predicts torch.Tensor Yes Output tensor for predicted token IDs (mutable, written in-place)
accept_index torch.Tensor Yes Output tensor for accepted token indices (mutable, written in-place)
accept_token_num torch.Tensor Yes Output tensor for the count of accepted tokens (mutable, written in-place)
candidates torch.Tensor Yes Draft candidate token IDs
retrive_index torch.Tensor Yes Tree retrieval index mapping node positions
retrive_next_token torch.Tensor Yes Pointer to the next token in the tree traversal
retrive_next_sibling torch.Tensor Yes Pointer to the next sibling in the tree traversal
target_predict torch.Tensor Yes Target model greedy predictions (argmax token IDs)

build_tree_kernel_efficient

Name Type Required Description
parent_list torch.Tensor Yes Parent indices defining the tree topology
selected_index torch.Tensor Yes Indices of selected draft tokens
verified_seq_len torch.Tensor Yes Verified sequence lengths per batch element
tree_mask torch.Tensor Yes Output tree attention mask (mutable, written in-place)
positions torch.Tensor Yes Output position IDs for each tree node (mutable, written in-place)
retrive_index torch.Tensor Yes Output retrieval index (mutable, written in-place)
retrive_next_token torch.Tensor Yes Output next-token pointer (mutable, written in-place)
retrive_next_sibling torch.Tensor Yes Output next-sibling pointer (mutable, written in-place)
topk int Yes Number of top-k candidates per tree level
depth int Yes Maximum depth of the draft tree
draft_token_num int Yes Total number of draft tokens in the tree
tree_mask_mode int Yes Mode for generating the tree mask

reconstruct_indices_from_tree_mask

Name Type Required Description
tree_mask torch.Tensor Yes The tree attention mask
verified_seq_len torch.Tensor Yes Verified sequence lengths per batch element
positions torch.Tensor Yes Output position IDs (mutable, written in-place)
retrive_index torch.Tensor Yes Output retrieval index (mutable, written in-place)
retrive_next_token torch.Tensor Yes Output next-token pointer (mutable, written in-place)
retrive_next_sibling torch.Tensor Yes Output next-sibling pointer (mutable, written in-place)
batch_size int Yes Number of sequences in the batch
draft_token_num int Yes Total number of draft tokens

segment_packbits

Name Type Required Description
x torch.Tensor Yes Input boolean tensor to pack
input_indptr torch.Tensor Yes Input segment boundary pointers
output_indptr torch.Tensor Yes Output segment boundary pointers
y torch.Tensor Yes Output packed bits tensor (mutable, written in-place)
batch_size int Yes Number of segments to process

Outputs

Name Type Description
(all functions) None All functions operate in-place on mutable tensor arguments and return None

Usage Examples

import torch
from sgl_kernel import build_tree_kernel_efficient, verify_tree_greedy

# Example: Build a tree structure for speculative decoding
batch_size = 1
topk = 4
depth = 3
draft_token_num = topk * depth  # 12 draft tokens

# Allocate tree structure tensors
parent_list = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
selected_index = torch.arange(draft_token_num, dtype=torch.int32, device="cuda").unsqueeze(0)
verified_seq_len = torch.tensor([128], dtype=torch.int32, device="cuda")
tree_mask = torch.zeros(batch_size, draft_token_num, draft_token_num, dtype=torch.bool, device="cuda")
positions = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
retrive_index = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
retrive_next_token = torch.full((batch_size, draft_token_num), -1, dtype=torch.int32, device="cuda")
retrive_next_sibling = torch.full((batch_size, draft_token_num), -1, dtype=torch.int32, device="cuda")

# Build the tree
build_tree_kernel_efficient(
    parent_list, selected_index, verified_seq_len,
    tree_mask, positions, retrive_index,
    retrive_next_token, retrive_next_sibling,
    topk=topk, depth=depth,
    draft_token_num=draft_token_num, tree_mask_mode=0,
)

# After running draft and target models, verify greedily
predicts = torch.zeros(batch_size, draft_token_num + 1, dtype=torch.int32, device="cuda")
accept_index = torch.zeros(batch_size, draft_token_num + 1, dtype=torch.int32, device="cuda")
accept_token_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
candidates = torch.randint(0, 32000, (batch_size, draft_token_num), dtype=torch.int32, device="cuda")
target_predict = torch.randint(0, 32000, (batch_size, draft_token_num), dtype=torch.int32, device="cuda")

verify_tree_greedy(
    predicts, accept_index, accept_token_num,
    candidates, retrive_index,
    retrive_next_token, retrive_next_sibling,
    target_predict,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment