Overview
Concrete tool for creating a reward function from a trained reward model provided by the trlx HH-RLHF example.
Description
The create_reward_fn() factory function in the HH dialogue alignment example loads a pre-trained reward model from HuggingFace Hub, sets it up for inference, and returns a callable that conforms to trlx's reward function interface. It supports two backends: local PyTorch inference (RewardModel on last GPU, half-precision) and NVIDIA Triton Inference Server (via tritonclient.grpc). For local inference, it computes delta rewards: reward(generated) - reward(original) to normalize the reward signal.
Usage
Call create_reward_fn() to get a reward function for PPO training on dialogue alignment tasks. The function automatically detects the serving backend from the TRITON_HOST environment variable.
Code Reference
Source Location
- Repository: trlx
- File: examples/hh/ppo_hh.py
- Lines: L115-205
Signature
def create_reward_fn() -> Callable:
"""
Factory function that creates a reward function for PPO training.
Supports two backends:
1. Triton Inference Server (if TRITON_HOST env var is set)
2. Local PyTorch inference (on last available GPU)
Returns:
reward_fn(samples, prompts, original_output, **kwargs) -> List[float]
Delta reward function: reward(generated) - reward(original).
"""
Import
# From the HH example script
from examples.hh.ppo_hh import create_reward_fn
I/O Contract
Inputs (Factory)
| Name |
Type |
Required |
Description
|
| (none) |
— |
— |
Uses env vars and HuggingFace Hub for model loading
|
| TRITON_HOST |
env var |
No |
Triton server URL/model (e.g., "localhost:8001/reward_model")
|
Outputs (Factory)
| Name |
Type |
Description
|
| return |
Callable |
Reward function conforming to trlx interface
|
Inputs (Returned Function)
| Name |
Type |
Required |
Description
|
| samples |
List[str] |
Yes |
Generated text samples
|
| prompts |
List[str] |
Yes |
Original prompts
|
| original_output |
List[str] |
Yes |
Reference outputs for delta reward
|
Outputs (Returned Function)
| Name |
Type |
Description
|
| return |
List[float] or Tensor |
Delta rewards (generated_reward - original_reward)
|
Usage Examples
Local Inference with Delta Rewards
import trlx
from trlx.data.configs import TRLConfig
# create_reward_fn loads reward model from HuggingFace Hub
reward_fn = create_reward_fn()
# Use in PPO training
trainer = trlx.train(
reward_fn=reward_fn,
prompts=train_prompts, # List[Dict] with "prompt" and "original_output"
eval_prompts=eval_prompts,
config=config,
stop_sequences=["Human:", "human:", "Assistant:", "assistant:"],
)
Triton Server Backend
import os
os.environ["TRITON_HOST"] = "localhost:8001/reward_model"
# Automatically uses Triton inference instead of local model
reward_fn = create_reward_fn()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.