Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Turboderp org Exllamav2 Ext QAttn

From Leeroopedia
Knowledge Sources
Domains Attention, Quantization, CUDA, Tensor_Parallelism
Last Updated 2026-02-15 00:00 GMT

Overview

C++ extension implementing quantized multi-head attention with fused LayerNorm, RoPE, Q/K/V projections, output projection, and optional LoRA adapters, including tensor-parallel variants with paged and standard KV cache support.

Description

ext_qattn.cpp provides the core attention computation pipeline for ExLlamaV2 transformer layers. The implementation is split into construction, two forward phases, LoRA configuration, and tensor-parallel variants:

  • make_q_attn -- Constructs a QAttn object encapsulating the full attention layer state: layernorm weights (with optional bias), Q/K/V/O projection matrices (as QMatrix* handles), temporary buffers, RoPE configuration, head/query norms, and residual handling. Returns an opaque uintptr_t handle. Supports both RMS norm and standard layer norm, FP32 residuals, and optional CUDA graph capture.
  • q_attn_forward_1 -- Executes the first phase of attention: applies LayerNorm to the input, projects through Q, K, V quantized matrices, and applies Rotary Position Embeddings (RoPE) using precomputed sin/cos tables. Accepts per-sequence past lengths for variable-length batches. Optionally applies LoRA adapters during projection.
  • q_attn_forward_2 -- Executes the second phase: takes the attention output (computed externally by Flash Attention), projects through the O matrix, and adds the residual connection. Also supports LoRA on the output projection.
  • q_attn_set_loras -- Configures LoRA adapter matrices for all four projections (Q, K, V, O). Each adapter is stored as a tuple of (A matrix pointer, B matrix pointer, rank). Returns the maximum rank across all adapters for buffer sizing.
  • tp_attn_forward_paged_ -- Tensor-parallel attention with paged KV cache. Broadcasts hidden states across devices, applies per-device layernorm and Q/K/V projections, performs RoPE, calls flash_attn_2_cuda.fwd_kvcache with block tables, then gathers and projects through O with residual addition. Supports multithreaded execution across devices via a thread pool and barrier synchronization.
  • tp_attn_forward_ -- Same as the paged variant but uses standard (contiguous) KV cache tensors without block tables.

Both tensor-parallel functions import the Flash Attention 2 CUDA kernel dynamically at runtime via py::module_::import("flash_attn_2_cuda").

Usage

Use make_q_attn during model initialization to create the attention layer handle, then call q_attn_forward_1 and q_attn_forward_2 during the forward pass (with Flash Attention called between them). Use q_attn_set_loras when loading LoRA adapters. Use the tp_attn_forward_* variants for multi-GPU tensor-parallel inference.

Code Reference

Source Location

Signature

uintptr_t make_q_attn(
    torch::Tensor layernorm,
    torch::Tensor layernorm_bias,
    bool layernorm_is_rms,
    bool headnorm_is_rms,
    float norm_epsilon,
    uintptr_t q_q_proj,
    uintptr_t q_k_proj,
    uintptr_t q_v_proj,
    uintptr_t q_o_proj,
    torch::Tensor temp_state,
    torch::Tensor temp_dq,
    int max_rows,
    int hidden_size,
    int num_heads,
    int num_kv_heads,
    int head_dim,
    int max_seq_len,
    bool has_residual,
    int rope_style,
    int sincos_size,
    torch::Tensor q_norm,
    torch::Tensor k_norm,
    torch::Tensor post_layernorm,
    torch::Tensor post_layernorm_bias,
    bool residual_fp32,
    bool use_graphs
);

void q_attn_forward_1(
    uintptr_t q_attn,
    torch::Tensor x,
    int batch_size,
    int q_len,
    int past_len,
    torch::Tensor past_lens,
    torch::Tensor q_temp,
    torch::Tensor k_temp,
    torch::Tensor v_temp,
    torch::Tensor sin,
    torch::Tensor cos,
    const std::vector<uintptr_t>& loras,
    torch::Tensor loras_temp
);

void q_attn_forward_2(
    uintptr_t q_attn,
    torch::Tensor x,
    torch::Tensor attn_output,
    int batch_size,
    int q_len,
    const std::vector<uintptr_t>& loras,
    torch::Tensor loras_temp
);

int q_attn_set_loras(
    uintptr_t q_attn,
    std::unordered_map<uintptr_t, torch::Tensor>& q_proj_lora_a,
    std::unordered_map<uintptr_t, torch::Tensor>& q_proj_lora_b,
    std::unordered_map<uintptr_t, torch::Tensor>& k_proj_lora_a,
    std::unordered_map<uintptr_t, torch::Tensor>& k_proj_lora_b,
    std::unordered_map<uintptr_t, torch::Tensor>& v_proj_lora_a,
    std::unordered_map<uintptr_t, torch::Tensor>& v_proj_lora_b,
    std::unordered_map<uintptr_t, torch::Tensor>& o_proj_lora_a,
    std::unordered_map<uintptr_t, torch::Tensor>& o_proj_lora_b
);

void tp_attn_forward_paged_(
    uintptr_t tp_context,
    torch::Tensor hidden_states,
    const std::vector<torch::Tensor> &temp_bc0_,
    const std::vector<torch::Tensor> &temp_bc1_,
    const std::vector<torch::Tensor> &temp_bc2_,
    const std::vector<torch::Tensor> &temp_q_,
    const std::vector<torch::Tensor> &temp_k_,
    const std::vector<torch::Tensor> &temp_v_,
    const std::vector<torch::Tensor> &temp_o_,
    const std::vector<torch::Tensor> &k_cache,
    const std::vector<torch::Tensor> &v_cache,
    const std::vector<torch::Tensor> &pre_layernorm,
    float norm_epsilon,
    const std::vector<uintptr_t> &q_proj,
    const std::vector<uintptr_t> &k_proj,
    const std::vector<uintptr_t> &v_proj,
    const std::vector<uintptr_t> &o_proj,
    int head_dim,
    int rope_style,
    int batch_size,
    int q_len,
    const std::vector<torch::Tensor> &sin,
    const std::vector<torch::Tensor> &cos,
    const std::vector<torch::Tensor> &past_lens,
    const std::vector<torch::Tensor> &block_index,
    float scaling
);

void tp_attn_forward_(
    uintptr_t tp_context,
    torch::Tensor hidden_states,
    /* ... same temp buffers and projections as paged variant ... */
    const std::vector<torch::Tensor> &past_len_tp,
    float scaling
);

Import

from exllamav2.ext import exllamav2_ext as ext_c

I/O Contract

Inputs

Parameter Type Description
layernorm torch.Tensor (kHalf) Pre-attention layer norm weights
q_q_proj, q_k_proj, q_v_proj, q_o_proj uintptr_t QMatrix handles for Q, K, V, O projection matrices
x torch.Tensor (kHalf or kFloat) Input hidden states; kFloat when residual_fp32=true
batch_size, q_len, past_len int Batch size, query length, and past context length
sin, cos torch.Tensor (kHalf) Precomputed RoPE sin/cos tables
loras std::vector<uintptr_t> Active LoRA adapter handles
attn_output torch.Tensor (kHalf) Output from Flash Attention (input to forward_2)
tp_context uintptr_t Tensor parallelism context handle (TP variants)
block_index std::vector<torch::Tensor> Per-device block tables (paged TP variant)
scaling float Attention softmax scaling factor (typically 1/sqrt(head_dim))

Outputs

Function Return Description
make_q_attn uintptr_t Opaque handle to the constructed QAttn object
q_attn_forward_1 void Writes Q, K, V projections into temp tensors (with RoPE applied)
q_attn_forward_2 void Applies O projection and adds residual; modifies x in-place
q_attn_set_loras int Maximum LoRA rank across all configured adapters
tp_attn_forward_paged_ void Writes final output (with residual) into the output tensors
tp_attn_forward_ void Same as paged variant but for contiguous KV cache

Usage Examples

from exllamav2.ext import exllamav2_ext as ext_c

# Create attention layer during model init
attn_handle = ext_c.make_q_attn(
    layernorm, layernorm_bias, True, False, 1e-6,
    q_proj_handle, k_proj_handle, v_proj_handle, o_proj_handle,
    temp_state, temp_dq, max_rows,
    hidden_size, num_heads, num_kv_heads, head_dim, max_seq_len,
    True, rope_style, sincos_size,
    q_norm, k_norm, post_layernorm, post_layernorm_bias,
    False, False
)

# Forward pass phase 1: LayerNorm -> Q,K,V -> RoPE
ext_c.q_attn_forward_1(
    attn_handle, hidden_states, batch_size, q_len, past_len,
    past_lens, q_temp, k_temp, v_temp, sin, cos, loras, loras_temp
)

# ... Flash Attention computes attn_output from q_temp, k_cache, v_cache ...

# Forward pass phase 2: O projection + residual
ext_c.q_attn_forward_2(
    attn_handle, hidden_states, attn_output, batch_size, q_len,
    loras, loras_temp
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment