Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Trl GRPO Reward Functions

From Leeroopedia
Revision as of 15:11, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Huggingface_Trl_GRPO_Reward_Functions.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Name GRPO Reward Functions
Library Huggingface TRL
Type API Doc
Source Files trl/rewards/accuracy_rewards.py (L23-178), trl/rewards/format_rewards.py (L18-50), trl/rewards/other_rewards.py (L18-62)
Import from trl.rewards import accuracy_reward, think_format_reward; from trl.rewards.other_rewards import get_soft_overlong_punishment

Overview

Description

TRL provides a set of built-in reward functions for common evaluation scenarios in online RL training. These functions follow a standardized callable interface and can be composed together within the GRPOTrainer. Each function evaluates a batch of completions and returns a list of scalar rewards.

Usage

from trl.rewards import accuracy_reward, reasoning_accuracy_reward, think_format_reward
from trl.rewards.other_rewards import get_soft_overlong_punishment

# Use directly with GRPOTrainer
trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=[accuracy_reward, think_format_reward],
    train_dataset=dataset,
)

# Or create a length-penalty reward with custom thresholds
length_reward = get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256)

Code Reference

Source Location

Function File Lines
accuracy_reward trl/rewards/accuracy_rewards.py L23-78
reasoning_accuracy_reward trl/rewards/accuracy_rewards.py L81-178
think_format_reward trl/rewards/format_rewards.py L18-50
get_soft_overlong_punishment trl/rewards/other_rewards.py L18-62

Signature

def accuracy_reward(
    completions: list[list[dict[str, str]]],
    solution: list[str],
    **kwargs,
) -> list[float | None]:
    """
    Checks if the completion matches the ground truth using math verification.
    Returns 1.0 for correct, 0.0 for incorrect, None if gold is unparseable.
    Requires: math_verify package.
    """
def reasoning_accuracy_reward(
    completions: list[list[dict[str, str]]],
    solution: list[str],
    reasoning_delimiters: list[str] | None = None,
    **kwargs,
) -> list[float | None]:
    """
    Strips reasoning content before the delimiter (default: "</think>"),
    then verifies the final answer. Returns 0.0 if reasoning is incomplete
    (no closing delimiter found), None if gold is unparseable.
    """
def think_format_reward(
    completions: list[list[dict[str, str]]],
    **kwargs,
) -> list[float]:
    """
    Checks that completions follow the <think>...</think> format.
    Returns 1.0 if matched, 0.0 otherwise.
    Uses regex: r"^<think>(?!.*<think>)(.*?)</think>.*$"
    """
def get_soft_overlong_punishment(
    max_completion_len: int,
    soft_punish_cache: int,
) -> Callable:
    """
    Factory that returns a reward function penalizing overlong completions.
    The returned function accepts completion_ids (list[list[int]]) and returns list[float].

    Reward schedule (per DAPO Eq. 13):
      - |y| <= L_max - L_cache  =>  0.0
      - L_max - L_cache < |y| <= L_max  =>  linear penalty from 0 to -1
      - |y| > L_max  =>  -1.0
    """

Import

from trl.rewards import accuracy_reward
from trl.rewards import reasoning_accuracy_reward
from trl.rewards import think_format_reward
from trl.rewards.other_rewards import get_soft_overlong_punishment

I/O Contract

Inputs (accuracy_reward)

Parameter Type Description
completions list[list[dict[str, str]]] Batch of completions in conversational format. Each completion is a list containing one message dict with a "content" key.
solution list[str] Ground-truth solutions in raw LaTeX text.
**kwargs Any Additional keyword arguments (ignored but required for trainer compatibility).

Outputs (accuracy_reward)

Output Type Description
rewards None] 1.0 if mathematically correct, 0.0 if incorrect, None if gold solution is unparseable.

Inputs (think_format_reward)

Parameter Type Description
completions list[list[dict[str, str]]] Batch of completions in conversational format.
**kwargs Any Additional keyword arguments (ignored).

Outputs (think_format_reward)

Output Type Description
rewards list[float] 1.0 if the <think>...</think> pattern is matched, 0.0 otherwise.

Inputs (get_soft_overlong_punishment)

Parameter Type Description
max_completion_len int Maximum allowed completion length (L_max).
soft_punish_cache int Size of the soft penalty region before the hard cutoff (L_cache).

Outputs (get_soft_overlong_punishment)

Output Type Description
reward_function Callable A reward function that accepts completion_ids: list[list[int]] and returns list[float].

Usage Examples

Accuracy reward with math verification:

from trl.rewards import accuracy_reward

solutions = [r"\frac{1}{3}", r"\frac{1}{3}"]
completions = [
    [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
    [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
]
rewards = accuracy_reward(completions, solutions)
# returns [1.0, 0.0]

Reasoning accuracy reward that strips thinking content:

from trl.rewards import reasoning_accuracy_reward

solutions = [r"\frac{1}{3}", r"\frac{1}{3}", r"\frac{1}{3}"]
completions = [
    [{"role": "assistant", "content": r"<think>Reasoning</think> \boxed{\frac{1}{3}}"}],
    [{"role": "assistant", "content": r"<think>Reasoning</think> \boxed{\frac{1}{2}}"}],
    [{"role": "assistant", "content": r"<think>Incomplete reasoning with \boxed{\frac{1}{3}}"}],
]
rewards = reasoning_accuracy_reward(completions, solutions)
# returns [1.0, 0.0, 0.0]  -- third is 0.0 because </think> is missing

Soft overlong punishment:

from trl.rewards.other_rewards import get_soft_overlong_punishment

reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
rewards = reward_fn(completion_ids=[[1] * 90])
# 90 tokens is within the penalty zone (80-100), returns [-0.5]

Registry in the GRPO script:

# From trl/scripts/grpo.py
reward_funcs_registry = {
    "accuracy_reward": accuracy_reward,
    "reasoning_accuracy_reward": reasoning_accuracy_reward,
    "think_format_reward": think_format_reward,
    "get_soft_overlong_punishment": get_soft_overlong_punishment(
        max_completion_len=1280, soft_punish_cache=256
    ),
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment