Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Create Reward Fn

From Leeroopedia


Knowledge Sources
Domains Reward_Modeling, Inference, Deployment
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for creating a reward function from a trained reward model provided by the trlx HH-RLHF example.

Description

The create_reward_fn() factory function in the HH dialogue alignment example loads a pre-trained reward model from HuggingFace Hub, sets it up for inference, and returns a callable that conforms to trlx's reward function interface. It supports two backends: local PyTorch inference (RewardModel on last GPU, half-precision) and NVIDIA Triton Inference Server (via tritonclient.grpc). For local inference, it computes delta rewards: reward(generated) - reward(original) to normalize the reward signal.

Usage

Call create_reward_fn() to get a reward function for PPO training on dialogue alignment tasks. The function automatically detects the serving backend from the TRITON_HOST environment variable.

Code Reference

Source Location

  • Repository: trlx
  • File: examples/hh/ppo_hh.py
  • Lines: L115-205

Signature

def create_reward_fn() -> Callable:
    """
    Factory function that creates a reward function for PPO training.

    Supports two backends:
    1. Triton Inference Server (if TRITON_HOST env var is set)
    2. Local PyTorch inference (on last available GPU)

    Returns:
        reward_fn(samples, prompts, original_output, **kwargs) -> List[float]
        Delta reward function: reward(generated) - reward(original).
    """

Import

# From the HH example script
from examples.hh.ppo_hh import create_reward_fn

I/O Contract

Inputs (Factory)

Name Type Required Description
(none) Uses env vars and HuggingFace Hub for model loading
TRITON_HOST env var No Triton server URL/model (e.g., "localhost:8001/reward_model")

Outputs (Factory)

Name Type Description
return Callable Reward function conforming to trlx interface

Inputs (Returned Function)

Name Type Required Description
samples List[str] Yes Generated text samples
prompts List[str] Yes Original prompts
original_output List[str] Yes Reference outputs for delta reward

Outputs (Returned Function)

Name Type Description
return List[float] or Tensor Delta rewards (generated_reward - original_reward)

Usage Examples

Local Inference with Delta Rewards

import trlx
from trlx.data.configs import TRLConfig

# create_reward_fn loads reward model from HuggingFace Hub
reward_fn = create_reward_fn()

# Use in PPO training
trainer = trlx.train(
    reward_fn=reward_fn,
    prompts=train_prompts,     # List[Dict] with "prompt" and "original_output"
    eval_prompts=eval_prompts,
    config=config,
    stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)

Triton Server Backend

import os
os.environ["TRITON_HOST"] = "localhost:8001/reward_model"

# Automatically uses Triton inference instead of local model
reward_fn = create_reward_fn()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment