Implementation:Microsoft DeepSpeedExamples Create Prompt Dataset
Appearance
Overview
Concrete tool for preparing RLHF training datasets provided by the DeepSpeed-Chat library.
Description
create_prompt_dataset is the main entry point for dataset loading in DeepSpeed-Chat. It orchestrates the full data preparation pipeline for all three RLHF training phases:
- Dataset routing -- It calls
get_raw_dataset, which maps a HuggingFace dataset name string to one of 16 dataset adapter classes (e.g.,DahoasRmstaticDataset,OpenaiWebgptcomparisonsDataset,StanfordnlpSHPDataset,LocalJsonFileDataset, and others). Each adapter implements a common interface defined by thePromptRawDatasetbase class, exposingget_prompt(),get_chosen(),get_rejected(),get_prompt_and_chosen(), andget_prompt_and_rejected(). - Tokenization and formatting -- Based on the
train_phaseparameter, the function callscreate_dataset_splitwhich tokenizes and structures data differently per phase:- Phase 1 (SFT): Concatenates prompt + chosen response, appends end-of-conversation token, tokenizes with padding and truncation, and sets labels with
-100masking where attention is inactive. - Phase 2 (Reward): Tokenizes both the chosen and rejected responses separately, producing paired tensors for the pairwise ranking loss.
- Phase 3 (RLHF/PPO): Tokenizes prompts only, filters out prompts exceeding
max_seq_len, and flips the token sequence so that left-padding aligns correctly for autoregressive generation.
- Phase 1 (SFT): Concatenates prompt + chosen response, appends end-of-conversation token, tokenizes with padding and truncation, and sets labels with
- Caching as
.ptfiles -- Processed datasets are saved as PyTorch tensor files (traindata_<hash>.ptandevaldata_<hash>.pt) in the specified output directory. The cache key is a SHA-256 hash derived from the dataset paths, split ratios, phase number, seed, tokenizer name, and sequence length. On subsequent runs, cached files are loaded directly, skipping reprocessing. - Distributed coordination -- The function uses
torch.distributed.all_reduceto check whether any rank is missing the cache. Only rank 0 (orlocal_rank <= 0) creates the dataset; all other ranks wait at atorch.distributed.barrier()before loading from the shared cache. - Multi-dataset blending -- When multiple dataset paths are provided, each is processed independently and then concatenated via
ConcatDataset. The combined dataset is shuffled with a deterministic seed. - SFT-only data augmentation -- During Phase 1, additional SFT-only datasets can be appended via the
sft_only_data_pathparameter. These are loaded with a"10,0,0"split (i.e., 100% allocated to Phase 1) and concatenated with the main training set.
Usage
Import create_prompt_dataset when setting up any of the 3 RLHF training phases to get properly formatted and cached datasets. It is called in:
- Step 1 (SFT):
training/step1_supervised_finetuning/main.pywithtrain_phase=1 - Step 2 (Reward Model):
training/step2_reward_model_finetuning/main.pywithtrain_phase=2 - Step 3 (RLHF/PPO):
training/step3_rlhf_finetuning/main.pywithtrain_phase=3
Code Reference
Source
| Repository | File |
|---|---|
| DeepSpeedExamples | applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py
|
Signature
def create_prompt_dataset(
local_rank,
data_path,
data_split,
output_path,
train_phase,
seed,
tokenizer,
max_seq_len,
end_of_conversation_token="<|endoftext|>",
sft_only_data_path=[],
reload=False
) -> Tuple[Dataset, Dataset]:
Import
from dschat.utils.data.data_utils import create_prompt_dataset
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
local_rank |
int |
Yes | GPU rank for distributed caching. Only rank <= 0 creates datasets; others wait at barrier. |
data_path |
list[str] |
Yes | HuggingFace dataset names (e.g., ["Dahoas/rm-static"]). Multiple paths trigger dataset blending via ConcatDataset.
|
data_split |
str |
Yes | Comma-separated train/eval split ratio string (e.g., "2,4,4"). The three values correspond to Phase 1, Phase 2, and Phase 3 proportions respectively.
|
output_path |
str |
Yes | Cache directory for storing processed .pt files and shuffle index .npy files. Must be on local (non-shared) storage for each node.
|
train_phase |
int |
Yes | Training phase selector: 1 = SFT, 2 = Reward Model, 3 = RLHF/PPO. Determines both data formatting and which split partition to use.
|
seed |
int |
Yes | Random seed for deterministic shuffling and splitting. Ensures reproducibility across runs and ranks. |
tokenizer |
AutoTokenizer |
Yes | HuggingFace tokenizer instance. Used for encoding text and providing pad_token_id. Must have init_kwargs["name_or_path"] set (used in cache key computation).
|
max_seq_len |
int |
Yes | Maximum sequence length for tokenization. Sequences are padded or truncated to this length (Phases 1 and 2). In Phase 3, prompts exceeding this length are filtered out. |
end_of_conversation_token |
str |
No | Token appended to the end of each response text before tokenization. Defaults to "<|endoftext|>".
|
sft_only_data_path |
list[str] |
No | Additional dataset paths used exclusively in Phase 1 (SFT). These are loaded with a "10,0,0" split and concatenated with the main training data.
|
reload |
bool |
No | If True, forces regeneration of cached datasets even when cache files exist. Defaults to False.
|
Outputs
| Name | Type | Description |
|---|---|---|
train_dataset |
PromptDataset |
Training split. A PyTorch Dataset returning phase-appropriate tensors: Phase 1 yields {input_ids, attention_mask, labels}; Phase 2 yields (chosen_ids, chosen_mask, rejected_ids, rejected_mask); Phase 3 yields (prompt_ids, prompt_mask, pad_token_id).
|
eval_dataset |
PromptDataset |
Evaluation split. Same structure as train_dataset but drawn from the evaluation partition of the underlying data.
|
Usage Examples
Example 1: Creating a Dataset for SFT (Phase 1)
from dschat.utils.data.data_utils import create_prompt_dataset
from dschat.utils.utils import load_hf_tokenizer
tokenizer = load_hf_tokenizer("facebook/opt-1.3b", fast_tokenizer=True)
train_dataset, eval_dataset = create_prompt_dataset(
local_rank=0,
data_path=["Dahoas/rm-static"],
data_split="2,4,4",
output_path="/tmp/data_files/",
train_phase=1,
seed=1234,
tokenizer=tokenizer,
max_seq_len=512,
end_of_conversation_token=tokenizer.eos_token,
sft_only_data_path=["pvduy/sharegpt_alpaca_oa_vicuna_format"],
)
# Each sample from train_dataset returns:
# {
# "input_ids": tensor of shape (max_seq_len,),
# "attention_mask": tensor of shape (max_seq_len,),
# "labels": tensor of shape (max_seq_len,) with -100 where attention_mask is 0
# }
Example 2: Creating a Dataset for RLHF/PPO (Phase 3)
from dschat.utils.data.data_utils import create_prompt_dataset
from dschat.utils.utils import load_hf_tokenizer
tokenizer = load_hf_tokenizer("facebook/opt-1.3b", fast_tokenizer=True)
prompt_train_dataset, _ = create_prompt_dataset(
local_rank=0,
data_path=["Dahoas/rm-static"],
data_split="2,4,4",
output_path="/tmp/data_files/",
train_phase=3,
seed=1234,
tokenizer=tokenizer,
max_seq_len=256,
)
# Each sample from prompt_train_dataset returns:
# (
# prompt_input_ids, # tensor of variable length, token order is flipped (reversed)
# prompt_attention_mask, # tensor of variable length, flipped to match input_ids
# pad_token_id, # int, used by DataCollatorRLHF for batch padding
# )
# Note: The eval_dataset is typically discarded (assigned to _) in Phase 3
# because RLHF evaluation is done via reward scoring, not a held-out loss.
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment