Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLM Embedder Llama Patch

From Leeroopedia


Knowledge Sources
Domains Large_Language_Models, Flash_Attention, Model_Optimization
Last Updated 2026-02-09 00:00 GMT

Overview

Monkey-patching utilities for enabling Flash Attention in LLaMA models to accelerate training and inference.

Description

This module provides runtime modifications to LLaMA models for Flash Attention 2 support:

forward() replaces the standard LLaMA attention mechanism with Flash Attention implementation:

  • Transforms query/key/value tensors into packed QKV format expected by Flash Attention
  • Handles variable-length sequences using unpad/pad operations
  • Applies rotary position embeddings before attention
  • Supports causal masking for autoregressive generation
  • Falls back to varlen interface for padded inputs with key_padding_mask

_prepare_decoder_attention_mask() disables the standard 4D attention mask transformation since Flash Attention works directly with 2D key padding masks.

enable_flash_attention() applies the patches either to the model class (before instantiation) or to an existing model instance. It checks GPU compute capability (requires Ampere/Hopper for training due to head_dim > 64 backward pass requirements).

disable_flash_attention() reverts patches by reloading the original transformers code.

upcast_layer_for_flash_attention() ensures LoRA layers, normalization layers, and embedding layers use the correct dtype (fp16/bf16) after quantization, which is necessary for Flash Attention compatibility.

Usage

Use this to significantly accelerate LLaMA training/inference by replacing standard attention with Flash Attention 2, reducing memory usage and improving throughput.

Code Reference

Source Location

Signature

def forward(self, hidden_states, attention_mask, position_ids, past_key_value,
            output_attentions, use_cache)

def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
                                     inputs_embeds, past_key_values_length)

def enable_flash_attention(model=None)
def disable_flash_attention(model=None)
def upcast_layer_for_flash_attention(model, torch_dtype)

Import

from research.llm_embedder.src.utils.llama_patch import enable_flash_attention, upcast_layer_for_flash_attention

I/O Contract

Inputs

Name Type Required Description
model LlamaPreTrainedModel No Existing LLaMA model to patch (None patches class)
torch_dtype torch.dtype Yes Target dtype for upcast (torch.float16 or torch.bfloat16)
hidden_states Tensor Yes Input hidden states [batch, seq_len, hidden_dim]
attention_mask Tensor No 2D key padding mask [batch, seq_len]
position_ids Tensor No Position IDs for RoPE

Outputs

Name Type Description
output Tensor Attention output [batch, seq_len, hidden_dim]
attention_weights None Always None (Flash Attention doesn't return weights)
past_key_value Tuple Cached KV for generation (if use_cache=True)

Usage Examples

from transformers import AutoModelForCausalLM
from research.llm_embedder.src.utils.llama_patch import enable_flash_attention, upcast_layer_for_flash_attention
import torch

# Patch before loading model (class-level)
enable_flash_attention(model=None)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    torch_dtype=torch.bfloat16
)

# Or patch existing model (instance-level)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
enable_flash_attention(model=model)

# After quantization, upcast specific layers
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
model = upcast_layer_for_flash_attention(model, torch.bfloat16)

# Now use model normally - Flash Attention is active
output = model(input_ids=input_ids, attention_mask=attention_mask)

# To disable (e.g., for debugging)
from research.llm_embedder.src.utils.llama_patch import disable_flash_attention
disable_flash_attention(model=model)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment