Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Random Walk Environment

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, Graph_Algorithms, Data_Generation
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for generating synthetic random-walk datasets on directed graphs for evaluating RL-based language model training.

Description

The generate_random_walks function creates a random directed graph with configurable node count and edge probability, generates sample walks from random start nodes to a fixed goal node, and provides a metric function that measures path optimality (ratio of shortest-path length to actual path length). The graph uses lowercase letters as node labels (max 26 nodes). Returns training data (walks as strings), evaluation prompts (unique start nodes), a metric function, and a logit mask (adjacency matrix) for constrained decoding. Used as a simple, interpretable test environment for validating trlx PPO and ILQL training.

Usage

Use this module to set up the random walks example experiment. The returned data is passed directly to trlx.train() for PPO, ILQL, or RFT training on shortest-path finding.

Code Reference

Source Location

Signature

def generate_random_walks(
    n_nodes: int = 21,
    max_length: int = 10,
    n_walks: int = 1000,
    p_edge: float = 0.1,
    seed: int = 1002,
    gpt2_tokenizer: bool = False,
) -> Tuple[Callable, List[str], List[str], torch.Tensor]:
    """
    Generate a random directed graph and sample walks on it.

    Args:
        n_nodes: Number of graph nodes (max 26, uses lowercase letters).
        max_length: Maximum walk length.
        n_walks: Number of sample walks to generate.
        p_edge: Probability of edge existence between any two nodes.
        seed: Random seed for reproducibility.
        gpt2_tokenizer: Use "|" delimiter for GPT-2 tokenizer compatibility.

    Returns:
        Tuple of (metric_fn, eval_prompts, sample_walks, logit_mask).
    """

Import

from examples.randomwalks import generate_random_walks

I/O Contract

Inputs

Name Type Required Description
n_nodes int No Number of graph nodes (default 21, max 26)
max_length int No Maximum walk length (default 10)
n_walks int No Number of sample walks to generate (default 1000)
p_edge float No Edge probability between any two nodes (default 0.1)
seed int No Random seed (default 1002)
gpt2_tokenizer bool No Whether to format walks for GPT-2 tokenizer

Outputs

Name Type Description
metric_fn Callable Function scoring path optimality (shortest_path_length / actual_length)
eval_prompts List[str] Unique start node strings for evaluation
sample_walks List[str] Training data: node sequences as strings
logit_mask torch.Tensor Adjacency matrix for constrained decoding (vocab_size x vocab_size)

Usage Examples

Generate Walks and Train with PPO

import trlx
from examples.randomwalks import generate_random_walks

# 1. Generate random walk environment
metric_fn, eval_prompts, walks, logit_mask = generate_random_walks(
    n_nodes=21,
    max_length=10,
    n_walks=1000,
    p_edge=0.1,
    seed=1002,
)

# 2. Define reward function (1.0 if goal reached, 0.0 otherwise)
def reward_fn(samples, **kwargs):
    # Goal node is the last letter in the alphabet slice
    goal = chr(ord('a') + 20)  # 'u' for n_nodes=21
    return [1.0 if s.strip().endswith(goal) else 0.0 for s in samples]

# 3. Train with PPO
trainer = trlx.train(
    reward_fn=reward_fn,
    metric_fn=metric_fn,
    prompts=walks,
    eval_prompts=eval_prompts,
    config=config,
    logit_mask=logit_mask,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment