Implementation:Huggingface Trl GRPOTrainer Train Loop
Appearance
| Property | Value |
|---|---|
| Implementation Name | GRPOTrainer Training Loop |
| Library | Huggingface TRL |
| Type | API Doc |
| Source Files | trl/trainer/grpo_trainer.py (L1030-2133), trl/generation/vllm_client.py (L57-288), trl/generation/vllm_generation.py (L168-692)
|
| Import | from trl import GRPOTrainer
|
| Loss Types | grpo, dr_grpo, dapo, bnpo, cispo, sapo |
Overview
Description
The GRPOTrainer training loop orchestrates the three-phase online RL cycle: generation, scoring, and policy update. The implementation overrides several Trainer methods to inject GRPO-specific behavior while leveraging the standard Trainer infrastructure for distributed training, gradient accumulation, logging, and checkpointing.
Usage
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=accuracy_reward,
args=GRPOConfig(output_dir="./output"),
train_dataset=dataset,
)
train_output = trainer.train()
Code Reference
Source Location
| Method | File | Lines |
|---|---|---|
training_step |
trl/trainer/grpo_trainer.py |
L1030-1039 |
_prepare_inputs |
trl/trainer/grpo_trainer.py |
L1041-1071 |
_generate_and_score_completions |
trl/trainer/grpo_trainer.py |
L1524-1868 |
_compute_loss |
trl/trainer/grpo_trainer.py |
L1953-2133 |
_calculate_rewards |
trl/trainer/grpo_trainer.py |
L1073-1153 |
_generate |
trl/trainer/grpo_trainer.py |
L1432-1522 |
_generate_single_turn |
trl/trainer/grpo_trainer.py |
L1155-1264 |
VLLMClient |
trl/generation/vllm_client.py |
L57-288 |
Key Method Signatures
def training_step(self, model, inputs, num_items_in_batch):
"""
Wraps parent training_step to track step count and timing.
Increments self._step for batch buffering logic.
"""
def _prepare_inputs(self, generation_batch: dict) -> dict:
"""
In training: receives generation batch (per_device_batch_size * steps_per_generation),
generates and scores completions, splits into sub-batches, returns current slice.
In eval: generates and scores for the single eval batch.
"""
def _generate_and_score_completions(self, inputs: list[dict]) -> dict:
"""
Full generation-scoring pipeline:
1. Generate completions (transformers or vLLM)
2. Compute rewards via all reward functions
3. Compute group-relative advantages
4. Compute old_per_token_logps for importance sampling
5. Compute ref_per_token_logps for KL penalty
Returns dict with prompt_ids, completion_ids, advantages, etc.
"""
def _compute_loss(self, model, inputs) -> torch.Tensor:
"""
Computes the GRPO policy gradient loss:
1. Forward pass for per_token_logps and entropies
2. Importance sampling ratio: exp(log_pi_new - log_pi_old)
3. Clipped surrogate objective (varies by loss_type)
4. Optional KL penalty (beta * per_token_kl)
5. Loss aggregation (varies by loss_type)
"""
class VLLMClient:
def __init__(
self,
base_url: str | None = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0,
):
"""Client to interact with a TRL vLLM server."""
def generate(
self,
prompts: list,
images: list | None = None,
n: int = 1,
temperature: float = 1.0,
max_tokens: int = 16,
**kwargs,
) -> dict:
"""
Generate completions from the vLLM server.
Returns dict with 'prompt_ids', 'completion_ids', 'logprobs'.
"""
Import
from trl import GRPOTrainer
from trl.generation.vllm_client import VLLMClient
I/O Contract
Inputs (training_step)
| Parameter | Type | Description |
|---|---|---|
model |
nn.Module |
The policy model (possibly wrapped by distributed framework). |
inputs |
dict |
Generation batch from the dataloader, containing "prompt" and optional metadata columns.
|
num_items_in_batch |
int |
Number of items for loss scaling. |
Outputs (_generate_and_score_completions)
| Key | Type | Description |
|---|---|---|
prompt_ids |
torch.Tensor |
Left-padded prompt token IDs, shape (B, P).
|
prompt_mask |
torch.Tensor |
Attention mask for prompts, shape (B, P).
|
completion_ids |
torch.Tensor |
Right-padded completion token IDs, shape (B, C).
|
completion_mask |
torch.Tensor |
Attention mask for completions, shape (B, C).
|
advantages |
torch.Tensor |
Group-relative advantages, shape (B,).
|
num_items_in_batch |
int |
Total completion tokens across all processes (for DAPO normalization). |
old_per_token_logps |
None | Log-probs at generation time, shape (B, C). Present when IS is needed.
|
ref_per_token_logps |
None | Reference model log-probs, shape (B, C). Present when beta > 0.
|
importance_sampling_ratio |
None | vLLM IS correction ratios. Present when using vLLM with IS correction. |
Outputs (_compute_loss)
| Output | Type | Description |
|---|---|---|
| loss | torch.Tensor |
Scalar loss value for the current micro-batch. |
Usage Examples
Standard training:
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward
from datasets import load_dataset
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=accuracy_reward,
args=GRPOConfig(
output_dir="./output",
num_generations=8,
max_completion_length=512,
loss_type="dapo",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
num_train_epochs=1,
learning_rate=1e-6,
),
train_dataset=dataset,
)
output = trainer.train()
print(f"Training loss: {output.training_loss}")
Training with vLLM acceleration:
# Terminal 1: Start vLLM server
trl vllm-serve --model Qwen/Qwen2.5-7B-Instruct
# Terminal 2: Run training
python train.py --use_vllm --vllm_mode server
Loss type comparison:
| Loss Type | Normalization | Length Bias | Reference |
|---|---|---|---|
grpo |
Per-sequence mean | Yes (shorter preferred) | DeepSeekMath |
dr_grpo |
batch_size * max_completion_length |
No | Dr. GRPO |
dapo (default) |
Global active tokens in accumulated batch | No | DAPO |
bnpo |
Local active tokens in micro-batch | No (minor variation) | -- |
cispo |
Global active tokens (clipped IS weights) | No | MiniMax-M1 |
sapo |
Per-sequence mean (sigmoid gating) | Yes (like grpo) | SAPO |
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment