Heuristic:Lm sys FastChat Flash Attention GPU Requirements
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-07 04:00 GMT |
Overview
Flash Attention in FastChat requires CUDA compute capability >= 8.0 (A100/H100) due to the backward pass limitation with head dimensions > 64, and replaces the standard attention mask with a key padding mask.
Description
FastChat provides two Flash Attention monkey patches for LLaMA models: one for LLaMA-1 (`llama_flash_attn_monkey_patch.py`) using `flash_attn_varlen_qkvpacked_func`, and one for LLaMA-2 (`llama2_flash_attn_monkey_patch.py`) using `flash_attn_func` and `flash_attn_varlen_kvpacked_func`. Both patches replace the standard `LlamaAttention.forward` and critically also replace `LlamaModel._prepare_decoder_attention_mask` to pass through the raw attention mask instead of expanding it into a 4D causal mask. This is because Flash Attention requires the mask as a simple key padding mask, not the expanded format.
Usage
Use this heuristic when deciding whether to enable Flash Attention for training or when debugging attention-related training failures. Flash Attention provides ~2x memory savings but only works on Ampere+ GPUs. For older GPUs, use xformers instead via `train_xformers.py`.
The Insight (Rule of Thumb)
- Action: Enable Flash Attention via `train_mem.py` (LLaMA-1) or `--flash_attn True` (LoRA training)
- Value: Requires CUDA compute capability >= 8.0 (A100, H100); `flash-attn >= 2.0` (LLaMA-1) or `flash-attn >= 2.1.0` (LLaMA-2 with past_key_value)
- Trade-off: Significant memory reduction and speedup, but limited to Ampere/Hopper GPUs. On older GPUs (V100, RTX 3090), a warning is printed but execution continues (may produce incorrect gradients for head_dim > 64).
- Critical side effect: The attention mask preparation is monkey-patched to return the raw mask instead of the expanded 4D causal mask. This is essential for Flash Attention but breaks if other code expects the standard expanded mask format.
Reasoning
Flash Attention avoids materializing the full N x N attention matrix, reducing memory from O(N^2) to O(N). However, the backward pass for head dimensions > 64 requires hardware features only available on Ampere+ architecture (SM 8.0+). The monkey patching approach is used because the Flash Attention API differs from standard PyTorch attention: it takes packed QKV tensors and cumulative sequence lengths rather than separate Q/K/V with 4D masks.
The attention mask change is a subtle but critical detail: the standard HuggingFace LlamaModel expands the attention mask into a float 4D tensor for additive masking, but Flash Attention uses a boolean key padding mask for efficient variable-length batching. Without the mask patch, training would be both slower and use more memory.
When used with LoRA, Flash Attention also requires explicit dtype casting of norm layers, lm_head, and embed_tokens to the compute dtype (fp16/bf16) to avoid mixed precision errors.
Code Evidence
CUDA capability check from `fastchat/train/llama_flash_attn_monkey_patch.py:97-103`:
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
Attention mask bypass from `fastchat/train/llama_flash_attn_monkey_patch.py:88-94`:
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
LLaMA-2 version gate from `fastchat/train/llama2_flash_attn_monkey_patch.py:70-72`:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
LoRA dtype casting for Flash Attention from `fastchat/train/train_lora.py:166-172`:
if training_args.flash_attn:
for name, module in model.named_modules():
if "norm" in name:
module = module.to(compute_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module = module.to(compute_dtype)