Implementation:Huggingface Alignment handbook Get Dataset Mixture
| Knowledge Sources | |
|---|---|
| Domains | NLP, Data_Engineering, Training |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for loading and blending weighted dataset mixtures from HuggingFace Hub, provided by the alignment-handbook library.
Description
The get_dataset function in mixture mode (when args.dataset_mixture is set) iterates over a list of DatasetConfig entries, loads each dataset from HuggingFace Hub, optionally selects specific columns and subsamples by weight, then concatenates and shuffles the combined result. This is the alignment-handbook's main differentiator over vanilla TRL dataset loading.
The mixture configuration is parsed by ScriptArguments.__post_init__ which validates the YAML dictionary and converts it into a DatasetMixtureConfig containing a list of DatasetConfig objects.
Usage
Use this when training with multiple data sources, such as SmolLM3's 25-split SFT mixture or the 2-split mid-training reasoning mixture. The mixture is configured entirely via YAML config files.
Code Reference
Source Location
- Repository: alignment-handbook
- File: src/alignment/data.py (lines 26-79, mixture branch at lines 38-76)
- Config: src/alignment/configs.py (lines 36-131, DatasetConfig, DatasetMixtureConfig, ScriptArguments)
Signature
def get_dataset(args: ScriptArguments) -> DatasetDict:
"""Load a dataset or a mixture of datasets based on the configuration.
Args:
args (ScriptArguments): Script arguments containing dataset configuration.
When args.dataset_mixture is set, operates in mixture mode.
Returns:
DatasetDict: The loaded datasets with 'train' split
(and optionally 'test' split if test_split_size is set).
"""
@dataclass
class DatasetConfig:
"""Configuration for a dataset in a mixture."""
id: str # HuggingFace dataset ID
config: Optional[str] = None # Dataset config name
split: str = "train" # Split to load
columns: Optional[list[str]] = None # Columns to select
weight: Optional[float] = None # Sampling weight (0.0-1.0)
@dataclass
class DatasetMixtureConfig:
"""Configuration for a mixture of datasets."""
datasets: list[DatasetConfig] # List of dataset configs
seed: int = 0 # Random seed for shuffling
test_split_size: Optional[float] = None # Optional test split fraction
Import
from alignment import get_dataset, ScriptArguments
from alignment.configs import DatasetConfig, DatasetMixtureConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| args | ScriptArguments | Yes | Script arguments with dataset_mixture set |
| args.dataset_mixture.datasets | list[DatasetConfig] | Yes | List of dataset configurations with id, config, split, columns, weight |
| args.dataset_mixture.seed | int | No | Random seed for shuffling (default: 0) |
| args.dataset_mixture.test_split_size | Optional[float] | No | Fraction for test split (e.g., 0.1 for 10%) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | DatasetDict | Dictionary with train split. If test_split_size is set, also contains test split. Combined and shuffled from all mixture components |
Usage Examples
YAML Mixture Configuration (SmolLM3 Mid-Training)
# From recipes/smollm3/sft/mid.yaml
dataset_mixture:
datasets:
- id: HuggingFaceTB/smoltalk2
config: Llama_Nemotron
columns:
- messages
weight: 1.0
- id: HuggingFaceTB/smoltalk2
config: OpenThoughts3
columns:
- messages
weight: 1.0
seed: 42
test_split_size: 0.01
YAML Mixture Configuration (SmolLM3 SFT with 25 splits)
# From recipes/smollm3/sft/sft.yaml (abbreviated)
dataset_mixture:
datasets:
- id: HuggingFaceTB/smoltalk2
config: everyday-conversations_think
columns: [messages]
weight: 0.3
- id: HuggingFaceTB/smoltalk2
config: smol-magpie-ultra_think
columns: [messages]
weight: 0.15
# ... 23 more splits with varying weights
seed: 42
test_split_size: 0.005
Programmatic Usage
from alignment import get_dataset
# After TrlParser parses YAML with dataset_mixture section,
# ScriptArguments.__post_init__ converts the dict to DatasetMixtureConfig
dataset = get_dataset(script_args)
print(f"Training samples: {len(dataset['train'])}")
# Training samples: 1500000 (combined from all mixture components)
if "test" in dataset:
print(f"Test samples: {len(dataset['test'])}")