Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples Create Prompt Dataset

From Leeroopedia


Template:Metadata

Overview

Concrete tool for preparing RLHF training datasets provided by the DeepSpeed-Chat library.

Description

create_prompt_dataset is the main entry point for dataset loading in DeepSpeed-Chat. It orchestrates the full data preparation pipeline for all three RLHF training phases:

  1. Dataset routing -- It calls get_raw_dataset, which maps a HuggingFace dataset name string to one of 16 dataset adapter classes (e.g., DahoasRmstaticDataset, OpenaiWebgptcomparisonsDataset, StanfordnlpSHPDataset, LocalJsonFileDataset, and others). Each adapter implements a common interface defined by the PromptRawDataset base class, exposing get_prompt(), get_chosen(), get_rejected(), get_prompt_and_chosen(), and get_prompt_and_rejected().
  2. Tokenization and formatting -- Based on the train_phase parameter, the function calls create_dataset_split which tokenizes and structures data differently per phase:
    • Phase 1 (SFT): Concatenates prompt + chosen response, appends end-of-conversation token, tokenizes with padding and truncation, and sets labels with -100 masking where attention is inactive.
    • Phase 2 (Reward): Tokenizes both the chosen and rejected responses separately, producing paired tensors for the pairwise ranking loss.
    • Phase 3 (RLHF/PPO): Tokenizes prompts only, filters out prompts exceeding max_seq_len, and flips the token sequence so that left-padding aligns correctly for autoregressive generation.
  3. Caching as .pt files -- Processed datasets are saved as PyTorch tensor files (traindata_<hash>.pt and evaldata_<hash>.pt) in the specified output directory. The cache key is a SHA-256 hash derived from the dataset paths, split ratios, phase number, seed, tokenizer name, and sequence length. On subsequent runs, cached files are loaded directly, skipping reprocessing.
  4. Distributed coordination -- The function uses torch.distributed.all_reduce to check whether any rank is missing the cache. Only rank 0 (or local_rank <= 0) creates the dataset; all other ranks wait at a torch.distributed.barrier() before loading from the shared cache.
  5. Multi-dataset blending -- When multiple dataset paths are provided, each is processed independently and then concatenated via ConcatDataset. The combined dataset is shuffled with a deterministic seed.
  6. SFT-only data augmentation -- During Phase 1, additional SFT-only datasets can be appended via the sft_only_data_path parameter. These are loaded with a "10,0,0" split (i.e., 100% allocated to Phase 1) and concatenated with the main training set.

Usage

Import create_prompt_dataset when setting up any of the 3 RLHF training phases to get properly formatted and cached datasets. It is called in:

  • Step 1 (SFT): training/step1_supervised_finetuning/main.py with train_phase=1
  • Step 2 (Reward Model): training/step2_reward_model_finetuning/main.py with train_phase=2
  • Step 3 (RLHF/PPO): training/step3_rlhf_finetuning/main.py with train_phase=3

Code Reference

Source

Repository File
DeepSpeedExamples applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py

Signature

def create_prompt_dataset(
    local_rank,
    data_path,
    data_split,
    output_path,
    train_phase,
    seed,
    tokenizer,
    max_seq_len,
    end_of_conversation_token="<|endoftext|>",
    sft_only_data_path=[],
    reload=False
) -> Tuple[Dataset, Dataset]:

Import

from dschat.utils.data.data_utils import create_prompt_dataset

I/O Contract

Inputs

Parameter Type Required Description
local_rank int Yes GPU rank for distributed caching. Only rank <= 0 creates datasets; others wait at barrier.
data_path list[str] Yes HuggingFace dataset names (e.g., ["Dahoas/rm-static"]). Multiple paths trigger dataset blending via ConcatDataset.
data_split str Yes Comma-separated train/eval split ratio string (e.g., "2,4,4"). The three values correspond to Phase 1, Phase 2, and Phase 3 proportions respectively.
output_path str Yes Cache directory for storing processed .pt files and shuffle index .npy files. Must be on local (non-shared) storage for each node.
train_phase int Yes Training phase selector: 1 = SFT, 2 = Reward Model, 3 = RLHF/PPO. Determines both data formatting and which split partition to use.
seed int Yes Random seed for deterministic shuffling and splitting. Ensures reproducibility across runs and ranks.
tokenizer AutoTokenizer Yes HuggingFace tokenizer instance. Used for encoding text and providing pad_token_id. Must have init_kwargs["name_or_path"] set (used in cache key computation).
max_seq_len int Yes Maximum sequence length for tokenization. Sequences are padded or truncated to this length (Phases 1 and 2). In Phase 3, prompts exceeding this length are filtered out.
end_of_conversation_token str No Token appended to the end of each response text before tokenization. Defaults to "<|endoftext|>".
sft_only_data_path list[str] No Additional dataset paths used exclusively in Phase 1 (SFT). These are loaded with a "10,0,0" split and concatenated with the main training data.
reload bool No If True, forces regeneration of cached datasets even when cache files exist. Defaults to False.

Outputs

Name Type Description
train_dataset PromptDataset Training split. A PyTorch Dataset returning phase-appropriate tensors: Phase 1 yields {input_ids, attention_mask, labels}; Phase 2 yields (chosen_ids, chosen_mask, rejected_ids, rejected_mask); Phase 3 yields (prompt_ids, prompt_mask, pad_token_id).
eval_dataset PromptDataset Evaluation split. Same structure as train_dataset but drawn from the evaluation partition of the underlying data.

Usage Examples

Example 1: Creating a Dataset for SFT (Phase 1)

from dschat.utils.data.data_utils import create_prompt_dataset
from dschat.utils.utils import load_hf_tokenizer

tokenizer = load_hf_tokenizer("facebook/opt-1.3b", fast_tokenizer=True)

train_dataset, eval_dataset = create_prompt_dataset(
    local_rank=0,
    data_path=["Dahoas/rm-static"],
    data_split="2,4,4",
    output_path="/tmp/data_files/",
    train_phase=1,
    seed=1234,
    tokenizer=tokenizer,
    max_seq_len=512,
    end_of_conversation_token=tokenizer.eos_token,
    sft_only_data_path=["pvduy/sharegpt_alpaca_oa_vicuna_format"],
)

# Each sample from train_dataset returns:
# {
#     "input_ids":      tensor of shape (max_seq_len,),
#     "attention_mask":  tensor of shape (max_seq_len,),
#     "labels":         tensor of shape (max_seq_len,) with -100 where attention_mask is 0
# }

Example 2: Creating a Dataset for RLHF/PPO (Phase 3)

from dschat.utils.data.data_utils import create_prompt_dataset
from dschat.utils.utils import load_hf_tokenizer

tokenizer = load_hf_tokenizer("facebook/opt-1.3b", fast_tokenizer=True)

prompt_train_dataset, _ = create_prompt_dataset(
    local_rank=0,
    data_path=["Dahoas/rm-static"],
    data_split="2,4,4",
    output_path="/tmp/data_files/",
    train_phase=3,
    seed=1234,
    tokenizer=tokenizer,
    max_seq_len=256,
)

# Each sample from prompt_train_dataset returns:
# (
#     prompt_input_ids,      # tensor of variable length, token order is flipped (reversed)
#     prompt_attention_mask,  # tensor of variable length, flipped to match input_ids
#     pad_token_id,          # int, used by DataCollatorRLHF for batch padding
# )
# Note: The eval_dataset is typically discarded (assigned to _) in Phase 3
# because RLHF evaluation is done via reward scoring, not a held-out loss.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment