Environment:Huggingface Trl DeepSpeed Environment
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Distributed_Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Optional DeepSpeed environment requiring deepspeed >= 0.14.4 for ZeRO Stage 1/2/3 distributed training with model and optimizer sharding.
Description
This environment provides the DeepSpeed library for memory-efficient distributed training. TRL ships with pre-built Accelerate configuration files for ZeRO Stage 1, 2, and 3. DeepSpeed ZeRO-3 is particularly important for training models that exceed single-GPU VRAM, as it shards model parameters, gradients, and optimizer states across GPUs. TRL includes special handling for ZeRO-3 during generation (gathering parameters for inference) and for reference model preparation.
Usage
Use this environment when training with multiple GPUs and needing memory-efficient model sharding. Required when using any of the TRL accelerate configs: trl/accelerate_configs/zero1.yaml, zero2.yaml, or zero3.yaml. Also needed when using device_map=None for distributed training with GRPO or DPO.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| OS | Linux | DeepSpeed has limited non-Linux support |
| Hardware | Multiple NVIDIA GPUs | ZeRO sharding requires multi-GPU setup |
| Python | >= 3.10 | Must match TRL core requirements |
| CUDA | Compatible with PyTorch | DeepSpeed compiles custom CUDA kernels |
Dependencies
System Packages
- `cuda-toolkit` (matching PyTorch CUDA version)
- `libaio-dev` (for async I/O in ZeRO offloading)
Python Packages
- `deepspeed` >= 0.14.4
- `transformers` != 5.1.0 (incompatibility; see transformers#43780)
- `accelerate` >= 1.4.0 (from core)
Credentials
No additional credentials required.
Quick Install
# Install TRL with DeepSpeed support
pip install "trl[deepspeed]"
# Or install DeepSpeed separately
pip install "deepspeed>=0.14.4"
Code Evidence
DeepSpeed availability check from `trl/import_utils.py:30-31`:
def is_deepspeed_available() -> bool:
return _is_package_available("deepspeed")
DeepSpeed version-dependent behavior in `trl/models/utils.py:80-84`:
if Version(deepspeed.__version__) >= Version("0.16.4"):
# Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847
optimizer_offload._register_deepspeed_module(optimizer_offload.module)
else:
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
ZeRO-3 parameter gathering for generation from `trl/models/utils.py:118-127`:
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
add_hooks(model)
ZeRO-3 bucket sizing from `trl/models/utils.py:233-243`:
if hidden_size is not None and stage == 3:
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
Distributed training device_map override from `trl/trainer/grpo_trainer.py:356-357`:
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
Invalidate trace cache @ step 0: expected module 1, but got module 0 |
Normal DeepSpeed ZeRO-3 message from stage3_prefetch_bucket_size |
Not an error; can be safely ignored |
The model optimizer is None |
Trying to unwrap model before first training step | Ensure model has been through at least one optimizer step before generation |
device_map="auto" fails in distributed |
Using device_map=auto with multi-GPU DeepSpeed | TRL automatically sets device_map=None for distributed training
|
Compatibility Notes
- DeepSpeed >= 0.16.4: Uses renamed
_register_deepspeed_modulemethod (from_register_hooks_recursively). - transformers == 5.1.0: Known incompatibility (see transformers#43780). Excluded in TRL's optional dependencies.
- ZeRO-3 + vLLM:
ds3_gather_for_generationmust be True (default) when using vLLM. - ZeRO-3 + PEFT + gradient checkpointing: Fixed in recent TRL commit (f11b4c3). Requires
enable_input_require_grads(). - DeepSpeed local import: TRL imports deepspeed locally (not at top-level) to avoid DS init interfering with other backends like vLLM.