Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Volcengine Verl SFTDataset

From Leeroopedia


Field Value
Knowledge Sources API Doc (verl dataset utilities)
Domains Supervised Fine-Tuning, Dataset Loading, Tokenization
Last Updated 2026-02-07

Overview

Description

The SFTDataset class is an in-memory PyTorch Dataset for supervised fine-tuning (SFT) in verl. It reads one or more Parquet files, extracts prompt and response columns, applies chat template formatting, tokenizes the text, and returns padded/truncated tensors with a loss mask that excludes the prompt tokens from the training loss.

The class supports configurable prompt and response keys (including nested dict access via prompt_dict_keys and response_dict_keys), three truncation strategies ("error", "left", "right"), optional random subsampling via max_samples, and optional shared memory usage for faster I/O. Each __getitem__ call returns a dictionary with input_ids, attention_mask, position_ids, and loss_mask, all as tensors of length max_length.

Usage

from verl.utils.dataset.sft_dataset import SFTDataset

dataset = SFTDataset(
    parquet_files="~/data/sft/train.parquet",
    tokenizer=tokenizer,
    config=data_config,
    max_samples=1000,
)

Code Reference

Attribute Detail
Source Location verl/utils/dataset/sft_dataset.py, Lines 33-204
Signature class SFTDataset(Dataset)
Constructor ListConfig, tokenizer, config, max_samples: int = -1)
Import from verl.utils.dataset.sft_dataset import SFTDataset

I/O Contract

Inputs

Parameter Type Description
parquet_files str or ListConfig Path(s) to Parquet files containing the SFT data
tokenizer PreTrainedTokenizer or str HuggingFace tokenizer instance or path to tokenizer
config OmegaConf Configuration object with dataset parameters
config.prompt_key str Column name for prompt data (default: "prompt")
config.response_key str Column name for response data (default: "response")
config.max_length int Maximum sequence length for padding/truncation (default: 1024)
config.truncation str Truncation strategy: "error", "left", or "right" (default: "error")
max_samples int Maximum number of samples to use; -1 for all (default: -1)

Outputs

Output Type Shape Description
input_ids torch.Tensor [max_length] Token IDs for the concatenated prompt + response, padded to max_length
attention_mask torch.Tensor [max_length] Binary mask: 1 for real tokens, 0 for padding
position_ids torch.Tensor [max_length] Position indices computed from attention mask
loss_mask torch.Tensor [max_length] Binary mask: 1 for response tokens included in loss, 0 for prompt/padding tokens

Usage Examples

Example 1: Basic SFTDataset initialization

from verl.utils.dataset.sft_dataset import SFTDataset
from transformers import AutoTokenizer
from omegaconf import OmegaConf

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
config = OmegaConf.create({
    "prompt_key": "prompt",
    "response_key": "response",
    "max_length": 2048,
    "truncation": "right",
})

dataset = SFTDataset(
    parquet_files="~/data/sft/train.parquet",
    tokenizer=tokenizer,
    config=config,
)

print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print(f"input_ids shape: {sample['input_ids'].shape}")   # torch.Size([2048])
print(f"loss_mask shape: {sample['loss_mask'].shape}")     # torch.Size([2048])

Example 2: Accessing a single item

item = dataset[42]
# item is a dict with keys: input_ids, attention_mask, position_ids, loss_mask
# All tensors have shape [max_length]
# loss_mask is 0 for prompt tokens and 1 for response tokens (excluding padding)

Example 3: Using with DataLoader

from torch.utils.data import DataLoader, DistributedSampler

sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=4,
    sampler=sampler,
    num_workers=8,
    pin_memory=True,
)

for batch in dataloader:
    # batch["input_ids"].shape: [4, max_length]
    # batch["loss_mask"].shape: [4, max_length]
    pass

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment