Implementation:Huggingface Open r1 Get Callbacks
Overview
Concrete tool for resolving callback names to TrainerCallback instances from Open-R1's callback registry.
Description
The get_callbacks function resolves callback names from config to TrainerCallback instances. The CALLBACKS registry currently contains one entry: "push_to_hub_revision" -> PushToHubRevisionCallback. PushToHubRevisionCallback implements on_save to push checkpoints to unique Hub branches and optionally trigger benchmark evaluation via Slurm. The callback uses a DummyConfig workaround to avoid breaking the accelerator distributed state when modifying training arguments.
Usage
Import when initializing SFTTrainer or GRPOTrainer to add configured callbacks.
Code Reference
- Source
- Repository:
open-r1, File:src/open_r1/utils/callbacks.py, Lines: L43-92
- Signature
def get_callbacks(train_config, model_config) -> List[TrainerCallback]:
"""Resolve callback names from config to TrainerCallback instances."""
class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, model_config) -> None:
def on_save(self, args, state, control, **kwargs):
- Import
from open_r1.utils.callbacks import get_callbacks
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| train_config | SFTConfig or GRPOConfig |
Yes | Training configuration with callbacks list of strings
|
| model_config | ModelConfig |
Yes | Model configuration with trust_remote_code
|
Outputs
| Type | Description |
|---|---|
list[TrainerCallback] |
Callback instances registered in training loop |
Usage Examples
Configuring callbacks in YAML:
# sft_config.yaml callbacks: - push_to_hub_revision
Resolving callbacks in training script:
from open_r1.utils.callbacks import get_callbacks
callbacks = get_callbacks(train_config, model_config)
trainer = SFTTrainer(
model=model,
args=train_config,
callbacks=callbacks,
...
)