Environment:CarperAI Trlx DeepSpeed Multi GPU
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Distributed_Training, Deep_Learning |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Multi-GPU Linux environment with DeepSpeed ZeRO Stage 2/3 for training large language models (6B-20B+) that exceed single-GPU memory capacity.
Description
This environment extends the base Python_Accelerate environment with DeepSpeed for distributed training across multiple GPUs. It provides ZeRO (Zero Redundancy Optimizer) memory optimization at Stage 2 (optimizer + gradient partitioning) or Stage 3 (full parameter partitioning). CPU offloading of optimizer states and parameters is supported for models that exceed aggregate GPU memory. The environment is orchestrated through HuggingFace Accelerate with DeepSpeed integration and includes pre-configured YAML and JSON config files for common setups.
Usage
Use this environment when training models that are too large for a single GPU (typically 6B+ parameters), when running reward model training on GPT-J/GPT-NeoX scale models, or when performing PPO optimization with large policy models. It is the mandatory prerequisite for running the GPTRewardModel training and the Create_Reward_Fn serving pipeline.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| OS | Linux | Multi-node requires shared filesystem (NFS/Lustre) |
| Python | 3.9 - 3.11 | Same as base environment |
| Hardware | Multiple NVIDIA GPUs | Configs tested with 8 GPUs (num_processes: 8) |
| CUDA | 11.8 | Must match PyTorch CUDA version |
| GPU Memory | 16GB+ per GPU | A100 40GB/80GB recommended for 6B-20B models |
| Disk | 100GB+ SSD | High IOPS for checkpoint saving and CPU offloading |
| Network | High-bandwidth interconnect | NVLink/InfiniBand for multi-GPU communication |
Dependencies
System Packages
- CUDA Toolkit 11.8
- NVIDIA drivers compatible with CUDA 11.8
- NCCL (for multi-GPU communication)
- pdsh (for multi-node launcher, optional)
Python Packages
- All packages from Python_Accelerate environment
- `deepspeed` >= 0.8.1 (tested with 0.10.1)
Credentials
Same as Python_Accelerate environment, plus:
- `ACCELERATE_DEEPSPEED_ZERO_STAGE`: Auto-set by Accelerate to indicate ZeRO stage (used by trlx for `synced_gpus` logic)
- `WORLD_SIZE`: Total number of training processes across all nodes
- `LOCAL_RANK`: GPU rank within the current node
Quick Install
# Install base trlx environment first (see Python_Accelerate)
pip install git+https://github.com/CarperAI/trlx.git
# DeepSpeed is already a core dependency, verify installation:
ds_report
# Launch with Accelerate + DeepSpeed ZeRO-2 (fp16)
accelerate launch --config_file configs/accelerate/zero2-fp16.yaml examples/ppo_sentiments.py
# Launch with Accelerate + DeepSpeed ZeRO-3
accelerate launch --config_file configs/accelerate/zero3.yaml examples/ppo_sentiments.py
Code Evidence
DeepSpeed ZeRO-3 detection for synced generation from `trlx/trainer/accelerate_ppo_trainer.py:89-95`:
generate_kwargs = dict(
do_sample=True,
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3",
)
DeepSpeed plugin configuration extraction from `trlx/utils/__init__.py:68-78`:
if accelerator.state.deepspeed_plugin is not None:
ds_plugin = accelerator.state.deepspeed_plugin
dist_config.update(
{
"gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps,
"gradient_clipping": ds_plugin.gradient_clipping,
"zero_stage": ds_plugin.zero_stage,
"offload_optimizer_device": ds_plugin.offload_optimizer_device,
"offload_param_device": ds_plugin.offload_param_device,
}
)
DeepSpeed fp16 auto_cast fix from `trlx/trainer/accelerate_base_trainer.py:58-61`:
if self.accelerator.state.deepspeed_plugin is not None:
ds_plugin = accelerator.state.deepspeed_plugin
if "fp16" in self.accelerator.state.deepspeed_plugin.deepspeed_config:
self.accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["auto_cast"] = False
ZeRO-3 conditional logic in ILQL model from `trlx/models/modeling_ilql.py:222`:
os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", "0") == "3"
Multi-GPU detection from `trlx/trainer/accelerate_base_trainer.py:63-64`:
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
self.device = int(os.environ.get("LOCAL_RANK", 0))
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| Generation hangs with DeepSpeed ZeRO-3 | `synced_gpus` not enabled | Ensure `ACCELERATE_DEEPSPEED_ZERO_STAGE` is set (auto-set by Accelerate) or use ZeRO-3 config |
| `CUDA out of memory` during training | Model too large even with ZeRO-2 | Switch to ZeRO-3 with CPU offloading (`configs/accelerate/zero3.yaml`) |
| Slow training with CPU offloading | Optimizer/param offloading to CPU | Use SSD-backed offloading or reduce offloading scope |
| `fp16 auto_cast` issues | DeepSpeed fp16 conflicts with model casting | trlx auto-disables `auto_cast` in DeepSpeed fp16 config |
| Reward model OOM during PPO | Reward model on same GPU as policy | Use `torch.cuda.device_count() - 1` to place reward model on last GPU |
Compatibility Notes
- ZeRO Stage 2: Partitions optimizer states and gradients. Good default for 6B-7B models on 8 GPUs. Available in both fp16 (`zero2-fp16.yaml`) and bf16 (`zero2-bf16.yaml`) variants.
- ZeRO Stage 3: Full parameter partitioning. Required for 20B+ models. Enables `zero3_save_16bit_model` for direct checkpoint saving.
- CPU Offloading: Supported for both optimizer and parameter states. Configured in DeepSpeed JSON configs (e.g., `ds_config_trlx_gptj_summarize.json` enables both).
- synced_gpus: Must be `True` for generation with ZeRO-3 to prevent deadlocks. trlx auto-detects this from `ACCELERATE_DEEPSPEED_ZERO_STAGE` environment variable.
- DDP (non-DeepSpeed): Also supported via `configs/accelerate/ddp.yaml` for simpler multi-GPU without ZeRO optimization.
- Multi-node: Supported via SLURM launcher (`scripts/slurm_train.sh`) or Accelerate multi-node config (`scripts/accelerate_train_example.sh`).
Pre-configured Configs
| Config File | Type | Description |
|---|---|---|
| `configs/accelerate/ddp.yaml` | DDP | Multi-GPU without DeepSpeed, bf16, 8 processes |
| `configs/accelerate/zero2-fp16.yaml` | ZeRO-2 | Stage 2 with fp16, no offloading, 8 processes |
| `configs/accelerate/zero2-bf16.yaml` | ZeRO-2 | Stage 2 with bf16, no offloading, 8 processes |
| `configs/accelerate/zero3.yaml` | ZeRO-3 | Stage 3 with bf16, 16-bit model saving, 8 processes |
| `examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json` | ZeRO-2 | Stage 2 with fp16, CPU offloading for optimizer and params |