Implementation:CarperAI Trlx Accelerate RFT Trainer
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Fine_Tuning |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for rejection fine-tuning (RFT) that generates multiple completions per prompt, scores them, and trains on the best ones using progressive percentile thresholding.
Description
The AccelerateRFTTrainer extends AccelerateRLTrainer to implement rejection fine-tuning. In each epoch, it generates n_generations_per_prompt completions per prompt, scores them with the reward function, filters to those above a percentile threshold, and trains on the selected completions using standard cross-entropy loss. The percentile threshold increases linearly from start_percentile to end_percentile over n_improve_steps, gradually raising the quality bar. The companion RFTConfig dataclass defines the generation kwargs and progressive thresholding parameters.
Usage
Use this trainer for rejection fine-tuning, where you want to train on the best-of-N completions rather than using policy gradient methods. Registered as "AccelerateRFTTrainer" and automatically selected when using the RFT method config.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/trainer/accelerate_rft_trainer.py
- Lines: 1-197
Signature
@dataclass
@register_method
class RFTConfig(MethodConfig):
"""
Rejection Fine-Tuning configuration.
Attributes:
gen_kwargs: dict - Generation parameters.
start_percentile: float - Initial score percentile threshold (default 0.7).
end_percentile: float - Final score percentile threshold (default 0.95).
n_improve_steps: int - Steps over which percentile grows (default 4).
n_generations_per_prompt: int - Completions per prompt (default 32).
"""
@register_trainer
class AccelerateRFTTrainer(AccelerateRLTrainer):
def __init__(self, config: TRLConfig, **kwargs):
"""
Args:
config: TRLConfig with RFT method configuration.
"""
def make_experience(self) -> None:
"""
Generate completions, score them, filter by percentile threshold,
and prepare training data from the best completions.
"""
def loss(self, batch) -> Tuple[torch.Tensor, dict]:
"""Compute cross-entropy loss on filtered best-of-N completions."""
def post_epoch_callback(self) -> None:
"""Generate new experience with updated percentile threshold after each epoch."""
Import
from trlx.trainer.accelerate_rft_trainer import AccelerateRFTTrainer, RFTConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | TRLConfig | Yes | Full trlx configuration with RFT method config |
| reward_fn | Callable | Yes | Function scoring completions (higher = better) |
| prompts | List[str] | Yes | Training prompts |
| gen_kwargs | dict | No | Generation parameters (temperature, max_new_tokens, etc.) |
| start_percentile | float | No | Initial score threshold (default 0.7) |
| end_percentile | float | No | Final score threshold (default 0.95) |
| n_improve_steps | int | No | Progressive threshold growth steps (default 4) |
| n_generations_per_prompt | int | No | Completions per prompt (default 32) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss returns | Tuple[Tensor, dict] | (cross-entropy loss, logging stats dict) |
| Trained model | AutoModelForCausalLM | Model fine-tuned on best-of-N completions |
Usage Examples
Train with Rejection Fine-Tuning
import trlx
from trlx.data.default_configs import TRLConfig
# 1. Define reward function
def reward_fn(samples, **kwargs):
# Score completions (e.g., sentiment classifier)
return [score_sentiment(s) for s in samples]
# 2. Configure RFT
config = TRLConfig.load_yaml("configs/rft_config.yml")
# 3. Train (RFT generates N completions per prompt, trains on best)
trainer = trlx.train(
reward_fn=reward_fn,
prompts=train_prompts,
eval_prompts=eval_prompts,
config=config,
)