Implementation:Volcengine Verl FSDPSFTTrainer Fit
| Field | Value |
|---|---|
| Knowledge Sources | API Doc (verl trainer) |
| Domains | Distributed Training, FSDP, Supervised Fine-Tuning, Gradient Accumulation |
| Last Updated | 2026-02-07 |
Overview
Description
The FSDPSFTTrainer class is a lightweight, single-file FSDP-based SFT trainer for verl. It manages the complete training lifecycle: model construction (with optional LoRA), FSDP wrapping (supporting both FSDP1 and FSDP2 strategies), optimizer and learning rate scheduler setup, distributed data loading, gradient accumulation via micro-batching, validation, checkpointing, and experiment tracking.
The fit() method orchestrates the outer training loop across epochs and steps. For each step, it delegates to training_step(batch), which splits the batch into micro-batches of size config.data.micro_batch_size_per_gpu, computes the forward pass and loss for each micro-batch, accumulates gradients, clips gradients via clip_grad_norm_, and performs the optimizer step. The loss computation in _compute_loss_and_backward() supports both standard and sequence-parallel modes (via Ulysses sequence parallelism with remove-padding).
The trainer supports checkpoint resumption via StatefulDataLoader, enabling training to resume from any saved step without losing dataloader state.
Usage
python -m verl.trainer.fsdp_sft_trainer \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
data.train_files=~/data/sft/train.parquet \
data.val_files=~/data/sft/test.parquet \
data.train_batch_size=64 \
data.micro_batch_size_per_gpu=4 \
optim.lr=1e-5 \
optim.clip_grad=1.0 \
model.strategy=fsdp
Code Reference
| Attribute | Detail |
|---|---|
| Source Location | verl/trainer/fsdp_sft_trainer.py, Lines 96-804
|
| Class | FSDPSFTTrainer
|
| Constructor | FSDPSFTTrainer(config, device_mesh, ulysses_device_mesh, tokenizer, train_dataset, val_dataset)
|
| Key Methods | fit(), training_step(batch), validation_step(batch), save_checkpoint(step)
|
| Import | from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
|
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
config |
OmegaConf DictConfig |
Full training configuration |
config.data.train_batch_size |
int |
Global training batch size (divided by DP size internally) |
config.data.micro_batch_size_per_gpu |
int |
Micro-batch size per GPU for gradient accumulation |
config.optim.lr |
float |
Peak learning rate |
config.optim.clip_grad |
float |
Maximum gradient norm for clipping |
config.optim.lr_warmup_steps_ratio |
float |
Fraction of total steps for LR warmup |
config.optim.lr_scheduler |
str |
LR scheduler type: "cosine" or "wsd"
|
config.model.strategy |
str |
FSDP strategy: "fsdp" (FSDP1) or "fsdp2" (FSDP2)
|
config.trainer.total_epochs |
int |
Number of training epochs |
config.trainer.save_freq |
int |
Save checkpoint every N steps |
config.trainer.test_freq |
int |
Run validation every N steps |
device_mesh |
DeviceMesh |
PyTorch distributed device mesh |
ulysses_device_mesh |
DeviceMesh |
Device mesh for Ulysses sequence parallelism |
tokenizer |
PreTrainedTokenizer |
HuggingFace tokenizer |
train_dataset |
Dataset |
Training dataset (e.g., SFTDataset)
|
val_dataset |
Dataset |
Validation dataset |
Outputs
| Output | Type | Description |
|---|---|---|
training_step return |
dict |
{"train/loss": float, "train/lr(1e-3)": float, "train/time(s)": float}
|
validation_step return |
torch.Tensor |
Scalar validation loss averaged across DP ranks |
| Side effect | Checkpoints | Model checkpoints saved in HuggingFace format at config.trainer.default_local_dir
|
| Side effect | Tracking logs | Training and validation metrics logged via Tracking (wandb, tensorboard, etc.)
|
Usage Examples
Example 1: Instantiate and run the trainer
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
ulysses_device_mesh = init_device_mesh(
"cuda",
mesh_shape=(dp_size, sp_size),
mesh_dim_names=("dp", "sp"),
)
trainer = FSDPSFTTrainer(
config=config,
device_mesh=device_mesh,
ulysses_device_mesh=ulysses_device_mesh,
tokenizer=tokenizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
trainer.fit()
Example 2: training_step internals
def training_step(self, batch):
self.fsdp_model.train()
self.optimizer.zero_grad()
micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
n_micro_batches = len(micro_batches)
step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss_and_backward(
batch=micro_batch, n_micro_batches=n_micro_batches
)
step_loss += loss.item()
grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
if torch.isfinite(grad_norm):
self.optimizer.step()
self.lr_scheduler.step()
return {"train/loss": step_loss, ...}
Example 3: Using the convenience run_sft function
from verl.trainer.fsdp_sft_trainer import run_sft
from omegaconf import OmegaConf
config = OmegaConf.load("config/sft_trainer.yaml")
run_sft(config)
# This handles device mesh init, dataset creation, trainer init, and fit()
Related Pages
- Principle:Volcengine_Verl_FSDP_Distributed_Training
- Environment:Volcengine_Verl_CUDA_GPU_Environment
- Environment:Volcengine_Verl_Python_Core_Dependencies
- Heuristic:Volcengine_Verl_FSDP_Mixed_Precision_Init
- verl/trainer/fsdp_sft_trainer.py -- Source file
- Implementation:Volcengine_Verl_SFTDataset -- Dataset consumed by this trainer
- Implementation:Volcengine_Verl_Get_Peft_Model_LoRA -- LoRA integration within model building
- Implementation:Volcengine_Verl_FSDPSFTTrainer_Save_Checkpoint -- Checkpoint saving details