Implementation:Huggingface Trl GRPO Model Kwargs Assembly
| Property | Value |
|---|---|
| Implementation Name | GRPO Model Kwargs Assembly |
| Library | Huggingface TRL |
| Type | Pattern Doc |
| Source Files | trl/scripts/grpo.py (L119-133), trl/trainer/grpo_trainer.py (L270-275)
|
| Import | from transformers import AutoModelForCausalLM
|
Overview
Description
This implementation documents the pattern by which model initialization keyword arguments are assembled in the GRPO training script and then consumed by the GRPOTrainer during deferred model loading. The script constructs a dictionary from the user's ModelConfig (revision, attention implementation, dtype) and optionally adds quantization settings, then attaches it to GRPOConfig.model_init_kwargs.
Usage
from trl import GRPOConfig, GRPOTrainer, ModelConfig, get_quantization_config, get_kbit_device_map
import torch
# Build model kwargs from ModelConfig
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config
# Attach to training config
training_args.model_init_kwargs = model_kwargs
# Trainer loads model from string path using these kwargs
trainer = GRPOTrainer(
model=model_args.model_name_or_path, # string path, not a model object
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset["train"],
)
Code Reference
Source Location
| Component | File | Lines |
|---|---|---|
| Kwargs assembly | trl/scripts/grpo.py |
L119-133 |
| Deferred model creation | trl/trainer/grpo_trainer.py |
L270-275 |
| Reference model creation | trl/trainer/grpo_trainer.py |
L556-570 |
Pattern
Step 1: Assemble kwargs in the script
# trl/scripts/grpo.py L119-133
dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
dtype=dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config
training_args.model_init_kwargs = model_kwargs
Step 2: Trainer consumes kwargs during init
# trl/trainer/grpo_trainer.py L270-275
if isinstance(model, str):
model_init_kwargs = args.model_init_kwargs or {}
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
model = create_model_from_path(model, **model_init_kwargs)
Step 3: Reference model created from the same config
# trl/trainer/grpo_trainer.py L556-570
if self.beta == 0.0:
self.ref_model = None # No reference model needed
elif is_peft_model(model):
self.ref_model = None # PEFT uses adapter disable for reference
else:
model_init_kwargs = args.model_init_kwargs or {}
if self.args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
self.ref_model = create_model_from_path(
get_config_model_id(self.model.config), **model_init_kwargs
)
Import
from transformers import AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer, ModelConfig
from trl import get_quantization_config, get_kbit_device_map
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
model_args.model_revision |
str |
Git revision (branch/tag/commit) of the model to load. |
model_args.attn_implementation |
str |
Attention implementation to use (e.g., "flash_attention_2", "sdpa").
|
model_args.dtype |
str |
Model dtype as string (e.g., "bfloat16", "auto").
|
model_args.model_name_or_path |
str |
HuggingFace model ID or local directory path. |
quantization_config |
None | Quantization settings for QLoRA, or None.
|
Outputs
| Output | Type | Description |
|---|---|---|
model |
PreTrainedModel |
The instantiated causal language model on the correct device(s). |
ref_model |
None | The reference model (if beta > 0 and not using PEFT), or None.
|
Usage Examples
Standard loading without quantization:
training_args = GRPOConfig(output_dir="./output")
training_args.model_init_kwargs = {
"revision": "main",
"attn_implementation": "flash_attention_2",
"dtype": torch.bfloat16,
}
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B-Instruct",
reward_funcs=accuracy_reward,
args=training_args,
train_dataset=dataset,
)
QLoRA loading with 4-bit quantization:
from transformers import BitsAndBytesConfig
training_args = GRPOConfig(output_dir="./output")
training_args.model_init_kwargs = {
"revision": "main",
"dtype": "auto",
"device_map": {"": 0},
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
}