Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct One Training Step

From Leeroopedia


Type Function
Source open_instruct/grpo_fast.py:L1455-1695
Dependencies ray, deepspeed, torch, wandb, numpy, pandas
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete function for executing one complete GRPO training step including distributed training, metric aggregation, checkpointing, and evaluation, provided by the Open Instruct library.

Description

one_training_step() orchestrates a single training iteration in the GRPO pipeline. It:

  1. Calls policy_group.models[i].step.remote() on all learner Ray actors to execute the actual forward/backward pass.
  2. Collects metrics and array metrics from all workers.
  3. If configured, triggers reference policy updates via Polyak averaging on all workers.
  4. Calls maybe_save_checkpoint() to conditionally save intermediate checkpoints.
  5. Reports training step timing to the ActorManager.
  6. Computes token-weighted averages for training metrics across workers.
  7. Calculates utilization metrics (MFU, tokens per second, etc.).
  8. Logs all metrics to Weights & Biases if tracking is enabled.
  9. Returns the total number of tokens processed in the step.

The function delegates the actual gradient computation to the PolicyTrainerRayProcess.step() method on each worker, which handles mini-batch iteration, gradient accumulation, loss computation, and optimizer steps.

Usage

Called once per iteration in the main training loop. This function runs on the head node and coordinates all learner workers via Ray remote calls.

Code Reference

Source Location

Signature

def one_training_step(
    args: grpo_utils.ExperimentConfig,
    streaming_config: data_loader_lib.StreamingDataLoaderConfig,
    policy_group: ModelGroup,
    tokenizer: PreTrainedTokenizer,
    data_thread_metrics: dict[str, Any],
    episode: int,
    training_step: int,
    num_total_tokens: int,
    start_time: float,
    train_dataset: datasets.Dataset,
    training_start_time: float,
    wandb_url: str,
    chat_template_name: str,
    model_dims: utils.ModelDims,
    actor_manager: ActorManager | None = None,
) -> int:

Import

from open_instruct.grpo_fast import one_training_step

I/O Contract

Inputs

Name Type Description
args ExperimentConfig Experiment configuration with training hyperparameters.
streaming_config StreamingDataLoaderConfig Streaming generation configuration.
policy_group ModelGroup Group of Ray actor handles for all learner processes.
tokenizer PreTrainedTokenizer Tokenizer for decoding (used in evaluation).
data_thread_metrics dict[str, Any] Metrics from the data preparation thread.
episode int Current episode number (total completions processed).
training_step int Current training step index.
num_total_tokens int Running total of tokens processed so far.
start_time float Time the current step started.
train_dataset Dataset Training dataset (for epoch computation).
training_start_time float Time training began (for total time metrics).
wandb_url str W&B run URL for linking in checkpoints.
chat_template_name str Name of the chat template for saving.
model_dims ModelDims Model dimensions for MFU computation.
actor_manager None Actor manager for reporting timing stats.

Outputs

Name Type Description
Return value int Number of tokens processed in this training step (sum of prompt and response tokens). Returns 0 if the batch was empty after packing.

Key Metrics Logged

Metric Description
loss/policy_avg Token-weighted average policy gradient loss.
loss/kl_avg Token-weighted average KL penalty.
loss/total_avg Token-weighted average total loss (policy + KL).
policy/clipfrac_avg Fraction of tokens where the ratio was clipped.
val/ratio Token-weighted average importance sampling ratio.
time/training Wall clock time for the training step (excluding generation).
learner_tokens_per_second_step Training throughput in tokens per second for this step.

Usage Examples

import time
from open_instruct.grpo_fast import one_training_step

# Inside the main training loop:
for training_step in range(num_training_steps):
    start_time = time.perf_counter()

    num_step_tokens = one_training_step(
        args=args,
        streaming_config=streaming_config,
        policy_group=policy_group,
        tokenizer=tokenizer,
        data_thread_metrics=data_metrics,
        episode=episode,
        training_step=training_step,
        num_total_tokens=total_tokens,
        start_time=start_time,
        train_dataset=train_dataset,
        training_start_time=training_start_time,
        wandb_url=wandb_url,
        chat_template_name="tulu",
        model_dims=model_dims,
        actor_manager=actor_manager,
    )

    total_tokens += num_step_tokens
    episode += streaming_config.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment