Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Princeton nlp Tree of thought llm Task Base Class

From Leeroopedia
Knowledge Sources
Domains Software_Design, NLP
Last Updated 2026-02-14 03:30 GMT

Overview

Concrete tool for defining the abstract Task interface that all benchmark tasks must implement provided by the Tree of Thoughts framework.

Description

The Task base class in base.py defines the contract for benchmark tasks. It provides stub implementations of the core methods (__init__, __len__, get_input, test_output) that subclasses override. It also defines the module-level DATA_PATH constant that all tasks use to locate their dataset files.

Three reference implementations exist:

  • Game24Task (game24.py): 92 lines, propose+value strategy, 4 BFS steps, binary validation via sympy.
  • TextTask (text.py): 99 lines, sample+vote strategy, 2 BFS steps, LLM-based coherency scoring.
  • MiniCrosswordsTask (crosswords.py): 258 lines, propose+value with environment simulation.

Usage

Subclass Task when implementing a new benchmark task for the framework. Reference Game24Task for the propose+value pattern and TextTask for the sample+vote pattern.

Code Reference

Source Location

Signature

import os
DATA_PATH = os.path.join(os.path.dirname(__file__), '..', 'data')

class Task:
    def __init__(self):
        pass

    def __len__(self) -> int:
        pass

    def get_input(self, idx: int) -> str:
        pass

    def test_output(self, idx: int, output: str):
        pass

Import

from tot.tasks.base import Task, DATA_PATH

I/O Contract

Inputs (subclass __init__)

Name Type Required Description
file str Task-specific Data file name relative to DATA_PATH

Outputs (required attributes after __init__)

Name Type Description
self.data list Loaded dataset (list of puzzles/inputs)
self.steps int Number of BFS depth levels
self.stops list Stop tokens per step (list of str or None, length = steps)
self.value_cache dict Cache for value evaluations (only for value strategy tasks)

Usage Examples

Implementing a New Task (Propose + Value)

import os
from tot.tasks.base import Task, DATA_PATH
from tot.prompts.my_task import *

class MyTask(Task):
    def __init__(self, file='my_data.csv'):
        super().__init__()
        path = os.path.join(DATA_PATH, 'my_task', file)
        self.data = load_data(path)
        self.value_cache = {}
        self.steps = 3          # 3 levels of BFS
        self.stops = ['\n'] * 3  # newline stop per step

    def __len__(self) -> int:
        return len(self.data)

    def get_input(self, idx: int) -> str:
        return self.data[idx]

    def test_output(self, idx: int, output: str):
        # Task-specific validation
        correct = validate(output, self.data[idx])
        return {'r': int(correct)}

    @staticmethod
    def standard_prompt_wrap(x: str, y: str = '') -> str:
        return standard_prompt.format(input=x) + y

    @staticmethod
    def cot_prompt_wrap(x: str, y: str = '') -> str:
        return cot_prompt.format(input=x) + y

    @staticmethod
    def propose_prompt_wrap(x: str, y: str = '') -> str:
        return propose_prompt.format(input=x) + y

    @staticmethod
    def value_prompt_wrap(x: str, y: str) -> str:
        return value_prompt.format(input=x, partial=y)

    @staticmethod
    def value_outputs_unwrap(x: str, y: str, value_outputs: list) -> float:
        # Parse LLM value judgments into numeric scores
        score = parse_scores(value_outputs)
        return score

Implementing a New Task (Sample + Vote)

class MyVotingTask(Task):
    def __init__(self):
        super().__init__()
        self.data = load_data()
        self.steps = 2
        self.stops = ['\n\n', None]

    # ... __len__, get_input, test_output ...

    @staticmethod
    def vote_prompt_wrap(x: str, ys: list) -> str:
        prompt = vote_prompt
        for i, y in enumerate(ys, 1):
            prompt += f'Choice {i}:\n{y}\n'
        return prompt

    @staticmethod
    def vote_outputs_unwrap(vote_outputs: list, n_candidates: int) -> list:
        votes = [0] * n_candidates
        for output in vote_outputs:
            chosen = parse_vote(output)
            if 0 <= chosen < n_candidates:
                votes[chosen] += 1
        return votes

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment