Implementation:Huggingface Trl Get Dataset GRPO
| Property | Value |
|---|---|
| Implementation Name | Get Dataset for GRPO |
| Library | Huggingface TRL |
| Type | API Doc |
| Source Files | trl/scripts/grpo.py (L136-149), trl/scripts/utils.py (L414-474)
|
| Import | from trl.scripts.utils import get_dataset; from trl import DatasetMixtureConfig
|
Overview
Description
The get_dataset function loads and combines multiple datasets into a single DatasetDict based on a DatasetMixtureConfig. In the GRPO script, dataset loading supports two paths: the mixture config path (via get_dataset) and the simple single-dataset path (via datasets.load_dataset directly). The mixture config path takes priority when both are specified.
Usage
from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig
# Mixture of datasets
mixture_config = DatasetMixtureConfig(
datasets=[
DatasetConfig(path="trl-lib/DeepMath-103K", split="train"),
DatasetConfig(path="my-org/custom-prompts", split="train", columns=["prompt", "solution"]),
],
test_split_size=0.05,
)
dataset = get_dataset(mixture_config)
# Returns DatasetDict with "train" and "test" splits
Code Reference
Source Location
| Function | File | Lines |
|---|---|---|
get_dataset |
trl/scripts/utils.py |
L414-474 |
| Dataset loading in GRPO script | trl/scripts/grpo.py |
L136-149 |
DatasetMixtureConfig |
trl/scripts/utils.py |
L90-152 |
DatasetConfig |
trl/scripts/utils.py |
L56-88 |
Signature
def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict:
"""
Load a mixture of datasets based on the configuration.
Args:
mixture_config (DatasetMixtureConfig):
Script arguments containing dataset configuration.
Returns:
DatasetDict: Combined dataset(s) from the mixture configuration,
with optional train/test split if test_split_size is set.
"""
@dataclass
class DatasetMixtureConfig:
datasets: list[DatasetConfig] = field(default_factory=list)
streaming: bool = False
test_split_size: float | None = None
@dataclass
class DatasetConfig:
path: str
name: str | None = None
data_dir: str | None = None
data_files: str | list[str] | dict[str, str] | None = None
split: str = "train"
columns: list[str] | None = None
Dataset loading logic in the GRPO script:
# trl/scripts/grpo.py L136-149
if dataset_args.datasets and script_args.dataset_name:
logger.warning(
"Both `datasets` and `dataset_name` are provided. "
"The `datasets` argument will be used."
)
dataset = get_dataset(dataset_args)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name,
name=script_args.dataset_config,
streaming=script_args.dataset_streaming,
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
Import
from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig
from datasets import load_dataset
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
mixture_config |
DatasetMixtureConfig |
Configuration specifying the list of datasets to load, streaming mode, and optional test split size. |
mixture_config.datasets |
list[DatasetConfig] |
Each entry specifies a dataset path, optional name/config, split, and column selection. |
mixture_config.streaming |
bool |
Whether to load datasets in streaming mode. |
mixture_config.test_split_size |
None | If set, the combined dataset is split into train/test. If None, only a "train" key is returned.
|
Outputs
| Output | Type | Description |
|---|---|---|
| dataset | DatasetDict |
A dictionary with at minimum a "train" key. If test_split_size is set, also contains a "test" key. Each split contains at minimum a "prompt" column.
|
Usage Examples
Single dataset via YAML config:
# config.yaml
dataset_name: trl-lib/DeepMath-103K
dataset_train_split: train
dataset_test_split: test
Mixture of datasets via YAML config:
# config.yaml
datasets:
- path: trl-lib/DeepMath-103K
split: train
columns:
- prompt
- solution
- path: my-org/custom-math-prompts
split: train
columns:
- prompt
- solution
streaming: false
test_split_size: 0.05
Programmatic usage:
from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig
mixture_config = DatasetMixtureConfig(
datasets=[DatasetConfig(path="trl-lib/tldr")]
)
dataset = get_dataset(mixture_config)
print(dataset)
# DatasetDict({
# train: Dataset({
# features: ['prompt', 'completion'],
# num_rows: 116722
# })
# })
Key constraint: GRPO datasets must contain a "prompt" column. Additional columns (like "solution") are preserved when remove_unused_columns=False (the default) and forwarded to reward functions during training.