Implementation:SqueezeAILab ETS Tree Select Softmax Costmodel
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Tree_Search, Optimization, Integer_Linear_Programming |
| Last Updated | 2026-02-14 02:00 GMT |
Overview
Concrete tool for ILP-based node selection in the ETS tree search, provided by the Tree class in rebase.py.
Description
The Tree.select_softmax_costmodel() method implements the ETS paper's core contribution. It:
- Computes a branch-to-leaf mapping by traversing parent pointers from each leaf
- Formulates an ILP using PuLP with binary variables for each leaf and branch node
- Computes softmax-weighted outcome scores based on PRM scores
- Optionally adds diversity coverage constraints using SentenceTransformer embeddings and hierarchical clustering
- Solves the ILP with the CBC solver
- Re-applies softmax proportional allocation to the retained nodes to determine expansion widths
Usage
Called automatically by Tree.select_and_expand() when select_method == "softmax_costmodel". Not called directly by user code.
Code Reference
Source Location
- Repository: ETS
- File: rebase.py
- Lines: 373-569
Signature
def select_softmax_costmodel(self, node_list, node_weights, width, depth):
"""
ILP-based node selection with cost and diversity optimization.
Args:
node_list (list[TreeNode]): Candidate leaf nodes at current depth
node_weights (list[float]): PRM scores for each candidate node
width (int): Remaining search budget
depth (int): Current tree depth
Returns:
tuple[list[TreeNode], list[int]]:
- nodes: All candidate nodes (retained + pruned) sorted by score
- select_num: Expansion width per node (0 for pruned nodes)
"""
Import
# Internal method of Tree class in rebase.py
# External dependencies:
from pulp import LpMaximize, LpProblem, LpVariable, lpSum, PULP_CBC_CMD
from sentence_transformers import SentenceTransformer
from scipy.cluster.hierarchy import linkage, fcluster
import numpy as np
import torch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| node_list | list[TreeNode] | Yes | Candidate leaf nodes at current depth (non-leaf, non-exhausted) |
| node_weights | list[float] | Yes | PRM scores for each candidate (from TreeNode.get_score()) |
| width | int | Yes | Remaining search budget (self.remaining_width) |
| depth | int | Yes | Current tree depth level |
Instance attributes used:
| Attribute | Type | Description |
|---|---|---|
| self.lambdac | float | Cost penalty weight (from config) |
| self.lambdas | float | Diversity penalty weight (from config) |
| self.model | SentenceTransformer or None | Embedding model for diversity scoring |
| self.paras["softmax_temperature"] | float | Temperature for softmax allocation |
Outputs
| Name | Type | Description |
|---|---|---|
| nodes | list[TreeNode] | All candidate nodes sorted by score (descending) |
| select_num | list[int] | Expansion width for each node (0 = pruned) |
Usage Examples
Called Internally by select_and_expand
# From Tree.select_and_expand() at rebase.py:L622-623
if self.paras["select_method"] == "softmax_costmodel":
nodes, widths = self.select_softmax_costmodel(
cand_node_list, # non-leaf, non-exhausted nodes at current depth
cand_node_weights, # PRM scores
self.remaining_width, # budget remaining
depth, # current depth
)
# Then expand each retained node
for expand_node, width in zip(nodes, widths):
if width >= 1:
self.expand(expand_node, width)
ILP Structure (Conceptual)
# Conceptual view of the ILP formulation
from pulp import LpMaximize, LpProblem, LpVariable, lpSum
# Binary variables
x = [LpVariable(f"x_{i}", cat="Binary") for i in range(N)] # leaf decisions
y = [LpVariable(f"y_{j}", cat="Binary") for j in range(M)] # branch decisions
problem = LpProblem("Tree_Selection_Problem", LpMaximize)
# Objective: outcome scores + cost penalty + diversity coverage
problem += (
lpSum(O[i] * x[i] for i in range(N)) + # maximize outcomes
lpSum(lambdac * Cost[j] * y[j] for j in range(M)) + # cost penalty (negative lambdac)
lpSum(lambdac * Cost[j+M] * x[j] for j in range(N)) + # leaf cost
lpSum(lambdas * coverage[k] for k in range(K)) # diversity
)
# Constraint: branch active iff has active leaf
for j, leaf_nodes in branch_leaf_mapping.items():
problem += y[j] <= lpSum(x[i] for i in leaf_nodes)
for i in leaf_nodes:
problem += y[j] >= x[i]
# Constraint: keep at least one node
problem += lpSum(x[i] for i in range(N)) >= 1
problem.solve(PULP_CBC_CMD(msg=0))
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment