Environment:Lm sys FastChat SFT Training Environment
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, NLP |
| Last Updated | 2026-02-07 04:00 GMT |
Overview
Multi-GPU CUDA environment with FSDP, Flash Attention 2.0+, and bfloat16 for full supervised fine-tuning of LLaMA-family models.
Description
This environment provides the full-parameter SFT training stack used by FastChat's Vicuna training pipeline. It uses PyTorch's Fully Sharded Data Parallel (FSDP) for multi-GPU distribution, Flash Attention 2.0+ for memory-efficient attention computation (requires CUDA compute capability >= 8.0), and bfloat16 mixed precision with TF32. The training scripts (`train_mem.py`, `train_xformers.py`) are thin wrappers that apply Flash Attention or xformers monkey patches before calling the main `train.py`.
Usage
Use this environment for full-parameter SFT training of causal language models. It is the mandatory prerequisite for running `train.py`, `train_mem.py` (Flash Attention variant), and `train_xformers.py` (xformers variant). The reference scripts use 4 GPUs for 7B models and 8 GPUs for 13B models.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| OS | Linux (Ubuntu 20.04+) | FSDP requires Linux; macOS not supported for training |
| Hardware | 4x NVIDIA A100 40GB (7B model) | 8x A100 for 13B; CUDA compute capability >= 8.0 for Flash Attention |
| CUDA | 11.8+ | Flash Attention 2.0 requires CUDA 11.8+ |
| Disk | 200GB+ SSD | For model weights, checkpoints, and training data |
| RAM | 256GB+ | 13B training with CPU offload uses half of system RAM |
Dependencies
System Packages
- `cuda-toolkit` >= 11.8 — CUDA runtime
- `nccl` — NVIDIA Collective Communications Library (for FSDP)
Python Packages
- `torch` >= 2.0 — PyTorch with CUDA and FSDP support
- `transformers` >= 4.31.0 — Model loading and Trainer
- `flash-attn` >= 2.0 — Flash Attention (requires CUDA CC >= 8.0)
- `einops` — Tensor reshaping operations
- `wandb` — Experiment tracking (optional but recommended)
- `xformers` — Alternative attention optimization (for `train_xformers.py`)
Credentials
- `WANDB_API_KEY`: Weights & Biases API key for experiment logging (optional)
Quick Install
# Install training dependencies
pip install "fschat[model_worker,train]"
# Flash Attention (requires CUDA 11.8+ and CC >= 8.0)
pip install flash-attn --no-build-isolation
# xformers alternative
pip install xformers
Code Evidence
Flash Attention CUDA capability check from `fastchat/train/llama_flash_attn_monkey_patch.py:97-103`:
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
FSDP model saving from `fastchat/train/train.py:81-89`:
def trainer_save_model_safe(trainer: transformers.Trainer):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(
trainer.model, StateDictType.FULL_STATE_DICT, save_policy
):
trainer.save_model()
RoPE scaling for extended context from `fastchat/train/train.py:271-274`:
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
use_cache disabled during training from `fastchat/train/train.py:275`:
config.use_cache = False
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| `Flash attention is only supported on A100 or H100 GPU` | GPU compute capability < 8.0 | Use `train.py` (without Flash Attention) or `train_xformers.py` instead |
| `CUDA out of memory` during training | Batch size too large or no gradient checkpointing | Add `--gradient_checkpointing True` and reduce `--per_device_train_batch_size` |
| `WARNING: tokenization mismatch` | Token offset mismatch in conversation preprocessing | Ignored automatically; see the -2 offset heuristic in Heuristic:Lm_sys_FastChat_Tokenizer_Offset_Correction |
| FSDP save failure | Incorrect state dict gathering | Use `trainer_save_model_safe()` which gathers full state dict to rank 0 with CPU offload |
Compatibility Notes
- Flash Attention: Requires CUDA compute capability >= 8.0 (A100, H100). For older GPUs (V100, RTX 3090), use `train_xformers.py` with xformers instead.
- FSDP vs DeepSpeed: The SFT scripts use PyTorch FSDP (`--fsdp "full_shard auto_wrap"`). For LoRA training, DeepSpeed is used instead (see Environment:Lm_sys_FastChat_LoRA_QLoRA_Training_Environment).
- 13B models: Require `--fsdp "full_shard auto_wrap offload"` (CPU offload enabled) with 8 GPUs.
- Tokenizer: Uses `use_fast=False` (slow tokenizer) for consistent behavior across models. Padding side is `right`.