Implementation:Unslothai Unsloth PatchFastRL GRPOTrainer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Optimization |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for patching TRL's GRPOTrainer with Unsloth's memory-efficient RL optimizations and running GRPO training.
Description
The GRPO training implementation consists of two components:
- PatchFastRL: A function that dynamically rewrites TRL's GRPOTrainer class methods with Unsloth-optimized replacements. This includes replacing compute_loss, _get_per_token_logps, and other critical methods defined in rl_replacements.py.
- GRPOTrainer (patched): After patching, the standard TRL GRPOTrainer is used with Unsloth's optimizations active, including chunked gradient accumulation, optimized log-probability computation, and model inference/training mode switching.
The patching system (unsloth/models/rl.py) dynamically inspects TRL's source code and rewrites functions to inject Unsloth optimizations, supporting multiple loss types: grpo, bnpo, dr_grpo, dapo, and cispo.
Usage
Call PatchFastRL("grpo") once before creating the GRPOTrainer. This must be done after model loading but before trainer instantiation. Then use GRPOTrainer normally with GRPOConfig.
Code Reference
Source Location
- Repository: unsloth
- File: unsloth/models/rl.py
- Lines: L1597-1603 (PatchFastRL), L88-261 (PatchRL for mode switching), L1262-1574 (patch_functions for dynamic rewriting)
- File: unsloth/models/rl_replacements.py
- Lines: L1-1303 (replacement compute_loss, _get_per_token_logps, etc.)
Signature
def PatchFastRL(algorithm=None, FastLanguageModel=None):
"""
Patches TRL trainer classes with Unsloth optimizations.
Args:
algorithm (str): RL algorithm to patch. Options: "grpo", "dpo", "sft",
"kto", "prm", "xpo", "cpo", "orpo", "reward", "bco".
Lowercased internally. Default None.
FastLanguageModel: Optional FastLanguageModel class reference.
Auto-detected if None.
"""
# After patching, use TRL's GRPOTrainer normally:
from trl import GRPOTrainer, GRPOConfig
Import
from unsloth import PatchFastRL
from trl import GRPOTrainer, GRPOConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| algorithm | str | Yes | "grpo" for GRPO training |
| model | PeftModel | Yes | vLLM-enabled LoRA model (for GRPOTrainer) |
| reward_funcs | list[Callable] | Yes | Reward functions scoring completions |
| train_dataset | Dataset | Yes | Dataset with "prompt" column |
| args.num_generations | int | No | Completions per prompt (default varies) |
| args.max_completion_length | int | No | Max generation tokens |
| args.beta | float | No | KL penalty coefficient |
| args.loss_type | str | No | "grpo", "bnpo", "dr_grpo", "dapo", "cispo" |
| args.temperature | float | No | Sampling temperature for generation |
| args.unsloth_num_chunks | int | No | Memory-efficient gradient accumulation chunks |
Outputs
| Name | Type | Description |
|---|---|---|
| trainer.train() returns | TrainOutput | Training metrics: completion_length, kl, clip_ratio, reward per function |
| model | PeftModel | Trained model with RL-optimized LoRA weights |
Usage Examples
GRPO Training Setup
from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOTrainer, GRPOConfig
# 1. Load model with vLLM
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="Qwen/Qwen2.5-3B-Instruct",
max_seq_length=4096,
load_in_4bit=True,
fast_inference=True,
gpu_memory_utilization=0.6,
max_lora_rank=64,
)
# 2. Apply LoRA
model = FastLanguageModel.get_peft_model(model, r=64, lora_alpha=64)
# 3. Patch GRPOTrainer with Unsloth optimizations
PatchFastRL("grpo", FastLanguageModel)
# 4. Define reward functions
def correctness_reward(prompts, completions, answer, **kwargs):
rewards = []
for completion, expected in zip(completions, answer):
rewards.append(1.0 if expected in completion else 0.0)
return rewards
# 5. Configure and train
config = GRPOConfig(
output_dir="./grpo_output",
num_generations=8,
max_completion_length=512,
beta=0.04,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=5e-6,
num_train_epochs=1,
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[correctness_reward],
args=config,
train_dataset=dataset,
)
trainer.train()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment