Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Accelerate RFT Trainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Fine_Tuning
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for rejection fine-tuning (RFT) that generates multiple completions per prompt, scores them, and trains on the best ones using progressive percentile thresholding.

Description

The AccelerateRFTTrainer extends AccelerateRLTrainer to implement rejection fine-tuning. In each epoch, it generates n_generations_per_prompt completions per prompt, scores them with the reward function, filters to those above a percentile threshold, and trains on the selected completions using standard cross-entropy loss. The percentile threshold increases linearly from start_percentile to end_percentile over n_improve_steps, gradually raising the quality bar. The companion RFTConfig dataclass defines the generation kwargs and progressive thresholding parameters.

Usage

Use this trainer for rejection fine-tuning, where you want to train on the best-of-N completions rather than using policy gradient methods. Registered as "AccelerateRFTTrainer" and automatically selected when using the RFT method config.

Code Reference

Source Location

Signature

@dataclass
@register_method
class RFTConfig(MethodConfig):
    """
    Rejection Fine-Tuning configuration.

    Attributes:
        gen_kwargs: dict        - Generation parameters.
        start_percentile: float - Initial score percentile threshold (default 0.7).
        end_percentile: float   - Final score percentile threshold (default 0.95).
        n_improve_steps: int    - Steps over which percentile grows (default 4).
        n_generations_per_prompt: int - Completions per prompt (default 32).
    """


@register_trainer
class AccelerateRFTTrainer(AccelerateRLTrainer):
    def __init__(self, config: TRLConfig, **kwargs):
        """
        Args:
            config: TRLConfig with RFT method configuration.
        """

    def make_experience(self) -> None:
        """
        Generate completions, score them, filter by percentile threshold,
        and prepare training data from the best completions.
        """

    def loss(self, batch) -> Tuple[torch.Tensor, dict]:
        """Compute cross-entropy loss on filtered best-of-N completions."""

    def post_epoch_callback(self) -> None:
        """Generate new experience with updated percentile threshold after each epoch."""

Import

from trlx.trainer.accelerate_rft_trainer import AccelerateRFTTrainer, RFTConfig

I/O Contract

Inputs

Name Type Required Description
config TRLConfig Yes Full trlx configuration with RFT method config
reward_fn Callable Yes Function scoring completions (higher = better)
prompts List[str] Yes Training prompts
gen_kwargs dict No Generation parameters (temperature, max_new_tokens, etc.)
start_percentile float No Initial score threshold (default 0.7)
end_percentile float No Final score threshold (default 0.95)
n_improve_steps int No Progressive threshold growth steps (default 4)
n_generations_per_prompt int No Completions per prompt (default 32)

Outputs

Name Type Description
loss returns Tuple[Tensor, dict] (cross-entropy loss, logging stats dict)
Trained model AutoModelForCausalLM Model fine-tuned on best-of-N completions

Usage Examples

Train with Rejection Fine-Tuning

import trlx
from trlx.data.default_configs import TRLConfig

# 1. Define reward function
def reward_fn(samples, **kwargs):
    # Score completions (e.g., sentiment classifier)
    return [score_sentiment(s) for s in samples]

# 2. Configure RFT
config = TRLConfig.load_yaml("configs/rft_config.yml")

# 3. Train (RFT generates N completions per prompt, trains on best)
trainer = trlx.train(
    reward_fn=reward_fn,
    prompts=train_prompts,
    eval_prompts=eval_prompts,
    config=config,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment