Implementation:Huggingface Open r1 Get Dataset
Metadata
| Field | Value |
|---|---|
| Sources | Repo: huggingface/open-r1; Doc: HuggingFace Datasets Documentation |
| Domains | NLP, Data_Engineering |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for loading and blending HuggingFace datasets provided by Open-R1. The get_dataset function serves as the primary data ingestion entry point for all training pipelines in the Open-R1 project.
Description
The get_dataset function supports two modes of operation:
- Single dataset loading — When
dataset_nameis specified in the arguments, the function loads a single dataset directly from HuggingFace Hub by name, with an optional split specification. - Weighted mixture loading — When
dataset_mixtureis configured, the function iterates over multiple dataset configurations, loading each one with per-dataset column selection and fractional subsampling. The individual datasets are concatenated, shuffled with a deterministic seed, and optionally split into train and test partitions.
This dual-mode design allows the same function interface to serve both simple single-source training runs and complex multi-source curriculum training setups without requiring different code paths in the training scripts.
Usage
Import get_dataset when you need to load training data for SFT or GRPO training, especially when combining multiple dataset sources. The function is designed to be called once at the start of a training pipeline, returning a DatasetDict ready for consumption by the training loop.
Code Reference
Source
| Field | Value |
|---|---|
| Repository | open-r1 |
| File | src/open_r1/utils/data.py |
| Lines | L12-65 |
Signature
def get_dataset(args: ScriptArguments) -> DatasetDict:
"""Load a dataset or a mixture of datasets based on the configuration.
Args:
args (ScriptArguments): Script arguments containing dataset configuration.
Returns:
DatasetDict: The loaded datasets.
"""
Import
from open_r1.utils import get_dataset
or
from open_r1.utils.data import get_dataset
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| args | ScriptArguments |
Yes | Script arguments containing dataset configuration. Must have either dataset_name or dataset_mixture configured (mutually exclusive).
|
Key fields within ScriptArguments:
| Field | Type | Description |
|---|---|---|
dataset_name |
str |
HuggingFace dataset identifier for single-dataset mode (e.g., "openai/gsm8k").
|
dataset_mixture |
dict |
Dictionary defining multiple datasets with weights, columns, and split configurations. |
dataset_split |
str |
Which split to load (e.g., "train").
|
test_split_size |
float |
Fraction of the combined dataset to reserve as a test split (e.g., 0.1). If not set, no test split is created.
|
seed |
int |
Random seed for shuffling and splitting, ensuring reproducibility. |
Outputs
| Output | Type | Description |
|---|---|---|
| return value | DatasetDict |
A dictionary of datasets. Always contains a "train" split. Optionally contains a "test" split if test_split_size is configured in the arguments.
|
Usage Examples
Example 1: Single Dataset Load
from dataclasses import dataclass
from open_r1.utils import get_dataset
@dataclass
class ScriptArguments:
dataset_name: str = "openai/gsm8k"
dataset_split: str = "train"
dataset_mixture: dict = None
test_split_size: float = 0.0
seed: int = 42
args = ScriptArguments()
dataset_dict = get_dataset(args)
# Access the training split
train_data = dataset_dict["train"]
print(f"Loaded {len(train_data)} training examples")
Example 2: Weighted Mixture of Multiple Datasets
from dataclasses import dataclass, field
from open_r1.utils import get_dataset
@dataclass
class ScriptArguments:
dataset_name: str = None
dataset_split: str = "train"
dataset_mixture: dict = field(default_factory=lambda: {
"openai/gsm8k": {
"weight": 0.6,
"columns": ["question", "answer"],
},
"code_alpaca": {
"weight": 0.3,
"columns": ["prompt", "completion"],
},
"tatsu-lab/alpaca": {
"weight": 0.1,
"columns": ["instruction", "output"],
},
})
test_split_size: float = 0.05
seed: int = 42
args = ScriptArguments()
dataset_dict = get_dataset(args)
train_data = dataset_dict["train"]
test_data = dataset_dict["test"]
print(f"Train: {len(train_data)} examples, Test: {len(test_data)} examples")