Implementation:CarperAI Trlx Reward Function Interface
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Reward_Modeling, NLP |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Interface specification for user-defined reward functions in trlx online PPO training.
Description
The reward function is a user-defined callable that scores batches of generated text samples during PPO rollouts. This is a Pattern Doc — it documents the interface that users must implement, not a library API. The function receives generated samples, prompts, and outputs, and must return a list of scalar rewards. trlx calls this function internally during each PPO rollout in AcceleratePPOTrainer.make_experience().
Usage
Implement this interface when using trlx.train() with online PPO. Pass the function as the reward_fn argument. The function will be called with batches of generated text during training.
Code Reference
Source Location
- Repository: trlx
- File: examples/ppo_sentiments.py
- Lines: L39-41 (canonical example)
Interface Specification
def reward_fn(
samples: List[str], # Full generated text (prompt + output)
prompts: List[str], # Original prompt text
outputs: List[str], # Generated output only (no prompt)
**kwargs # Additional metadata passed through PromptPipeline
) -> List[float]:
"""
Score generated text samples with scalar rewards.
Args:
samples: Complete generated strings (prompt + completion).
prompts: The original prompt strings.
outputs: The generated completion strings only.
**kwargs: Extra keys from prompt dicts (if prompts were dicts).
Returns:
List of float rewards, one per sample. Higher = better.
"""
...
Import
# No import needed — this is a user-defined function
# Pass directly to trlx.train()
import trlx
trlx.train(reward_fn=my_reward_fn, prompts=prompts, config=config)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| samples | List[str] | Yes | Full generated text strings (prompt + completion) |
| prompts | List[str] | Yes | Original prompt strings |
| outputs | List[str] | Yes | Generated completion strings only |
| **kwargs | Dict | No | Extra metadata from prompt dicts (e.g., original_output) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | List[float] | Scalar reward per sample, same length as inputs |
Usage Examples
Sentiment-Based Reward
from transformers import pipeline
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)
def reward_fn(samples, **kwargs):
"""Reward = probability of positive sentiment."""
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}
pipe_outputs = sentiment_fn(samples, **sent_kwargs)
rewards = [output[1]["score"] for output in pipe_outputs]
return rewards
import trlx
from trlx.data.default_configs import default_ppo_config
config = default_ppo_config()
trainer = trlx.train(reward_fn=reward_fn, prompts=prompts, config=config)
Delta Reward (Model-Based)
def reward_fn(samples, prompts, original_output, **kwargs):
"""Compute reward improvement over reference output."""
generated_rewards = get_reward_scores(samples)
original_samples = [p + o for p, o in zip(prompts, original_output)]
original_rewards = get_reward_scores(original_samples)
return (generated_rewards - original_rewards).tolist()