Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:OpenRLHF OpenRLHF SFTDataset init

From Leeroopedia


Knowledge Sources
Domains Data_Processing, NLP
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for constructing tokenized SFT training datasets with loss masking provided by OpenRLHF.

Description

The SFTDataset class processes raw datasets by applying chat templates, tokenizing prompt-response pairs, computing loss masks that zero out prompt tokens, and handling multi-turn conversations. It supports parallel data processing via HuggingFace's map function and filters out samples exceeding the maximum sequence length.

Usage

Instantiate after calling blending_datasets to get a raw dataset. Pass the result to a DataLoader for training with SFTTrainer or KDTrainer.

Code Reference

Source Location

  • Repository: OpenRLHF
  • File: openrlhf/datasets/sft_dataset.py
  • Lines: L35-230 (class), L45-87 (__init__)

Signature

class SFTDataset(Dataset):
    def __init__(
        self,
        dataset,                # datasets.Dataset: raw HF dataset
        tokenizer: Callable,    # tokenizer for encoding
        max_length: int,        # maximum sequence length
        strategy,               # DeepspeedStrategy
        input_template=None,    # str: prompt formatting template
        pretrain_mode=False,    # bool: if True, loss on all tokens
        num_processors=8,       # int: parallel processing workers
        multiturn=False,        # bool: multi-turn conversation support
    ) -> None:

Import

from openrlhf.datasets import SFTDataset

I/O Contract

Inputs

Name Type Required Description
dataset datasets.Dataset Yes Raw dataset from blending_datasets
tokenizer Callable Yes HuggingFace tokenizer
max_length int Yes Maximum sequence length
strategy DeepspeedStrategy Yes Training strategy
pretrain_mode bool No Loss on all tokens (default False)
multiturn bool No Multi-turn conversation support (default False)

Outputs

Name Type Description
__getitem__ returns Tuple[Tensor, Tensor, Tensor] (input_ids, attention_mask, loss_mask)

Usage Examples

from openrlhf.datasets import SFTDataset
from openrlhf.datasets.utils import blending_datasets

# Load and blend datasets
raw_dataset = blending_datasets(args.dataset, strategy=strategy)

# Create SFT dataset
train_dataset = SFTDataset(
    raw_dataset,
    tokenizer,
    args.max_len,
    strategy,
    pretrain_mode=args.pretrain_mode,
    multiturn=args.multiturn,
)

# Create dataloader
train_dataloader = strategy.setup_dataloader(
    train_dataset,
    args.micro_train_batch_size,
    pin_memory=True,
    collate_fn=train_dataset.collate_fn,
)

Related Pages

Implements Principle

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment