Implementation:CarperAI Trlx Random Walk Environment
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Graph_Algorithms, Data_Generation |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for generating synthetic random-walk datasets on directed graphs for evaluating RL-based language model training.
Description
The generate_random_walks function creates a random directed graph with configurable node count and edge probability, generates sample walks from random start nodes to a fixed goal node, and provides a metric function that measures path optimality (ratio of shortest-path length to actual path length). The graph uses lowercase letters as node labels (max 26 nodes). Returns training data (walks as strings), evaluation prompts (unique start nodes), a metric function, and a logit mask (adjacency matrix) for constrained decoding. Used as a simple, interpretable test environment for validating trlx PPO and ILQL training.
Usage
Use this module to set up the random walks example experiment. The returned data is passed directly to trlx.train() for PPO, ILQL, or RFT training on shortest-path finding.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: examples/randomwalks/randomwalks.py
- Lines: 1-220
Signature
def generate_random_walks(
n_nodes: int = 21,
max_length: int = 10,
n_walks: int = 1000,
p_edge: float = 0.1,
seed: int = 1002,
gpt2_tokenizer: bool = False,
) -> Tuple[Callable, List[str], List[str], torch.Tensor]:
"""
Generate a random directed graph and sample walks on it.
Args:
n_nodes: Number of graph nodes (max 26, uses lowercase letters).
max_length: Maximum walk length.
n_walks: Number of sample walks to generate.
p_edge: Probability of edge existence between any two nodes.
seed: Random seed for reproducibility.
gpt2_tokenizer: Use "|" delimiter for GPT-2 tokenizer compatibility.
Returns:
Tuple of (metric_fn, eval_prompts, sample_walks, logit_mask).
"""
Import
from examples.randomwalks import generate_random_walks
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| n_nodes | int | No | Number of graph nodes (default 21, max 26) |
| max_length | int | No | Maximum walk length (default 10) |
| n_walks | int | No | Number of sample walks to generate (default 1000) |
| p_edge | float | No | Edge probability between any two nodes (default 0.1) |
| seed | int | No | Random seed (default 1002) |
| gpt2_tokenizer | bool | No | Whether to format walks for GPT-2 tokenizer |
Outputs
| Name | Type | Description |
|---|---|---|
| metric_fn | Callable | Function scoring path optimality (shortest_path_length / actual_length) |
| eval_prompts | List[str] | Unique start node strings for evaluation |
| sample_walks | List[str] | Training data: node sequences as strings |
| logit_mask | torch.Tensor | Adjacency matrix for constrained decoding (vocab_size x vocab_size) |
Usage Examples
Generate Walks and Train with PPO
import trlx
from examples.randomwalks import generate_random_walks
# 1. Generate random walk environment
metric_fn, eval_prompts, walks, logit_mask = generate_random_walks(
n_nodes=21,
max_length=10,
n_walks=1000,
p_edge=0.1,
seed=1002,
)
# 2. Define reward function (1.0 if goal reached, 0.0 otherwise)
def reward_fn(samples, **kwargs):
# Goal node is the last letter in the alphabet slice
goal = chr(ord('a') + 20) # 'u' for n_nodes=21
return [1.0 if s.strip().endswith(goal) else 0.0 for s in samples]
# 3. Train with PPO
trainer = trlx.train(
reward_fn=reward_fn,
metric_fn=metric_fn,
prompts=walks,
eval_prompts=eval_prompts,
config=config,
logit_mask=logit_mask,
)