Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Trl Get Dataset GRPO

From Leeroopedia


Property Value
Implementation Name Get Dataset for GRPO
Library Huggingface TRL
Type API Doc
Source Files trl/scripts/grpo.py (L136-149), trl/scripts/utils.py (L414-474)
Import from trl.scripts.utils import get_dataset; from trl import DatasetMixtureConfig

Overview

Description

The get_dataset function loads and combines multiple datasets into a single DatasetDict based on a DatasetMixtureConfig. In the GRPO script, dataset loading supports two paths: the mixture config path (via get_dataset) and the simple single-dataset path (via datasets.load_dataset directly). The mixture config path takes priority when both are specified.

Usage

from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig

# Mixture of datasets
mixture_config = DatasetMixtureConfig(
    datasets=[
        DatasetConfig(path="trl-lib/DeepMath-103K", split="train"),
        DatasetConfig(path="my-org/custom-prompts", split="train", columns=["prompt", "solution"]),
    ],
    test_split_size=0.05,
)
dataset = get_dataset(mixture_config)
# Returns DatasetDict with "train" and "test" splits

Code Reference

Source Location

Function File Lines
get_dataset trl/scripts/utils.py L414-474
Dataset loading in GRPO script trl/scripts/grpo.py L136-149
DatasetMixtureConfig trl/scripts/utils.py L90-152
DatasetConfig trl/scripts/utils.py L56-88

Signature

def get_dataset(mixture_config: DatasetMixtureConfig) -> DatasetDict:
    """
    Load a mixture of datasets based on the configuration.

    Args:
        mixture_config (DatasetMixtureConfig):
            Script arguments containing dataset configuration.

    Returns:
        DatasetDict: Combined dataset(s) from the mixture configuration,
            with optional train/test split if test_split_size is set.
    """
@dataclass
class DatasetMixtureConfig:
    datasets: list[DatasetConfig] = field(default_factory=list)
    streaming: bool = False
    test_split_size: float | None = None
@dataclass
class DatasetConfig:
    path: str
    name: str | None = None
    data_dir: str | None = None
    data_files: str | list[str] | dict[str, str] | None = None
    split: str = "train"
    columns: list[str] | None = None

Dataset loading logic in the GRPO script:

# trl/scripts/grpo.py L136-149
if dataset_args.datasets and script_args.dataset_name:
    logger.warning(
        "Both `datasets` and `dataset_name` are provided. "
        "The `datasets` argument will be used."
    )
    dataset = get_dataset(dataset_args)
elif dataset_args.datasets and not script_args.dataset_name:
    dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
    dataset = load_dataset(
        script_args.dataset_name,
        name=script_args.dataset_config,
        streaming=script_args.dataset_streaming,
    )
else:
    raise ValueError("Either `datasets` or `dataset_name` must be provided.")

Import

from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig
from datasets import load_dataset

I/O Contract

Inputs

Parameter Type Description
mixture_config DatasetMixtureConfig Configuration specifying the list of datasets to load, streaming mode, and optional test split size.
mixture_config.datasets list[DatasetConfig] Each entry specifies a dataset path, optional name/config, split, and column selection.
mixture_config.streaming bool Whether to load datasets in streaming mode.
mixture_config.test_split_size None If set, the combined dataset is split into train/test. If None, only a "train" key is returned.

Outputs

Output Type Description
dataset DatasetDict A dictionary with at minimum a "train" key. If test_split_size is set, also contains a "test" key. Each split contains at minimum a "prompt" column.

Usage Examples

Single dataset via YAML config:

# config.yaml
dataset_name: trl-lib/DeepMath-103K
dataset_train_split: train
dataset_test_split: test

Mixture of datasets via YAML config:

# config.yaml
datasets:
  - path: trl-lib/DeepMath-103K
    split: train
    columns:
      - prompt
      - solution
  - path: my-org/custom-math-prompts
    split: train
    columns:
      - prompt
      - solution
streaming: false
test_split_size: 0.05

Programmatic usage:

from trl import DatasetMixtureConfig, get_dataset
from trl.scripts.utils import DatasetConfig

mixture_config = DatasetMixtureConfig(
    datasets=[DatasetConfig(path="trl-lib/tldr")]
)
dataset = get_dataset(mixture_config)
print(dataset)
# DatasetDict({
#     train: Dataset({
#         features: ['prompt', 'completion'],
#         num_rows: 116722
#     })
# })

Key constraint: GRPO datasets must contain a "prompt" column. Additional columns (like "solution") are preserved when remove_unused_columns=False (the default) and forwarded to reward functions during training.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment