Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Predibase Lorax Flash Attention Backend Selection

From Leeroopedia



Knowledge Sources
Domains Optimization, LLMs
Last Updated 2026-02-08 02:30 GMT

Overview

Multi-tier attention backend selection that dispatches between FlashInfer, Flash Attention V2, Flash Attention V1, and Triton fallback based on GPU architecture, installed packages, and feature requirements (prefix caching, window attention).

Description

LoRAX implements a cascading attention backend selection at import time in `flash_attn.py`. The selection determines the `attention()` function used for all model inference. The hierarchy is:

  1. FlashInfer: Used when `FLASH_INFER=1` or `PREFIX_CACHING=1`. Supports paged KV cache with block_size=1 (vs 16 for FA2), enabling fine-grained prefix caching. Required for chunked prefill.
  2. Flash Attention V2 (CUDA): Default for NVIDIA SM 8.0+ GPUs. Block_size=16 paged KV cache. Supports window attention (sliding window).
  3. Flash Attention V2 (ROCm CK): Default for AMD MI210/MI250. Uses Composable Kernel backend. Does NOT support window attention.
  4. Flash Attention V2 (ROCm Triton): Alternative ROCm backend. Enabled via `ROCM_USE_FLASH_ATTN_V2_TRITON=true`.
  5. Flash Attention V1: Fallback for SM 7.5 GPUs (Turing). Requires head expansion for MQA/GQA (V2 handles this natively). No window attention support.
  6. Intel XPU: Uses `ipex.llm.functional.varlen_attention`. No window attention.

Usage

This heuristic applies at server startup when the attention backend is selected. The choice affects all subsequent inference operations. Key decision points:

  • Need prefix caching? Must use FlashInfer (`PREFIX_CACHING=1` auto-enables it)
  • Need chunked prefill? Must use FlashInfer (`CHUNKED_PREFILL=1` requires `FLASH_INFER=1`)
  • Using sliding window models (Mistral)? Requires Flash Attn V2 or FlashInfer (not V1, not ROCm CK)
  • Using Turing GPU (T4)? Limited to Flash Attn V1 with MQA/GQA expansion overhead

The Insight (Rule of Thumb)

  • Action: For production on Ampere+ GPUs, use default Flash Attention V2 unless you need prefix caching (then set `FLASH_INFER=1`). For Turing GPUs, accept Flash Attn V1 with reduced feature set.
  • Value: FlashInfer block_size=1 enables token-level prefix sharing. FA2 block_size=16 is coarser but more memory-efficient.
  • Trade-off: FlashInfer adds prefix caching capability and chunked prefill but requires the flashinfer package (cu124 build). Flash Attn V1 requires expanding KV heads for MQA/GQA models, increasing memory usage.
  • ROCm note: Window attention not supported on AMD GPUs with CK backend. Use Triton backend if needed.

Reasoning

The attention computation is the most latency-critical operation in transformer inference. Flash Attention reorganizes the computation to be I/O-aware, reducing HBM reads from O(N^2) to O(N) by tiling. Each backend variant optimizes for different hardware:

  • V2 vs V1: V2 natively handles GQA (grouped-query attention) by broadcasting KV heads internally. V1 requires explicit head expansion (doubling KV memory for GQA models). V2 also supports window attention for sliding-window models like Mistral.
  • FlashInfer vs FA2: FlashInfer uses block_size=1 paged KV cache, enabling token-granularity prefix sharing via radix tree. FA2 uses block_size=16, meaning 16 tokens must share the same cache block.
  • ROCm CK vs Triton: CK (Composable Kernel) is AMD's optimized kernel library. Triton is portable but may have lower peak performance.

Code evidence from `server/lorax_server/utils/flash_attn.py:11-114`:

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
    raise ImportError("`USE_FLASH_ATTENTION` is false.")

# Priority: FlashInfer > FA2 CUDA > FA2 ROCm CK > FA2 ROCm Triton > FA1 > Error

Flash Attn V1 MQA/GQA expansion from `server/lorax_server/utils/flash_attn.py:301-313`:

# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
    if k.shape[1] == 1:  # MQA expand
        k = k.expand(-1, q.shape[1], -1)
    else:  # Grouped attention reshape
        k = (k.unsqueeze(2)
             .expand(-1, -1, q.shape[1] // k.shape[1], -1)
             .reshape(original_shape[0], -1, original_shape[2]))

Backend selection in state from `server/lorax_server/utils/state.py:17-38`:

# Always use flashinfer when prefix caching is enabled
FLASH_INFER = bool(int(os.environ.get("FLASH_INFER", "0"))) or PREFIX_CACHING

BLOCK_SIZE: int
if FLASH_INFER:
    BLOCK_SIZE = 1
else:
    BLOCK_SIZE = 16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment