Implementation:SqueezeAILab ETS Reward Guided Search
| Knowledge Sources | |
|---|---|
| Domains | Tree_Search, Inference_Time_Compute, Mathematical_Reasoning |
| Last Updated | 2026-02-14 02:00 GMT |
Overview
Concrete tool for executing the ETS reward-guided tree search on a single math problem, provided by the ETS repository.
Description
The reward_guided_search function is decorated with SGLang's @function decorator and orchestrates the full tree search loop for a single question. It initializes a Tree object, then iterates through depth levels calling select_and_expand and insert until the search budget is exhausted. After search completes, it collects all leaf nodes as candidate answers with their step-by-step PRM scores.
For batch execution across multiple questions, reward_guided_search.run_batch() processes questions in parallel using SGLang's multi-threaded runtime.
Usage
Called via run_batch from main() in rebase.py. Each invocation processes one question with a dedicated tree instance.
Code Reference
Source Location
- Repository: ETS
- File: rebase.py
- Lines: 639-698 (function body), 746 (run_batch invocation)
Signature
@function
def reward_guided_search(s, id, question, ground_truth_answer, paras, reward_host, multimodel, model_config):
"""
Execute tree search for a single math problem.
Args:
s: SGLang state object (injected by @function decorator)
id (int): Question index for output file naming
question (str): Math problem text
ground_truth_answer (dict): Dictionary with "answer" key containing reference answer
paras (dict): Hyperparameter dictionary from YAML config
reward_host (RuntimeEndpoint): PRM server endpoint
multimodel (SentenceTransformer or None): Embedding model for diversity (None if lambdas=0)
model_config (AutoConfig): Policy model configuration for tokenizer initialization
Returns:
dict: {"id", "question", "model_answer": [{"text", "step_scores"}],
"ground_truth_answer", "total_tokens"}
"""
Import
from sglang import function, gen, RuntimeEndpoint
# reward_guided_search is defined in rebase.py, not imported from a package
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| s | SGLang state | Yes | Injected by @function decorator; represents current generation state |
| id | int | Yes | Question index, used for naming output file (answer_q{id}.json) |
| question | str | Yes | Math problem text to solve |
| ground_truth_answer | dict | Yes | Dictionary with "answer" key for ground truth |
| paras | dict | Yes | Hyperparameters from YAML config (width, select_method, temperatures, etc.) |
| reward_host | RuntimeEndpoint | Yes | Endpoint for PRM server (e.g., http://localhost:30020) |
| multimodel | SentenceTransformer or None | Yes | Embedding model for diversity scoring; None if lambdas=0 |
| model_config | AutoConfig | Yes | Policy model config for tokenizer initialization in Tree.__init__ |
Outputs
| Name | Type | Description |
|---|---|---|
| return value | dict | Per-question result with keys: id, question, model_answer (list of candidates), ground_truth_answer, total_tokens |
| answer_q{id}.json | File | Per-question JSON file written to paras["store_path"] |
Usage Examples
Batch Execution (from main)
from sglang import RuntimeEndpoint
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig
# Build input list for batch processing
input_list_dict = []
for i, prompt in enumerate(prompts):
input_list_dict.append({
"id": i,
"question": prompt,
"ground_truth_answer": test_examples[i],
"paras": paras,
"reward_host": RuntimeEndpoint(args.reward_host),
"multimodel": multimodel,
"model_config": model_config,
})
# Run tree search across all questions in parallel
states = reward_guided_search.run_batch(
input_list_dict,
backend=RuntimeEndpoint(args.policy_host),
num_threads=paras["num_threads"],
progress_bar=True,
)
# Collect results
results = []
for s in states:
results.append(s.ret_value)