Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Volcengine Verl FSDPSFTTrainer Fit

From Leeroopedia


Field Value
Knowledge Sources API Doc (verl trainer)
Domains Distributed Training, FSDP, Supervised Fine-Tuning, Gradient Accumulation
Last Updated 2026-02-07

Overview

Description

The FSDPSFTTrainer class is a lightweight, single-file FSDP-based SFT trainer for verl. It manages the complete training lifecycle: model construction (with optional LoRA), FSDP wrapping (supporting both FSDP1 and FSDP2 strategies), optimizer and learning rate scheduler setup, distributed data loading, gradient accumulation via micro-batching, validation, checkpointing, and experiment tracking.

The fit() method orchestrates the outer training loop across epochs and steps. For each step, it delegates to training_step(batch), which splits the batch into micro-batches of size config.data.micro_batch_size_per_gpu, computes the forward pass and loss for each micro-batch, accumulates gradients, clips gradients via clip_grad_norm_, and performs the optimizer step. The loss computation in _compute_loss_and_backward() supports both standard and sequence-parallel modes (via Ulysses sequence parallelism with remove-padding).

The trainer supports checkpoint resumption via StatefulDataLoader, enabling training to resume from any saved step without losing dataloader state.

Usage

python -m verl.trainer.fsdp_sft_trainer \
    model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
    data.train_files=~/data/sft/train.parquet \
    data.val_files=~/data/sft/test.parquet \
    data.train_batch_size=64 \
    data.micro_batch_size_per_gpu=4 \
    optim.lr=1e-5 \
    optim.clip_grad=1.0 \
    model.strategy=fsdp

Code Reference

Attribute Detail
Source Location verl/trainer/fsdp_sft_trainer.py, Lines 96-804
Class FSDPSFTTrainer
Constructor FSDPSFTTrainer(config, device_mesh, ulysses_device_mesh, tokenizer, train_dataset, val_dataset)
Key Methods fit(), training_step(batch), validation_step(batch), save_checkpoint(step)
Import from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer

I/O Contract

Inputs

Parameter Type Description
config OmegaConf DictConfig Full training configuration
config.data.train_batch_size int Global training batch size (divided by DP size internally)
config.data.micro_batch_size_per_gpu int Micro-batch size per GPU for gradient accumulation
config.optim.lr float Peak learning rate
config.optim.clip_grad float Maximum gradient norm for clipping
config.optim.lr_warmup_steps_ratio float Fraction of total steps for LR warmup
config.optim.lr_scheduler str LR scheduler type: "cosine" or "wsd"
config.model.strategy str FSDP strategy: "fsdp" (FSDP1) or "fsdp2" (FSDP2)
config.trainer.total_epochs int Number of training epochs
config.trainer.save_freq int Save checkpoint every N steps
config.trainer.test_freq int Run validation every N steps
device_mesh DeviceMesh PyTorch distributed device mesh
ulysses_device_mesh DeviceMesh Device mesh for Ulysses sequence parallelism
tokenizer PreTrainedTokenizer HuggingFace tokenizer
train_dataset Dataset Training dataset (e.g., SFTDataset)
val_dataset Dataset Validation dataset

Outputs

Output Type Description
training_step return dict {"train/loss": float, "train/lr(1e-3)": float, "train/time(s)": float}
validation_step return torch.Tensor Scalar validation loss averaged across DP ranks
Side effect Checkpoints Model checkpoints saved in HuggingFace format at config.trainer.default_local_dir
Side effect Tracking logs Training and validation metrics logged via Tracking (wandb, tensorboard, etc.)

Usage Examples

Example 1: Instantiate and run the trainer

from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
from torch.distributed.device_mesh import init_device_mesh

device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
ulysses_device_mesh = init_device_mesh(
    "cuda",
    mesh_shape=(dp_size, sp_size),
    mesh_dim_names=("dp", "sp"),
)

trainer = FSDPSFTTrainer(
    config=config,
    device_mesh=device_mesh,
    ulysses_device_mesh=ulysses_device_mesh,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
)
trainer.fit()

Example 2: training_step internals

def training_step(self, batch):
    self.fsdp_model.train()
    self.optimizer.zero_grad()

    micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu)
    n_micro_batches = len(micro_batches)
    step_loss = 0

    for micro_batch in micro_batches:
        loss = self._compute_loss_and_backward(
            batch=micro_batch, n_micro_batches=n_micro_batches
        )
        step_loss += loss.item()

    grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
    if torch.isfinite(grad_norm):
        self.optimizer.step()
    self.lr_scheduler.step()

    return {"train/loss": step_loss, ...}

Example 3: Using the convenience run_sft function

from verl.trainer.fsdp_sft_trainer import run_sft
from omegaconf import OmegaConf

config = OmegaConf.load("config/sft_trainer.yaml")
run_sft(config)
# This handles device mesh init, dataset creation, trainer init, and fit()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment