Heuristic:Haotian liu LLaVA Flash Attention GPU Requirement
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Infrastructure |
| Last Updated | 2026-02-13 23:00 GMT |
Overview
Flash Attention requires NVIDIA A100 or H100 GPU (compute capability >= 8.0) for training due to the backward pass limitation with head dimensions > 64.
Description
LLaVA provides two memory-efficient training entry points: `train_mem.py` (Flash Attention 2) and `train_xformers.py` (xformers). Flash Attention dramatically reduces memory usage and speeds up attention computation but has a hardware constraint: the backward pass for head dimensions greater than 64 only works on Ampere (A100) or Hopper (H100) architecture GPUs. The code explicitly checks GPU compute capability and emits a warning if it is less than 8.0. xformers serves as an alternative for older GPUs.
Usage
Apply this heuristic when selecting which training entry point to use. If your GPU is an A100, H100, or newer (compute capability >= 8.0), use `train_mem.py` for Flash Attention 2. If your GPU is older (e.g., V100, RTX 3090), either use standard `train.py` or `train_xformers.py` for memory optimization.
The Insight (Rule of Thumb)
- Action: Check GPU compute capability before choosing training script. Use `torch.cuda.get_device_capability()` to verify.
- Value: Compute capability >= 8.0 required (A100 = 8.0, H100 = 9.0).
- Trade-off: Flash Attention provides ~2x speedup and ~40% memory reduction versus standard attention. xformers is the fallback for older GPUs but provides less speedup.
- Fallback: Use `train_xformers.py` for memory-efficient training on GPUs with compute capability < 8.0, or standard `train.py` for maximum compatibility.
Reasoning
LLaMA models use head dimensions of 128, which exceeds the 64-head-dim limit for Flash Attention backward pass on pre-Ampere GPUs. The Flash Attention library itself enforces this constraint. The LLaVA codebase provides three training entry points to accommodate different hardware: `train.py` (standard), `train_mem.py` (Flash Attention 2), and `train_xformers.py` (xformers). The V1.5 training scripts default to `train_mem.py` with Flash Attention 2, assuming A100/H100 hardware.
The xformers attention patch includes an explicit comment calling the attention mask detection logic "a nasty hack" — it checks `attention_mask[0, 0, 0, 1] == 0` to determine whether the mask is causal or not, which works but is fragile.
Code Evidence
GPU capability check from `llama_flash_attn_monkey_patch.py:105-111`:
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
Flash Attention import with version fallback from `llama_flash_attn_monkey_patch.py:9-13`:
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
xformers fallback warning from `llama_xformers_attn_monkey_patch.py:13-16`:
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
Flash Attention 2 training entry from `train_mem.py:1-4`:
from llava.train.train import train
if __name__ == "__main__":
train(attn_implementation="flash_attention_2")