Implementation:OpenRLHF OpenRLHF SFTDataset init
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Data_Processing, NLP |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for constructing tokenized SFT training datasets with loss masking provided by OpenRLHF.
Description
The SFTDataset class processes raw datasets by applying chat templates, tokenizing prompt-response pairs, computing loss masks that zero out prompt tokens, and handling multi-turn conversations. It supports parallel data processing via HuggingFace's map function and filters out samples exceeding the maximum sequence length.
Usage
Instantiate after calling blending_datasets to get a raw dataset. Pass the result to a DataLoader for training with SFTTrainer or KDTrainer.
Code Reference
Source Location
- Repository: OpenRLHF
- File: openrlhf/datasets/sft_dataset.py
- Lines: L35-230 (class), L45-87 (__init__)
Signature
class SFTDataset(Dataset):
def __init__(
self,
dataset, # datasets.Dataset: raw HF dataset
tokenizer: Callable, # tokenizer for encoding
max_length: int, # maximum sequence length
strategy, # DeepspeedStrategy
input_template=None, # str: prompt formatting template
pretrain_mode=False, # bool: if True, loss on all tokens
num_processors=8, # int: parallel processing workers
multiturn=False, # bool: multi-turn conversation support
) -> None:
Import
from openrlhf.datasets import SFTDataset
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | datasets.Dataset | Yes | Raw dataset from blending_datasets |
| tokenizer | Callable | Yes | HuggingFace tokenizer |
| max_length | int | Yes | Maximum sequence length |
| strategy | DeepspeedStrategy | Yes | Training strategy |
| pretrain_mode | bool | No | Loss on all tokens (default False) |
| multiturn | bool | No | Multi-turn conversation support (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| __getitem__ returns | Tuple[Tensor, Tensor, Tensor] | (input_ids, attention_mask, loss_mask) |
Usage Examples
from openrlhf.datasets import SFTDataset
from openrlhf.datasets.utils import blending_datasets
# Load and blend datasets
raw_dataset = blending_datasets(args.dataset, strategy=strategy)
# Create SFT dataset
train_dataset = SFTDataset(
raw_dataset,
tokenizer,
args.max_len,
strategy,
pretrain_mode=args.pretrain_mode,
multiturn=args.multiturn,
)
# Create dataloader
train_dataloader = strategy.setup_dataloader(
train_dataset,
args.micro_train_batch_size,
pin_memory=True,
collate_fn=train_dataset.collate_fn,
)
Related Pages
Implements Principle
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment