Implementation:FlagOpen FlagEmbedding LLM Embedder Llama Patch
| Knowledge Sources | |
|---|---|
| Domains | Large_Language_Models, Flash_Attention, Model_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Monkey-patching utilities for enabling Flash Attention in LLaMA models to accelerate training and inference.
Description
This module provides runtime modifications to LLaMA models for Flash Attention 2 support:
forward() replaces the standard LLaMA attention mechanism with Flash Attention implementation:
- Transforms query/key/value tensors into packed QKV format expected by Flash Attention
- Handles variable-length sequences using unpad/pad operations
- Applies rotary position embeddings before attention
- Supports causal masking for autoregressive generation
- Falls back to varlen interface for padded inputs with key_padding_mask
_prepare_decoder_attention_mask() disables the standard 4D attention mask transformation since Flash Attention works directly with 2D key padding masks.
enable_flash_attention() applies the patches either to the model class (before instantiation) or to an existing model instance. It checks GPU compute capability (requires Ampere/Hopper for training due to head_dim > 64 backward pass requirements).
disable_flash_attention() reverts patches by reloading the original transformers code.
upcast_layer_for_flash_attention() ensures LoRA layers, normalization layers, and embedding layers use the correct dtype (fp16/bf16) after quantization, which is necessary for Flash Attention compatibility.
Usage
Use this to significantly accelerate LLaMA training/inference by replacing standard attention with Flash Attention 2, reducing memory usage and improving throughput.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_embedder/src/utils/llama_patch.py
- Lines: 1-184
Signature
def forward(self, hidden_states, attention_mask, position_ids, past_key_value,
output_attentions, use_cache)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length)
def enable_flash_attention(model=None)
def disable_flash_attention(model=None)
def upcast_layer_for_flash_attention(model, torch_dtype)
Import
from research.llm_embedder.src.utils.llama_patch import enable_flash_attention, upcast_layer_for_flash_attention
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | LlamaPreTrainedModel | No | Existing LLaMA model to patch (None patches class) |
| torch_dtype | torch.dtype | Yes | Target dtype for upcast (torch.float16 or torch.bfloat16) |
| hidden_states | Tensor | Yes | Input hidden states [batch, seq_len, hidden_dim] |
| attention_mask | Tensor | No | 2D key padding mask [batch, seq_len] |
| position_ids | Tensor | No | Position IDs for RoPE |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor | Attention output [batch, seq_len, hidden_dim] |
| attention_weights | None | Always None (Flash Attention doesn't return weights) |
| past_key_value | Tuple | Cached KV for generation (if use_cache=True) |
Usage Examples
from transformers import AutoModelForCausalLM
from research.llm_embedder.src.utils.llama_patch import enable_flash_attention, upcast_layer_for_flash_attention
import torch
# Patch before loading model (class-level)
enable_flash_attention(model=None)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
torch_dtype=torch.bfloat16
)
# Or patch existing model (instance-level)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
enable_flash_attention(model=model)
# After quantization, upcast specific layers
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
model = upcast_layer_for_flash_attention(model, torch.bfloat16)
# Now use model normally - Flash Attention is active
output = model(input_ids=input_ids, attention_mask=attention_mask)
# To disable (e.g., for debugging)
from research.llm_embedder.src.utils.llama_patch import disable_flash_attention
disable_flash_attention(model=model)