Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Unslothai Unsloth PatchFastRL GRPOTrainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Optimization
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for patching TRL's GRPOTrainer with Unsloth's memory-efficient RL optimizations and running GRPO training.

Description

The GRPO training implementation consists of two components:

  1. PatchFastRL: A function that dynamically rewrites TRL's GRPOTrainer class methods with Unsloth-optimized replacements. This includes replacing compute_loss, _get_per_token_logps, and other critical methods defined in rl_replacements.py.
  2. GRPOTrainer (patched): After patching, the standard TRL GRPOTrainer is used with Unsloth's optimizations active, including chunked gradient accumulation, optimized log-probability computation, and model inference/training mode switching.

The patching system (unsloth/models/rl.py) dynamically inspects TRL's source code and rewrites functions to inject Unsloth optimizations, supporting multiple loss types: grpo, bnpo, dr_grpo, dapo, and cispo.

Usage

Call PatchFastRL("grpo") once before creating the GRPOTrainer. This must be done after model loading but before trainer instantiation. Then use GRPOTrainer normally with GRPOConfig.

Code Reference

Source Location

  • Repository: unsloth
  • File: unsloth/models/rl.py
  • Lines: L1597-1603 (PatchFastRL), L88-261 (PatchRL for mode switching), L1262-1574 (patch_functions for dynamic rewriting)
  • File: unsloth/models/rl_replacements.py
  • Lines: L1-1303 (replacement compute_loss, _get_per_token_logps, etc.)

Signature

def PatchFastRL(algorithm=None, FastLanguageModel=None):
    """
    Patches TRL trainer classes with Unsloth optimizations.

    Args:
        algorithm (str): RL algorithm to patch. Options: "grpo", "dpo", "sft",
            "kto", "prm", "xpo", "cpo", "orpo", "reward", "bco".
            Lowercased internally. Default None.
        FastLanguageModel: Optional FastLanguageModel class reference.
            Auto-detected if None.
    """

# After patching, use TRL's GRPOTrainer normally:
from trl import GRPOTrainer, GRPOConfig

Import

from unsloth import PatchFastRL
from trl import GRPOTrainer, GRPOConfig

I/O Contract

Inputs

Name Type Required Description
algorithm str Yes "grpo" for GRPO training
model PeftModel Yes vLLM-enabled LoRA model (for GRPOTrainer)
reward_funcs list[Callable] Yes Reward functions scoring completions
train_dataset Dataset Yes Dataset with "prompt" column
args.num_generations int No Completions per prompt (default varies)
args.max_completion_length int No Max generation tokens
args.beta float No KL penalty coefficient
args.loss_type str No "grpo", "bnpo", "dr_grpo", "dapo", "cispo"
args.temperature float No Sampling temperature for generation
args.unsloth_num_chunks int No Memory-efficient gradient accumulation chunks

Outputs

Name Type Description
trainer.train() returns TrainOutput Training metrics: completion_length, kl, clip_ratio, reward per function
model PeftModel Trained model with RL-optimized LoRA weights

Usage Examples

GRPO Training Setup

from unsloth import FastLanguageModel, PatchFastRL
from trl import GRPOTrainer, GRPOConfig

# 1. Load model with vLLM
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-3B-Instruct",
    max_seq_length=4096,
    load_in_4bit=True,
    fast_inference=True,
    gpu_memory_utilization=0.6,
    max_lora_rank=64,
)

# 2. Apply LoRA
model = FastLanguageModel.get_peft_model(model, r=64, lora_alpha=64)

# 3. Patch GRPOTrainer with Unsloth optimizations
PatchFastRL("grpo", FastLanguageModel)

# 4. Define reward functions
def correctness_reward(prompts, completions, answer, **kwargs):
    rewards = []
    for completion, expected in zip(completions, answer):
        rewards.append(1.0 if expected in completion else 0.0)
    return rewards

# 5. Configure and train
config = GRPOConfig(
    output_dir="./grpo_output",
    num_generations=8,
    max_completion_length=512,
    beta=0.04,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    num_train_epochs=1,
)

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[correctness_reward],
    args=config,
    train_dataset=dataset,
)

trainer.train()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment