Implementation:Allenai Open instruct GRPO ExperimentConfig
| Type | Dataclass |
|---|---|
| Source | open_instruct/grpo_utils.py:L23-232
|
| Dependencies | dataclasses, torch, enum |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete dataclass for specifying and validating all hyperparameters of a GRPO training run, provided by the Open Instruct library.
Description
ExperimentConfig is a comprehensive Python dataclass that centralizes all configuration for GRPO training. It includes parameters for optimization, algorithm selection, distributed training, checkpointing, experiment tracking, and AI2-specific infrastructure settings.
The dataclass includes a thorough __post_init__ method that validates parameter combinations and computes derived values. It handles:
- Incompatible parameter detection (e.g., vLLM logprobs with truncated importance sampling).
- Loss denominator validation.
- Checkpoint directory setup and calibration.
- Google Cloud Storage checkpoint path construction.
- Reference policy consistency checks.
- Automatic download of latest checkpoints from GCS for resumption.
The GRPOLossType enum provides type-safe loss function selection between DAPO and CISPO variants.
Usage
Instantiated from command-line arguments at the beginning of a GRPO training run. The instance is passed to virtually every component in the pipeline: learner actors, data preparation actor, the main training loop, checkpointing functions, and evaluation logic.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/grpo_utils.py
Signature
class GRPOLossType(enum.StrEnum):
dapo = "dapo"
cispo = "cispo"
@dataclass
class ExperimentConfig:
# Experiment
exp_name: str = "grpo"
seed: int = 1
run_name: str | None = None
# Optimizer
learning_rate: float = 2e-5
lr_scheduler_type: Literal["linear", "cosine", "cosine_with_restarts",
"polynomial", "constant", "constant_with_warmup"] = "linear"
warm_up_steps: int = 0
warmup_ratio: float = 0.0
weight_decay: float = 0.0
max_grad_norm: float = 1.0
# Batch sizes
per_device_train_batch_size: int = 1
total_episodes: int = 100000
world_size: int | None = None
num_training_steps: int | None = None
# Algorithm
num_epochs: int = 1
num_mini_batches: int = 1
beta: float = 0.05
clip_lower: float = 0.2
clip_higher: float = 0.2
kl_estimator: Literal[0, 1, 2, 3] = 2
loss_fn: GRPOLossType = GRPOLossType.dapo
alpha: float = 0.6
ref_policy_update_freq: int | None = None
load_ref_policy: bool = True
# Ray / Distributed
num_learners_per_node: list[int] = field(default_factory=lambda: [1])
num_nodes: int = 1
deepspeed_stage: int = 0
deepspeed_zpg: int = 8
sequence_parallel_size: int = 1
gather_whole_model: bool = True
# Checkpointing
save_freq: int = 200
output_dir: str = "output"
push_to_hub: bool = True
keep_last_n_checkpoints: int = 3
checkpoint_state_freq: int = -1
checkpoint_state_dir: str | None = None
# Experiment tracking
with_tracking: bool = False
wandb_project_name: str = "open_instruct_internal"
local_eval_every: int = 100
verbose: bool = False
# ... additional fields omitted for brevity
Import
from open_instruct.grpo_utils import ExperimentConfig, GRPOLossType
I/O Contract
Key Fields
| Field | Type | Default | Description |
|---|---|---|---|
exp_name |
str |
"grpo" | Name of the experiment for logging and output directory naming. |
seed |
int |
1 | Random seed for reproducibility across all random number generators. |
learning_rate |
float |
2e-5 | Initial learning rate for the AdamW optimizer. |
beta |
float |
0.05 | KL penalty coefficient. Set to 0 to disable KL regularization. |
clip_lower |
float |
0.2 | Lower bound of the PPO clipping range. |
clip_higher |
float |
0.2 | Upper bound of the PPO clipping range. Set higher than clip_lower for DAPO-style asymmetric clipping. |
loss_fn |
GRPOLossType |
dapo | Loss function variant: "dapo" (asymmetric PPO clip) or "cispo" (clipped importance sampling). |
num_mini_batches |
int |
1 | Number of mini-batches to split each rollout batch into for gradient accumulation. |
deepspeed_stage |
int |
0 | DeepSpeed ZeRO stage (0, 2, or 3). |
num_learners_per_node |
list[int] |
[1] | Number of learner GPUs per node (supports heterogeneous configurations). |
save_freq |
int |
200 | Save a lightweight model checkpoint every N training steps. |
local_eval_every |
int |
100 | Run in-loop evaluation every N training steps. Set to -1 to disable. |
max_grad_norm |
float |
1.0 | Maximum gradient norm for gradient clipping. |
load_ref_policy |
bool |
True | Whether to load a reference policy for KL penalty computation. |
alpha |
float |
0.6 | Polyak averaging coefficient for reference policy updates. |
Runtime Fields (Computed)
| Field | Description |
|---|---|
world_size |
Total number of learner GPUs (sum of num_learners_per_node).
|
num_training_steps |
Total training steps (derived from total_episodes and batch size).
|
run_name |
Unique run identifier generated from experiment name and timestamp. |
temperature |
Sampling temperature, copied from streaming_config at runtime.
|
Usage Examples
from open_instruct.grpo_utils import ExperimentConfig, GRPOLossType
# Standard GRPO configuration
config = ExperimentConfig(
exp_name="olmo_7b_gsm8k",
seed=42,
learning_rate=1e-6,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
beta=0.05,
clip_lower=0.2,
clip_higher=0.28, # DAPO-style asymmetric
loss_fn=GRPOLossType.dapo,
num_mini_batches=2,
per_device_train_batch_size=1,
total_episodes=50000,
deepspeed_stage=2,
num_learners_per_node=[4, 4], # 4 learners on each of 2 nodes
num_nodes=2,
save_freq=100,
local_eval_every=50,
output_dir="/output/olmo_7b_gsm8k",
with_tracking=True,
wandb_project_name="grpo_experiments",
)
# Configuration without reference model (pure reward maximization)
config_no_ref = ExperimentConfig(
beta=0.0,
load_ref_policy=False,
loss_fn=GRPOLossType.cispo,
)
# Access derived values
print(f"World size: {sum(config.num_learners_per_node)}")