Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Haotian liu LLaVA Flash Attention GPU Requirement

From Leeroopedia
Knowledge Sources
Domains Optimization, Infrastructure
Last Updated 2026-02-13 23:00 GMT

Overview

Flash Attention requires NVIDIA A100 or H100 GPU (compute capability >= 8.0) for training due to the backward pass limitation with head dimensions > 64.

Description

LLaVA provides two memory-efficient training entry points: `train_mem.py` (Flash Attention 2) and `train_xformers.py` (xformers). Flash Attention dramatically reduces memory usage and speeds up attention computation but has a hardware constraint: the backward pass for head dimensions greater than 64 only works on Ampere (A100) or Hopper (H100) architecture GPUs. The code explicitly checks GPU compute capability and emits a warning if it is less than 8.0. xformers serves as an alternative for older GPUs.

Usage

Apply this heuristic when selecting which training entry point to use. If your GPU is an A100, H100, or newer (compute capability >= 8.0), use `train_mem.py` for Flash Attention 2. If your GPU is older (e.g., V100, RTX 3090), either use standard `train.py` or `train_xformers.py` for memory optimization.

The Insight (Rule of Thumb)

  • Action: Check GPU compute capability before choosing training script. Use `torch.cuda.get_device_capability()` to verify.
  • Value: Compute capability >= 8.0 required (A100 = 8.0, H100 = 9.0).
  • Trade-off: Flash Attention provides ~2x speedup and ~40% memory reduction versus standard attention. xformers is the fallback for older GPUs but provides less speedup.
  • Fallback: Use `train_xformers.py` for memory-efficient training on GPUs with compute capability < 8.0, or standard `train.py` for maximum compatibility.

Reasoning

LLaMA models use head dimensions of 128, which exceeds the 64-head-dim limit for Flash Attention backward pass on pre-Ampere GPUs. The Flash Attention library itself enforces this constraint. The LLaVA codebase provides three training entry points to accommodate different hardware: `train.py` (standard), `train_mem.py` (Flash Attention 2), and `train_xformers.py` (xformers). The V1.5 training scripts default to `train_mem.py` with Flash Attention 2, assuming A100/H100 hardware.

The xformers attention patch includes an explicit comment calling the attention mask detection logic "a nasty hack" — it checks `attention_mask[0, 0, 0, 1] == 0` to determine whether the mask is causal or not, which works but is fragile.

Code Evidence

GPU capability check from `llama_flash_attn_monkey_patch.py:105-111`:

def replace_llama_attn_with_flash_attn():
    cuda_major, cuda_minor = torch.cuda.get_device_capability()
    if cuda_major < 8:
        warnings.warn(
            "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
            "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
        )

Flash Attention import with version fallback from `llama_flash_attn_monkey_patch.py:9-13`:

try:
    from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
    from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input

xformers fallback warning from `llama_xformers_attn_monkey_patch.py:13-16`:

try:
    import xformers.ops
except ImportError:
    logging.error("xformers not found! Please install it before trying to use it.")

Flash Attention 2 training entry from `train_mem.py:1-4`:

from llava.train.train import train

if __name__ == "__main__":
    train(attn_implementation="flash_attention_2")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment