Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Reward Function Interface

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, Reward_Modeling, NLP
Last Updated 2026-02-07 16:00 GMT

Overview

Interface specification for user-defined reward functions in trlx online PPO training.

Description

The reward function is a user-defined callable that scores batches of generated text samples during PPO rollouts. This is a Pattern Doc — it documents the interface that users must implement, not a library API. The function receives generated samples, prompts, and outputs, and must return a list of scalar rewards. trlx calls this function internally during each PPO rollout in AcceleratePPOTrainer.make_experience().

Usage

Implement this interface when using trlx.train() with online PPO. Pass the function as the reward_fn argument. The function will be called with batches of generated text during training.

Code Reference

Source Location

  • Repository: trlx
  • File: examples/ppo_sentiments.py
  • Lines: L39-41 (canonical example)

Interface Specification

def reward_fn(
    samples: List[str],       # Full generated text (prompt + output)
    prompts: List[str],       # Original prompt text
    outputs: List[str],       # Generated output only (no prompt)
    **kwargs                  # Additional metadata passed through PromptPipeline
) -> List[float]:
    """
    Score generated text samples with scalar rewards.

    Args:
        samples: Complete generated strings (prompt + completion).
        prompts: The original prompt strings.
        outputs: The generated completion strings only.
        **kwargs: Extra keys from prompt dicts (if prompts were dicts).

    Returns:
        List of float rewards, one per sample. Higher = better.
    """
    ...

Import

# No import needed — this is a user-defined function
# Pass directly to trlx.train()
import trlx
trlx.train(reward_fn=my_reward_fn, prompts=prompts, config=config)

I/O Contract

Inputs

Name Type Required Description
samples List[str] Yes Full generated text strings (prompt + completion)
prompts List[str] Yes Original prompt strings
outputs List[str] Yes Generated completion strings only
**kwargs Dict No Extra metadata from prompt dicts (e.g., original_output)

Outputs

Name Type Description
return List[float] Scalar reward per sample, same length as inputs

Usage Examples

Sentiment-Based Reward

from transformers import pipeline

sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)

def reward_fn(samples, **kwargs):
    """Reward = probability of positive sentiment."""
    sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}
    pipe_outputs = sentiment_fn(samples, **sent_kwargs)
    rewards = [output[1]["score"] for output in pipe_outputs]
    return rewards

import trlx
from trlx.data.default_configs import default_ppo_config

config = default_ppo_config()
trainer = trlx.train(reward_fn=reward_fn, prompts=prompts, config=config)

Delta Reward (Model-Based)

def reward_fn(samples, prompts, original_output, **kwargs):
    """Compute reward improvement over reference output."""
    generated_rewards = get_reward_scores(samples)
    original_samples = [p + o for p, o in zip(prompts, original_output)]
    original_rewards = get_reward_scores(original_samples)
    return (generated_rewards - original_rewards).tolist()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment