Implementation:Sgl project Sglang Speculative Decoding Ops
| Knowledge Sources | |
|---|---|
| Domains | Speculative Decoding, GPU Computing, LLM Inference |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python interface for speculative decoding GPU kernels supporting tree-based draft verification, tree construction, and auxiliary tree mask operations.
Description
This module provides the speculative decoding operations API for the SGLang kernel library. It exposes five functions that delegate to custom CUDA kernels registered via torch.ops.sgl_kernel, enabling efficient tree-structured speculative decoding during LLM inference.
The core verification functions are tree_speculative_sampling_target_only and verify_tree_greedy. The sampling variant performs speculative sampling verification using target model probabilities against draft model probabilities, with configurable threshold_single and threshold_acc acceptance thresholds and optional deterministic behavior. The greedy variant performs straightforward greedy verification by comparing draft candidate tokens against target model predictions.
build_tree_kernel_efficient constructs the tree data structures (mask, positions, retrieval indices, next-token pointers, and next-sibling pointers) from parent lists and selected indices, parameterized by topk, depth, draft_token_num, and tree_mask_mode.
reconstruct_indices_from_tree_mask rebuilds retrieval indices from a tree mask, useful when the tree structure changes. segment_packbits packs boolean segments into bitwise representations for compact tree mask storage.
All functions operate in-place on mutable tensor arguments and return None.
Usage
Use these operations when implementing tree-based speculative decoding in SGLang. The typical workflow is: (1) build the tree structure with build_tree_kernel_efficient, (2) run the draft model to generate candidate tokens, (3) verify candidates with either tree_speculative_sampling_target_only (stochastic) or verify_tree_greedy (deterministic), and (4) use auxiliary functions like reconstruct_indices_from_tree_mask and segment_packbits for tree mask manipulation.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/speculative.py
- Lines: 1-127
Signature
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
uniform_samples_for_final_sampling: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
threshold_single: float = 1.0,
threshold_acc: float = 1.0,
deterministic: bool = True,
) -> None:
...
def verify_tree_greedy(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
target_predict: torch.Tensor,
) -> None:
...
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
tree_mask_mode: int,
) -> None:
...
def reconstruct_indices_from_tree_mask(
tree_mask: torch.Tensor,
verified_seq_len: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
batch_size: int,
draft_token_num: int,
) -> None:
...
def segment_packbits(
x: torch.Tensor,
input_indptr: torch.Tensor,
output_indptr: torch.Tensor,
y: torch.Tensor,
batch_size: int,
) -> None:
...
Import
from sgl_kernel import (
tree_speculative_sampling_target_only,
verify_tree_greedy,
build_tree_kernel_efficient,
reconstruct_indices_from_tree_mask,
segment_packbits,
)
I/O Contract
Inputs
tree_speculative_sampling_target_only
| Name | Type | Required | Description |
|---|---|---|---|
| predicts | torch.Tensor | Yes | Output tensor for predicted token IDs (mutable, written in-place) |
| accept_index | torch.Tensor | Yes | Output tensor for accepted token indices (mutable, written in-place) |
| accept_token_num | torch.Tensor | Yes | Output tensor for the count of accepted tokens (mutable, written in-place) |
| candidates | torch.Tensor | Yes | Draft candidate token IDs from the draft model |
| retrive_index | torch.Tensor | Yes | Tree retrieval index mapping node positions |
| retrive_next_token | torch.Tensor | Yes | Pointer to the next token in the tree traversal |
| retrive_next_sibling | torch.Tensor | Yes | Pointer to the next sibling in the tree traversal |
| uniform_samples | torch.Tensor | Yes | Pre-generated uniform random samples for acceptance decisions |
| uniform_samples_for_final_sampling | torch.Tensor | Yes | Pre-generated uniform random samples for final token sampling |
| target_probs | torch.Tensor | Yes | Probability distribution from the target model |
| draft_probs | torch.Tensor | Yes | Probability distribution from the draft model |
| threshold_single | float | No | Per-token acceptance threshold; defaults to 1.0 |
| threshold_acc | float | No | Accumulated acceptance threshold; defaults to 1.0 |
| deterministic | bool | No | Whether to use deterministic sampling; defaults to True |
verify_tree_greedy
| Name | Type | Required | Description |
|---|---|---|---|
| predicts | torch.Tensor | Yes | Output tensor for predicted token IDs (mutable, written in-place) |
| accept_index | torch.Tensor | Yes | Output tensor for accepted token indices (mutable, written in-place) |
| accept_token_num | torch.Tensor | Yes | Output tensor for the count of accepted tokens (mutable, written in-place) |
| candidates | torch.Tensor | Yes | Draft candidate token IDs |
| retrive_index | torch.Tensor | Yes | Tree retrieval index mapping node positions |
| retrive_next_token | torch.Tensor | Yes | Pointer to the next token in the tree traversal |
| retrive_next_sibling | torch.Tensor | Yes | Pointer to the next sibling in the tree traversal |
| target_predict | torch.Tensor | Yes | Target model greedy predictions (argmax token IDs) |
build_tree_kernel_efficient
| Name | Type | Required | Description |
|---|---|---|---|
| parent_list | torch.Tensor | Yes | Parent indices defining the tree topology |
| selected_index | torch.Tensor | Yes | Indices of selected draft tokens |
| verified_seq_len | torch.Tensor | Yes | Verified sequence lengths per batch element |
| tree_mask | torch.Tensor | Yes | Output tree attention mask (mutable, written in-place) |
| positions | torch.Tensor | Yes | Output position IDs for each tree node (mutable, written in-place) |
| retrive_index | torch.Tensor | Yes | Output retrieval index (mutable, written in-place) |
| retrive_next_token | torch.Tensor | Yes | Output next-token pointer (mutable, written in-place) |
| retrive_next_sibling | torch.Tensor | Yes | Output next-sibling pointer (mutable, written in-place) |
| topk | int | Yes | Number of top-k candidates per tree level |
| depth | int | Yes | Maximum depth of the draft tree |
| draft_token_num | int | Yes | Total number of draft tokens in the tree |
| tree_mask_mode | int | Yes | Mode for generating the tree mask |
reconstruct_indices_from_tree_mask
| Name | Type | Required | Description |
|---|---|---|---|
| tree_mask | torch.Tensor | Yes | The tree attention mask |
| verified_seq_len | torch.Tensor | Yes | Verified sequence lengths per batch element |
| positions | torch.Tensor | Yes | Output position IDs (mutable, written in-place) |
| retrive_index | torch.Tensor | Yes | Output retrieval index (mutable, written in-place) |
| retrive_next_token | torch.Tensor | Yes | Output next-token pointer (mutable, written in-place) |
| retrive_next_sibling | torch.Tensor | Yes | Output next-sibling pointer (mutable, written in-place) |
| batch_size | int | Yes | Number of sequences in the batch |
| draft_token_num | int | Yes | Total number of draft tokens |
segment_packbits
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Input boolean tensor to pack |
| input_indptr | torch.Tensor | Yes | Input segment boundary pointers |
| output_indptr | torch.Tensor | Yes | Output segment boundary pointers |
| y | torch.Tensor | Yes | Output packed bits tensor (mutable, written in-place) |
| batch_size | int | Yes | Number of segments to process |
Outputs
| Name | Type | Description |
|---|---|---|
| (all functions) | None | All functions operate in-place on mutable tensor arguments and return None |
Usage Examples
import torch
from sgl_kernel import build_tree_kernel_efficient, verify_tree_greedy
# Example: Build a tree structure for speculative decoding
batch_size = 1
topk = 4
depth = 3
draft_token_num = topk * depth # 12 draft tokens
# Allocate tree structure tensors
parent_list = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
selected_index = torch.arange(draft_token_num, dtype=torch.int32, device="cuda").unsqueeze(0)
verified_seq_len = torch.tensor([128], dtype=torch.int32, device="cuda")
tree_mask = torch.zeros(batch_size, draft_token_num, draft_token_num, dtype=torch.bool, device="cuda")
positions = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
retrive_index = torch.zeros(batch_size, draft_token_num, dtype=torch.int32, device="cuda")
retrive_next_token = torch.full((batch_size, draft_token_num), -1, dtype=torch.int32, device="cuda")
retrive_next_sibling = torch.full((batch_size, draft_token_num), -1, dtype=torch.int32, device="cuda")
# Build the tree
build_tree_kernel_efficient(
parent_list, selected_index, verified_seq_len,
tree_mask, positions, retrive_index,
retrive_next_token, retrive_next_sibling,
topk=topk, depth=depth,
draft_token_num=draft_token_num, tree_mask_mode=0,
)
# After running draft and target models, verify greedily
predicts = torch.zeros(batch_size, draft_token_num + 1, dtype=torch.int32, device="cuda")
accept_index = torch.zeros(batch_size, draft_token_num + 1, dtype=torch.int32, device="cuda")
accept_token_num = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
candidates = torch.randint(0, 32000, (batch_size, draft_token_num), dtype=torch.int32, device="cuda")
target_predict = torch.randint(0, 32000, (batch_size, draft_token_num), dtype=torch.int32, device="cuda")
verify_tree_greedy(
predicts, accept_index, accept_token_num,
candidates, retrive_index,
retrive_next_token, retrive_next_sibling,
target_predict,
)