Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl GRPOTrainer Train Loop

From Leeroopedia


Property Value
Implementation Name GRPOTrainer Training Loop
Library Huggingface TRL
Type API Doc
Source Files trl/trainer/grpo_trainer.py (L1030-2133), trl/generation/vllm_client.py (L57-288), trl/generation/vllm_generation.py (L168-692)
Import from trl import GRPOTrainer
Loss Types grpo, dr_grpo, dapo, bnpo, cispo, sapo

Overview

Description

The GRPOTrainer training loop orchestrates the three-phase online RL cycle: generation, scoring, and policy update. The implementation overrides several Trainer methods to inject GRPO-specific behavior while leveraging the standard Trainer infrastructure for distributed training, gradient accumulation, logging, and checkpointing.

Usage

from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=GRPOConfig(output_dir="./output"),
    train_dataset=dataset,
)
train_output = trainer.train()

Code Reference

Source Location

Method File Lines
training_step trl/trainer/grpo_trainer.py L1030-1039
_prepare_inputs trl/trainer/grpo_trainer.py L1041-1071
_generate_and_score_completions trl/trainer/grpo_trainer.py L1524-1868
_compute_loss trl/trainer/grpo_trainer.py L1953-2133
_calculate_rewards trl/trainer/grpo_trainer.py L1073-1153
_generate trl/trainer/grpo_trainer.py L1432-1522
_generate_single_turn trl/trainer/grpo_trainer.py L1155-1264
VLLMClient trl/generation/vllm_client.py L57-288

Key Method Signatures

def training_step(self, model, inputs, num_items_in_batch):
    """
    Wraps parent training_step to track step count and timing.
    Increments self._step for batch buffering logic.
    """

def _prepare_inputs(self, generation_batch: dict) -> dict:
    """
    In training: receives generation batch (per_device_batch_size * steps_per_generation),
    generates and scores completions, splits into sub-batches, returns current slice.
    In eval: generates and scores for the single eval batch.
    """

def _generate_and_score_completions(self, inputs: list[dict]) -> dict:
    """
    Full generation-scoring pipeline:
    1. Generate completions (transformers or vLLM)
    2. Compute rewards via all reward functions
    3. Compute group-relative advantages
    4. Compute old_per_token_logps for importance sampling
    5. Compute ref_per_token_logps for KL penalty
    Returns dict with prompt_ids, completion_ids, advantages, etc.
    """

def _compute_loss(self, model, inputs) -> torch.Tensor:
    """
    Computes the GRPO policy gradient loss:
    1. Forward pass for per_token_logps and entropies
    2. Importance sampling ratio: exp(log_pi_new - log_pi_old)
    3. Clipped surrogate objective (varies by loss_type)
    4. Optional KL penalty (beta * per_token_kl)
    5. Loss aggregation (varies by loss_type)
    """
class VLLMClient:
    def __init__(
        self,
        base_url: str | None = None,
        host: str = "0.0.0.0",
        server_port: int = 8000,
        group_port: int = 51216,
        connection_timeout: float = 0.0,
    ):
        """Client to interact with a TRL vLLM server."""

    def generate(
        self,
        prompts: list,
        images: list | None = None,
        n: int = 1,
        temperature: float = 1.0,
        max_tokens: int = 16,
        **kwargs,
    ) -> dict:
        """
        Generate completions from the vLLM server.
        Returns dict with 'prompt_ids', 'completion_ids', 'logprobs'.
        """

Import

from trl import GRPOTrainer
from trl.generation.vllm_client import VLLMClient

I/O Contract

Inputs (training_step)

Parameter Type Description
model nn.Module The policy model (possibly wrapped by distributed framework).
inputs dict Generation batch from the dataloader, containing "prompt" and optional metadata columns.
num_items_in_batch int Number of items for loss scaling.

Outputs (_generate_and_score_completions)

Key Type Description
prompt_ids torch.Tensor Left-padded prompt token IDs, shape (B, P).
prompt_mask torch.Tensor Attention mask for prompts, shape (B, P).
completion_ids torch.Tensor Right-padded completion token IDs, shape (B, C).
completion_mask torch.Tensor Attention mask for completions, shape (B, C).
advantages torch.Tensor Group-relative advantages, shape (B,).
num_items_in_batch int Total completion tokens across all processes (for DAPO normalization).
old_per_token_logps None Log-probs at generation time, shape (B, C). Present when IS is needed.
ref_per_token_logps None Reference model log-probs, shape (B, C). Present when beta > 0.
importance_sampling_ratio None vLLM IS correction ratios. Present when using vLLM with IS correction.

Outputs (_compute_loss)

Output Type Description
loss torch.Tensor Scalar loss value for the current micro-batch.

Usage Examples

Standard training:

from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward
from datasets import load_dataset

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=GRPOConfig(
        output_dir="./output",
        num_generations=8,
        max_completion_length=512,
        loss_type="dapo",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        num_train_epochs=1,
        learning_rate=1e-6,
    ),
    train_dataset=dataset,
)

output = trainer.train()
print(f"Training loss: {output.training_loss}")

Training with vLLM acceleration:

# Terminal 1: Start vLLM server
trl vllm-serve --model Qwen/Qwen2.5-7B-Instruct

# Terminal 2: Run training
python train.py --use_vllm --vllm_mode server

Loss type comparison:

Loss Type Normalization Length Bias Reference
grpo Per-sequence mean Yes (shorter preferred) DeepSeekMath
dr_grpo batch_size * max_completion_length No Dr. GRPO
dapo (default) Global active tokens in accumulated batch No DAPO
bnpo Local active tokens in micro-batch No (minor variation) --
cispo Global active tokens (clipped IS weights) No MiniMax-M1
sapo Per-sequence mean (sigmoid gating) Yes (like grpo) SAPO

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment