Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct GRPO ExperimentConfig

From Leeroopedia


Type Dataclass
Source open_instruct/grpo_utils.py:L23-232
Dependencies dataclasses, torch, enum
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete dataclass for specifying and validating all hyperparameters of a GRPO training run, provided by the Open Instruct library.

Description

ExperimentConfig is a comprehensive Python dataclass that centralizes all configuration for GRPO training. It includes parameters for optimization, algorithm selection, distributed training, checkpointing, experiment tracking, and AI2-specific infrastructure settings.

The dataclass includes a thorough __post_init__ method that validates parameter combinations and computes derived values. It handles:

  • Incompatible parameter detection (e.g., vLLM logprobs with truncated importance sampling).
  • Loss denominator validation.
  • Checkpoint directory setup and calibration.
  • Google Cloud Storage checkpoint path construction.
  • Reference policy consistency checks.
  • Automatic download of latest checkpoints from GCS for resumption.

The GRPOLossType enum provides type-safe loss function selection between DAPO and CISPO variants.

Usage

Instantiated from command-line arguments at the beginning of a GRPO training run. The instance is passed to virtually every component in the pipeline: learner actors, data preparation actor, the main training loop, checkpointing functions, and evaluation logic.

Code Reference

Source Location

Signature

class GRPOLossType(enum.StrEnum):
    dapo = "dapo"
    cispo = "cispo"


@dataclass
class ExperimentConfig:
    # Experiment
    exp_name: str = "grpo"
    seed: int = 1
    run_name: str | None = None

    # Optimizer
    learning_rate: float = 2e-5
    lr_scheduler_type: Literal["linear", "cosine", "cosine_with_restarts",
                               "polynomial", "constant", "constant_with_warmup"] = "linear"
    warm_up_steps: int = 0
    warmup_ratio: float = 0.0
    weight_decay: float = 0.0
    max_grad_norm: float = 1.0

    # Batch sizes
    per_device_train_batch_size: int = 1
    total_episodes: int = 100000
    world_size: int | None = None
    num_training_steps: int | None = None

    # Algorithm
    num_epochs: int = 1
    num_mini_batches: int = 1
    beta: float = 0.05
    clip_lower: float = 0.2
    clip_higher: float = 0.2
    kl_estimator: Literal[0, 1, 2, 3] = 2
    loss_fn: GRPOLossType = GRPOLossType.dapo
    alpha: float = 0.6
    ref_policy_update_freq: int | None = None
    load_ref_policy: bool = True

    # Ray / Distributed
    num_learners_per_node: list[int] = field(default_factory=lambda: [1])
    num_nodes: int = 1
    deepspeed_stage: int = 0
    deepspeed_zpg: int = 8
    sequence_parallel_size: int = 1
    gather_whole_model: bool = True

    # Checkpointing
    save_freq: int = 200
    output_dir: str = "output"
    push_to_hub: bool = True
    keep_last_n_checkpoints: int = 3
    checkpoint_state_freq: int = -1
    checkpoint_state_dir: str | None = None

    # Experiment tracking
    with_tracking: bool = False
    wandb_project_name: str = "open_instruct_internal"
    local_eval_every: int = 100
    verbose: bool = False

    # ... additional fields omitted for brevity

Import

from open_instruct.grpo_utils import ExperimentConfig, GRPOLossType

I/O Contract

Key Fields

Field Type Default Description
exp_name str "grpo" Name of the experiment for logging and output directory naming.
seed int 1 Random seed for reproducibility across all random number generators.
learning_rate float 2e-5 Initial learning rate for the AdamW optimizer.
beta float 0.05 KL penalty coefficient. Set to 0 to disable KL regularization.
clip_lower float 0.2 Lower bound of the PPO clipping range.
clip_higher float 0.2 Upper bound of the PPO clipping range. Set higher than clip_lower for DAPO-style asymmetric clipping.
loss_fn GRPOLossType dapo Loss function variant: "dapo" (asymmetric PPO clip) or "cispo" (clipped importance sampling).
num_mini_batches int 1 Number of mini-batches to split each rollout batch into for gradient accumulation.
deepspeed_stage int 0 DeepSpeed ZeRO stage (0, 2, or 3).
num_learners_per_node list[int] [1] Number of learner GPUs per node (supports heterogeneous configurations).
save_freq int 200 Save a lightweight model checkpoint every N training steps.
local_eval_every int 100 Run in-loop evaluation every N training steps. Set to -1 to disable.
max_grad_norm float 1.0 Maximum gradient norm for gradient clipping.
load_ref_policy bool True Whether to load a reference policy for KL penalty computation.
alpha float 0.6 Polyak averaging coefficient for reference policy updates.

Runtime Fields (Computed)

Field Description
world_size Total number of learner GPUs (sum of num_learners_per_node).
num_training_steps Total training steps (derived from total_episodes and batch size).
run_name Unique run identifier generated from experiment name and timestamp.
temperature Sampling temperature, copied from streaming_config at runtime.

Usage Examples

from open_instruct.grpo_utils import ExperimentConfig, GRPOLossType

# Standard GRPO configuration
config = ExperimentConfig(
    exp_name="olmo_7b_gsm8k",
    seed=42,
    learning_rate=1e-6,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    beta=0.05,
    clip_lower=0.2,
    clip_higher=0.28,       # DAPO-style asymmetric
    loss_fn=GRPOLossType.dapo,
    num_mini_batches=2,
    per_device_train_batch_size=1,
    total_episodes=50000,
    deepspeed_stage=2,
    num_learners_per_node=[4, 4],  # 4 learners on each of 2 nodes
    num_nodes=2,
    save_freq=100,
    local_eval_every=50,
    output_dir="/output/olmo_7b_gsm8k",
    with_tracking=True,
    wandb_project_name="grpo_experiments",
)

# Configuration without reference model (pure reward maximization)
config_no_ref = ExperimentConfig(
    beta=0.0,
    load_ref_policy=False,
    loss_fn=GRPOLossType.cispo,
)

# Access derived values
print(f"World size: {sum(config.num_learners_per_node)}")

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment