Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl GRPO Model Kwargs Assembly

From Leeroopedia


Property Value
Implementation Name GRPO Model Kwargs Assembly
Library Huggingface TRL
Type Pattern Doc
Source Files trl/scripts/grpo.py (L119-133), trl/trainer/grpo_trainer.py (L270-275)
Import from transformers import AutoModelForCausalLM

Overview

Description

This implementation documents the pattern by which model initialization keyword arguments are assembled in the GRPO training script and then consumed by the GRPOTrainer during deferred model loading. The script constructs a dictionary from the user's ModelConfig (revision, attention implementation, dtype) and optionally adds quantization settings, then attaches it to GRPOConfig.model_init_kwargs.

Usage

from trl import GRPOConfig, GRPOTrainer, ModelConfig, get_quantization_config, get_kbit_device_map
import torch

# Build model kwargs from ModelConfig
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
model_kwargs = dict(
    revision=model_args.model_revision,
    attn_implementation=model_args.attn_implementation,
    dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
    model_kwargs["device_map"] = get_kbit_device_map()
    model_kwargs["quantization_config"] = quantization_config

# Attach to training config
training_args.model_init_kwargs = model_kwargs

# Trainer loads model from string path using these kwargs
trainer = GRPOTrainer(
    model=model_args.model_name_or_path,  # string path, not a model object
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset["train"],
)

Code Reference

Source Location

Component File Lines
Kwargs assembly trl/scripts/grpo.py L119-133
Deferred model creation trl/trainer/grpo_trainer.py L270-275
Reference model creation trl/trainer/grpo_trainer.py L556-570

Pattern

Step 1: Assemble kwargs in the script

# trl/scripts/grpo.py L119-133
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)

model_kwargs = dict(
    revision=model_args.model_revision,
    attn_implementation=model_args.attn_implementation,
    dtype=dtype,
)
quantization_config = get_quantization_config(model_args)

if quantization_config is not None:
    model_kwargs["device_map"] = get_kbit_device_map()
    model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs

Step 2: Trainer consumes kwargs during init

# trl/trainer/grpo_trainer.py L270-275
if isinstance(model, str):
    model_init_kwargs = args.model_init_kwargs or {}
    # Distributed training requires device_map=None ("auto" fails)
    if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
        model_init_kwargs["device_map"] = None
    model = create_model_from_path(model, **model_init_kwargs)

Step 3: Reference model created from the same config

# trl/trainer/grpo_trainer.py L556-570
if self.beta == 0.0:
    self.ref_model = None  # No reference model needed
elif is_peft_model(model):
    self.ref_model = None  # PEFT uses adapter disable for reference
else:
    model_init_kwargs = args.model_init_kwargs or {}
    if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
        model_init_kwargs["device_map"] = None
    self.ref_model = create_model_from_path(
        get_config_model_id(self.model.config), **model_init_kwargs
    )

Import

from transformers import AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer, ModelConfig
from trl import get_quantization_config, get_kbit_device_map

I/O Contract

Inputs

Parameter Type Description
model_args.model_revision str Git revision (branch/tag/commit) of the model to load.
model_args.attn_implementation str Attention implementation to use (e.g., "flash_attention_2", "sdpa").
model_args.dtype str Model dtype as string (e.g., "bfloat16", "auto").
model_args.model_name_or_path str HuggingFace model ID or local directory path.
quantization_config None Quantization settings for QLoRA, or None.

Outputs

Output Type Description
model PreTrainedModel The instantiated causal language model on the correct device(s).
ref_model None The reference model (if beta > 0 and not using PEFT), or None.

Usage Examples

Standard loading without quantization:

training_args = GRPOConfig(output_dir="./output")
training_args.model_init_kwargs = {
    "revision": "main",
    "attn_implementation": "flash_attention_2",
    "dtype": torch.bfloat16,
}

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-7B-Instruct",
    reward_funcs=accuracy_reward,
    args=training_args,
    train_dataset=dataset,
)

QLoRA loading with 4-bit quantization:

from transformers import BitsAndBytesConfig

training_args = GRPOConfig(output_dir="./output")
training_args.model_init_kwargs = {
    "revision": "main",
    "dtype": "auto",
    "device_map": {"": 0},
    "quantization_config": BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment