Implementation:Huggingface Trl GRPO Reward Functions
Appearance
| Property | Value |
|---|---|
| Implementation Name | GRPO Reward Functions |
| Library | Huggingface TRL |
| Type | API Doc |
| Source Files | trl/rewards/accuracy_rewards.py (L23-178), trl/rewards/format_rewards.py (L18-50), trl/rewards/other_rewards.py (L18-62)
|
| Import | from trl.rewards import accuracy_reward, think_format_reward; from trl.rewards.other_rewards import get_soft_overlong_punishment
|
Overview
Description
TRL provides a set of built-in reward functions for common evaluation scenarios in online RL training. These functions follow a standardized callable interface and can be composed together within the GRPOTrainer. Each function evaluates a batch of completions and returns a list of scalar rewards.
Usage
from trl.rewards import accuracy_reward, reasoning_accuracy_reward, think_format_reward
from trl.rewards.other_rewards import get_soft_overlong_punishment
# Use directly with GRPOTrainer
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=[accuracy_reward, think_format_reward],
train_dataset=dataset,
)
# Or create a length-penalty reward with custom thresholds
length_reward = get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256)
Code Reference
Source Location
| Function | File | Lines |
|---|---|---|
accuracy_reward |
trl/rewards/accuracy_rewards.py |
L23-78 |
reasoning_accuracy_reward |
trl/rewards/accuracy_rewards.py |
L81-178 |
think_format_reward |
trl/rewards/format_rewards.py |
L18-50 |
get_soft_overlong_punishment |
trl/rewards/other_rewards.py |
L18-62 |
Signature
def accuracy_reward(
completions: list[list[dict[str, str]]],
solution: list[str],
**kwargs,
) -> list[float | None]:
"""
Checks if the completion matches the ground truth using math verification.
Returns 1.0 for correct, 0.0 for incorrect, None if gold is unparseable.
Requires: math_verify package.
"""
def reasoning_accuracy_reward(
completions: list[list[dict[str, str]]],
solution: list[str],
reasoning_delimiters: list[str] | None = None,
**kwargs,
) -> list[float | None]:
"""
Strips reasoning content before the delimiter (default: "</think>"),
then verifies the final answer. Returns 0.0 if reasoning is incomplete
(no closing delimiter found), None if gold is unparseable.
"""
def think_format_reward(
completions: list[list[dict[str, str]]],
**kwargs,
) -> list[float]:
"""
Checks that completions follow the <think>...</think> format.
Returns 1.0 if matched, 0.0 otherwise.
Uses regex: r"^<think>(?!.*<think>)(.*?)</think>.*$"
"""
def get_soft_overlong_punishment(
max_completion_len: int,
soft_punish_cache: int,
) -> Callable:
"""
Factory that returns a reward function penalizing overlong completions.
The returned function accepts completion_ids (list[list[int]]) and returns list[float].
Reward schedule (per DAPO Eq. 13):
- |y| <= L_max - L_cache => 0.0
- L_max - L_cache < |y| <= L_max => linear penalty from 0 to -1
- |y| > L_max => -1.0
"""
Import
from trl.rewards import accuracy_reward
from trl.rewards import reasoning_accuracy_reward
from trl.rewards import think_format_reward
from trl.rewards.other_rewards import get_soft_overlong_punishment
I/O Contract
Inputs (accuracy_reward)
| Parameter | Type | Description |
|---|---|---|
completions |
list[list[dict[str, str]]] |
Batch of completions in conversational format. Each completion is a list containing one message dict with a "content" key.
|
solution |
list[str] |
Ground-truth solutions in raw LaTeX text. |
**kwargs |
Any | Additional keyword arguments (ignored but required for trainer compatibility). |
Outputs (accuracy_reward)
| Output | Type | Description |
|---|---|---|
| rewards | None] | 1.0 if mathematically correct, 0.0 if incorrect, None if gold solution is unparseable.
|
Inputs (think_format_reward)
| Parameter | Type | Description |
|---|---|---|
completions |
list[list[dict[str, str]]] |
Batch of completions in conversational format. |
**kwargs |
Any | Additional keyword arguments (ignored). |
Outputs (think_format_reward)
| Output | Type | Description |
|---|---|---|
| rewards | list[float] |
1.0 if the <think>...</think> pattern is matched, 0.0 otherwise.
|
Inputs (get_soft_overlong_punishment)
| Parameter | Type | Description |
|---|---|---|
max_completion_len |
int |
Maximum allowed completion length (L_max). |
soft_punish_cache |
int |
Size of the soft penalty region before the hard cutoff (L_cache). |
Outputs (get_soft_overlong_punishment)
| Output | Type | Description |
|---|---|---|
| reward_function | Callable |
A reward function that accepts completion_ids: list[list[int]] and returns list[float].
|
Usage Examples
Accuracy reward with math verification:
from trl.rewards import accuracy_reward
solutions = [r"\frac{1}{3}", r"\frac{1}{3}"]
completions = [
[{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
[{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
]
rewards = accuracy_reward(completions, solutions)
# returns [1.0, 0.0]
Reasoning accuracy reward that strips thinking content:
from trl.rewards import reasoning_accuracy_reward
solutions = [r"\frac{1}{3}", r"\frac{1}{3}", r"\frac{1}{3}"]
completions = [
[{"role": "assistant", "content": r"<think>Reasoning</think> \boxed{\frac{1}{3}}"}],
[{"role": "assistant", "content": r"<think>Reasoning</think> \boxed{\frac{1}{2}}"}],
[{"role": "assistant", "content": r"<think>Incomplete reasoning with \boxed{\frac{1}{3}}"}],
]
rewards = reasoning_accuracy_reward(completions, solutions)
# returns [1.0, 0.0, 0.0] -- third is 0.0 because </think> is missing
Soft overlong punishment:
from trl.rewards.other_rewards import get_soft_overlong_punishment
reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
rewards = reward_fn(completion_ids=[[1] * 90])
# 90 tokens is within the penalty zone (80-100), returns [-0.5]
Registry in the GRPO script:
# From trl/scripts/grpo.py
reward_funcs_registry = {
"accuracy_reward": accuracy_reward,
"reasoning_accuracy_reward": reasoning_accuracy_reward,
"think_format_reward": think_format_reward,
"get_soft_overlong_punishment": get_soft_overlong_punishment(
max_completion_len=1280, soft_punish_cache=256
),
}
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment