Implementation:Allenai Open instruct One Training Step
| Type | Function |
|---|---|
| Source | open_instruct/grpo_fast.py:L1455-1695
|
| Dependencies | ray, deepspeed, torch, wandb, numpy, pandas |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete function for executing one complete GRPO training step including distributed training, metric aggregation, checkpointing, and evaluation, provided by the Open Instruct library.
Description
one_training_step() orchestrates a single training iteration in the GRPO pipeline. It:
- Calls
policy_group.models[i].step.remote()on all learner Ray actors to execute the actual forward/backward pass. - Collects metrics and array metrics from all workers.
- If configured, triggers reference policy updates via Polyak averaging on all workers.
- Calls
maybe_save_checkpoint()to conditionally save intermediate checkpoints. - Reports training step timing to the
ActorManager. - Computes token-weighted averages for training metrics across workers.
- Calculates utilization metrics (MFU, tokens per second, etc.).
- Logs all metrics to Weights & Biases if tracking is enabled.
- Returns the total number of tokens processed in the step.
The function delegates the actual gradient computation to the PolicyTrainerRayProcess.step() method on each worker, which handles mini-batch iteration, gradient accumulation, loss computation, and optimizer steps.
Usage
Called once per iteration in the main training loop. This function runs on the head node and coordinates all learner workers via Ray remote calls.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/grpo_fast.py
Signature
def one_training_step(
args: grpo_utils.ExperimentConfig,
streaming_config: data_loader_lib.StreamingDataLoaderConfig,
policy_group: ModelGroup,
tokenizer: PreTrainedTokenizer,
data_thread_metrics: dict[str, Any],
episode: int,
training_step: int,
num_total_tokens: int,
start_time: float,
train_dataset: datasets.Dataset,
training_start_time: float,
wandb_url: str,
chat_template_name: str,
model_dims: utils.ModelDims,
actor_manager: ActorManager | None = None,
) -> int:
Import
from open_instruct.grpo_fast import one_training_step
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
args |
ExperimentConfig |
Experiment configuration with training hyperparameters. |
streaming_config |
StreamingDataLoaderConfig |
Streaming generation configuration. |
policy_group |
ModelGroup |
Group of Ray actor handles for all learner processes. |
tokenizer |
PreTrainedTokenizer |
Tokenizer for decoding (used in evaluation). |
data_thread_metrics |
dict[str, Any] |
Metrics from the data preparation thread. |
episode |
int |
Current episode number (total completions processed). |
training_step |
int |
Current training step index. |
num_total_tokens |
int |
Running total of tokens processed so far. |
start_time |
float |
Time the current step started. |
train_dataset |
Dataset |
Training dataset (for epoch computation). |
training_start_time |
float |
Time training began (for total time metrics). |
wandb_url |
str |
W&B run URL for linking in checkpoints. |
chat_template_name |
str |
Name of the chat template for saving. |
model_dims |
ModelDims |
Model dimensions for MFU computation. |
actor_manager |
None | Actor manager for reporting timing stats. |
Outputs
| Name | Type | Description |
|---|---|---|
| Return value | int |
Number of tokens processed in this training step (sum of prompt and response tokens). Returns 0 if the batch was empty after packing. |
Key Metrics Logged
| Metric | Description |
|---|---|
loss/policy_avg |
Token-weighted average policy gradient loss. |
loss/kl_avg |
Token-weighted average KL penalty. |
loss/total_avg |
Token-weighted average total loss (policy + KL). |
policy/clipfrac_avg |
Fraction of tokens where the ratio was clipped. |
val/ratio |
Token-weighted average importance sampling ratio. |
time/training |
Wall clock time for the training step (excluding generation). |
learner_tokens_per_second_step |
Training throughput in tokens per second for this step. |
Usage Examples
import time
from open_instruct.grpo_fast import one_training_step
# Inside the main training loop:
for training_step in range(num_training_steps):
start_time = time.perf_counter()
num_step_tokens = one_training_step(
args=args,
streaming_config=streaming_config,
policy_group=policy_group,
tokenizer=tokenizer,
data_thread_metrics=data_metrics,
episode=episode,
training_step=training_step,
num_total_tokens=total_tokens,
start_time=start_time,
train_dataset=train_dataset,
training_start_time=training_start_time,
wandb_url=wandb_url,
chat_template_name="tulu",
model_dims=model_dims,
actor_manager=actor_manager,
)
total_tokens += num_step_tokens
episode += streaming_config.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout