Implementation:Deepspeedai DeepSpeed Initialize For SFT
Overview
Concrete tool for initializing a DeepSpeed engine for supervised fine-tuning in the RLHF pipeline provided by the DeepSpeed library.
Description
Uses deepspeed.initialize() with a standard DeepSpeedEngine (not the hybrid engine) for the SFT phase. Configuration typically uses ZeRO Stage 2 or 3 with fp16 or bf16 mixed precision. The SFT model checkpoint produced at the end of this phase serves as the starting point for both the actor model (Step 3, RLHF) and the reward model (Step 2).
When hybrid_engine.enabled is not set (or set to False) in the DeepSpeed configuration, the initialize() function constructs a standard DeepSpeedEngine instance. This engine handles distributed data parallelism, ZeRO optimizer state partitioning, gradient communication, mixed-precision scaling, and learning rate scheduling. The SFT training loop then follows the standard pattern of calling the engine for the forward pass, engine.backward(loss) for gradient computation, and engine.step() for parameter updates.
Code Reference
| Property | Value |
|---|---|
| Repository | https://github.com/deepspeedai/DeepSpeed |
| File | deepspeed/__init__.py (L80-252), deepspeed/runtime/engine.py (L206-420, DeepSpeedEngine.__init__)
|
| Signature | def initialize(args=None, model=None, optimizer=None, model_parameters=None, training_data=None, lr_scheduler=None, distributed_port=29500, mpu=None, dist_init_required=None, collate_fn=None, config=None, mesh_param=None, config_params=None) -> Tuple[DeepSpeedEngine, Optimizer, DataLoader, LRScheduler]
|
| Import | import deepspeed
|
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | torch.nn.Module | Yes | Pretrained language model to fine-tune |
| config | dict or str | Yes | DeepSpeed configuration with ZeRO settings |
| training_data | torch.utils.data.Dataset | No | Optional training dataset for DataLoader creation |
| model_parameters | iterable | No | Parameters to optimize (defaults to all model parameters) |
| optimizer | Optimizer | No | Custom optimizer (otherwise created from config) |
| lr_scheduler | LRScheduler | No | Custom learning rate scheduler |
Outputs
| Name | Type | Description |
|---|---|---|
| engine | DeepSpeedEngine | Wrapped model for distributed SFT training |
| optimizer | Optimizer | Wrapped optimizer instance |
| dataloader | DataLoader | DataLoader if training_data was provided, otherwise None |
| lr_scheduler | LRScheduler | Learning rate scheduler if configured, otherwise None |
Usage Example
import deepspeed
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
sft_config = {
"train_batch_size": 32,
"zero_optimization": {"stage": 2},
"bf16": {"enabled": True}
}
engine, optimizer, _, _ = deepspeed.initialize(
model=model,
config=sft_config,
model_parameters=model.parameters()
)
# Standard SFT training loop
for batch in sft_dataloader:
loss = engine(batch)
engine.backward(loss)
engine.step()
engine.save_checkpoint("sft_checkpoint/")
Related Pages
Knowledge Sources
Last updated: 2026-02-09 00:00 GMT