Implementation:Volcengine Verl SFTDataset
| Field | Value |
|---|---|
| Knowledge Sources | API Doc (verl dataset utilities) |
| Domains | Supervised Fine-Tuning, Dataset Loading, Tokenization |
| Last Updated | 2026-02-07 |
Overview
Description
The SFTDataset class is an in-memory PyTorch Dataset for supervised fine-tuning (SFT) in verl. It reads one or more Parquet files, extracts prompt and response columns, applies chat template formatting, tokenizes the text, and returns padded/truncated tensors with a loss mask that excludes the prompt tokens from the training loss.
The class supports configurable prompt and response keys (including nested dict access via prompt_dict_keys and response_dict_keys), three truncation strategies ("error", "left", "right"), optional random subsampling via max_samples, and optional shared memory usage for faster I/O. Each __getitem__ call returns a dictionary with input_ids, attention_mask, position_ids, and loss_mask, all as tensors of length max_length.
Usage
from verl.utils.dataset.sft_dataset import SFTDataset
dataset = SFTDataset(
parquet_files="~/data/sft/train.parquet",
tokenizer=tokenizer,
config=data_config,
max_samples=1000,
)
Code Reference
| Attribute | Detail |
|---|---|
| Source Location | verl/utils/dataset/sft_dataset.py, Lines 33-204
|
| Signature | class SFTDataset(Dataset)
|
| Constructor | ListConfig, tokenizer, config, max_samples: int = -1) |
| Import | from verl.utils.dataset.sft_dataset import SFTDataset
|
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
parquet_files |
str or ListConfig |
Path(s) to Parquet files containing the SFT data |
tokenizer |
PreTrainedTokenizer or str |
HuggingFace tokenizer instance or path to tokenizer |
config |
OmegaConf |
Configuration object with dataset parameters |
config.prompt_key |
str |
Column name for prompt data (default: "prompt")
|
config.response_key |
str |
Column name for response data (default: "response")
|
config.max_length |
int |
Maximum sequence length for padding/truncation (default: 1024)
|
config.truncation |
str |
Truncation strategy: "error", "left", or "right" (default: "error")
|
max_samples |
int |
Maximum number of samples to use; -1 for all (default: -1)
|
Outputs
| Output | Type | Shape | Description |
|---|---|---|---|
input_ids |
torch.Tensor |
[max_length] |
Token IDs for the concatenated prompt + response, padded to max_length
|
attention_mask |
torch.Tensor |
[max_length] |
Binary mask: 1 for real tokens, 0 for padding |
position_ids |
torch.Tensor |
[max_length] |
Position indices computed from attention mask |
loss_mask |
torch.Tensor |
[max_length] |
Binary mask: 1 for response tokens included in loss, 0 for prompt/padding tokens |
Usage Examples
Example 1: Basic SFTDataset initialization
from verl.utils.dataset.sft_dataset import SFTDataset
from transformers import AutoTokenizer
from omegaconf import OmegaConf
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
config = OmegaConf.create({
"prompt_key": "prompt",
"response_key": "response",
"max_length": 2048,
"truncation": "right",
})
dataset = SFTDataset(
parquet_files="~/data/sft/train.parquet",
tokenizer=tokenizer,
config=config,
)
print(f"Dataset size: {len(dataset)}")
sample = dataset[0]
print(f"input_ids shape: {sample['input_ids'].shape}") # torch.Size([2048])
print(f"loss_mask shape: {sample['loss_mask'].shape}") # torch.Size([2048])
Example 2: Accessing a single item
item = dataset[42]
# item is a dict with keys: input_ids, attention_mask, position_ids, loss_mask
# All tensors have shape [max_length]
# loss_mask is 0 for prompt tokens and 1 for response tokens (excluding padding)
Example 3: Using with DataLoader
from torch.utils.data import DataLoader, DistributedSampler
sampler = DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
dataset=dataset,
batch_size=4,
sampler=sampler,
num_workers=8,
pin_memory=True,
)
for batch in dataloader:
# batch["input_ids"].shape: [4, max_length]
# batch["loss_mask"].shape: [4, max_length]
pass
Related Pages
- Principle:Volcengine_Verl_SFT_Data_Preparation
- verl/utils/dataset/sft_dataset.py -- Source file
- Implementation:Volcengine_Verl_FSDPSFTTrainer_Fit -- Trainer that consumes SFTDataset
- Implementation:Volcengine_Verl_Dataset_To_Parquet -- Upstream Parquet export