Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:SqueezeAILab ETS Tree Select Softmax Costmodel

From Leeroopedia
Knowledge Sources
Domains Tree_Search, Optimization, Integer_Linear_Programming
Last Updated 2026-02-14 02:00 GMT

Overview

Concrete tool for ILP-based node selection in the ETS tree search, provided by the Tree class in rebase.py.

Description

The Tree.select_softmax_costmodel() method implements the ETS paper's core contribution. It:

  1. Computes a branch-to-leaf mapping by traversing parent pointers from each leaf
  2. Formulates an ILP using PuLP with binary variables for each leaf and branch node
  3. Computes softmax-weighted outcome scores based on PRM scores
  4. Optionally adds diversity coverage constraints using SentenceTransformer embeddings and hierarchical clustering
  5. Solves the ILP with the CBC solver
  6. Re-applies softmax proportional allocation to the retained nodes to determine expansion widths

Usage

Called automatically by Tree.select_and_expand() when select_method == "softmax_costmodel". Not called directly by user code.

Code Reference

Source Location

  • Repository: ETS
  • File: rebase.py
  • Lines: 373-569

Signature

def select_softmax_costmodel(self, node_list, node_weights, width, depth):
    """
    ILP-based node selection with cost and diversity optimization.

    Args:
        node_list (list[TreeNode]): Candidate leaf nodes at current depth
        node_weights (list[float]): PRM scores for each candidate node
        width (int): Remaining search budget
        depth (int): Current tree depth

    Returns:
        tuple[list[TreeNode], list[int]]:
            - nodes: All candidate nodes (retained + pruned) sorted by score
            - select_num: Expansion width per node (0 for pruned nodes)
    """

Import

# Internal method of Tree class in rebase.py
# External dependencies:
from pulp import LpMaximize, LpProblem, LpVariable, lpSum, PULP_CBC_CMD
from sentence_transformers import SentenceTransformer
from scipy.cluster.hierarchy import linkage, fcluster
import numpy as np
import torch

I/O Contract

Inputs

Name Type Required Description
node_list list[TreeNode] Yes Candidate leaf nodes at current depth (non-leaf, non-exhausted)
node_weights list[float] Yes PRM scores for each candidate (from TreeNode.get_score())
width int Yes Remaining search budget (self.remaining_width)
depth int Yes Current tree depth level

Instance attributes used:

Attribute Type Description
self.lambdac float Cost penalty weight (from config)
self.lambdas float Diversity penalty weight (from config)
self.model SentenceTransformer or None Embedding model for diversity scoring
self.paras["softmax_temperature"] float Temperature for softmax allocation

Outputs

Name Type Description
nodes list[TreeNode] All candidate nodes sorted by score (descending)
select_num list[int] Expansion width for each node (0 = pruned)

Usage Examples

Called Internally by select_and_expand

# From Tree.select_and_expand() at rebase.py:L622-623
if self.paras["select_method"] == "softmax_costmodel":
    nodes, widths = self.select_softmax_costmodel(
        cand_node_list,       # non-leaf, non-exhausted nodes at current depth
        cand_node_weights,    # PRM scores
        self.remaining_width, # budget remaining
        depth,                # current depth
    )

# Then expand each retained node
for expand_node, width in zip(nodes, widths):
    if width >= 1:
        self.expand(expand_node, width)

ILP Structure (Conceptual)

# Conceptual view of the ILP formulation
from pulp import LpMaximize, LpProblem, LpVariable, lpSum

# Binary variables
x = [LpVariable(f"x_{i}", cat="Binary") for i in range(N)]  # leaf decisions
y = [LpVariable(f"y_{j}", cat="Binary") for j in range(M)]  # branch decisions

problem = LpProblem("Tree_Selection_Problem", LpMaximize)

# Objective: outcome scores + cost penalty + diversity coverage
problem += (
    lpSum(O[i] * x[i] for i in range(N)) +       # maximize outcomes
    lpSum(lambdac * Cost[j] * y[j] for j in range(M)) +  # cost penalty (negative lambdac)
    lpSum(lambdac * Cost[j+M] * x[j] for j in range(N)) +  # leaf cost
    lpSum(lambdas * coverage[k] for k in range(K))  # diversity
)

# Constraint: branch active iff has active leaf
for j, leaf_nodes in branch_leaf_mapping.items():
    problem += y[j] <= lpSum(x[i] for i in leaf_nodes)
    for i in leaf_nodes:
        problem += y[j] >= x[i]

# Constraint: keep at least one node
problem += lpSum(x[i] for i in range(N)) >= 1

problem.solve(PULP_CBC_CMD(msg=0))

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment