Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq CLIPAttentionFused

From Leeroopedia
Knowledge Sources
Domains Optimization, Vision
Last Updated 2026-02-15 00:00 GMT

Overview

Fused vision attention module that merges separate Q, K, V projection layers into a single matrix multiply for CLIP vision encoders.

Description

CLIPAttentionFused replaces the standard CLIPAttention module by fusing the separate q_proj, k_proj, and v_proj linear layers into a single qkv_proj linear layer. This reduces the number of GEMM calls from three to one per attention layer, improving throughput on GPU hardware. The module preserves the full attention computation including causal and padding attention masks, dropout, and optional attention weight output.

CLIPMLP provides a standard two-layer feed-forward network with configurable activation function (from HuggingFace ACT2FN). CLIPEncoderLayer composes the attention and MLP with layer normalization and residual connections in the standard pre-norm transformer pattern.

The make_fused_vision_attn function traverses a model's module tree, identifies all CLIPAttention instances, concatenates their Q/K/V weight matrices (and biases if present), constructs CLIPAttentionFused replacements, and swaps them in-place. Memory is freed via garbage collection and CUDA cache clearing after each replacement.

Usage

Call make_fused_vision_attn(model, dev) after loading a CLIP-based vision model to fuse all attention projections before inference. This is a one-time model transformation step.

Code Reference

Source Location

Signature

class CLIPAttentionFused(nn.Module):
    def __init__(self, hidden_size, num_heads, qkv_proj, out_proj, dev,
                 attention_dropout=0.0):
        """Fused multi-head attention with single QKV projection."""
    def forward(self, hidden_states: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                causal_attention_mask: Optional[torch.Tensor] = None,
                output_attentions: Optional[bool] = False)
        -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...

class CLIPMLP(nn.Module):
    def __init__(self, config): ...
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ...

class CLIPEncoderLayer(nn.Module):
    def __init__(self, config: CLIPConfig): ...
    def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor,
                causal_attention_mask: torch.Tensor,
                output_attentions: Optional[bool] = False) -> Tuple[torch.FloatTensor]: ...

def make_fused_vision_attn(model, dev):
    """Replace all CLIPAttention modules with CLIPAttentionFused, fusing Q/K/V projections."""

Import

from tinychat.modules.fused_vision_attn import make_fused_vision_attn

I/O Contract

Inputs (CLIPAttentionFused.forward)

Name Type Required Description
hidden_states torch.Tensor Yes Input tensor of shape (batch, seq_len, embed_dim)
attention_mask torch.Tensor No Padding attention mask of shape (batch, 1, tgt_len, src_len)
causal_attention_mask torch.Tensor No Causal mask of shape (batch, 1, tgt_len, src_len)
output_attentions bool No Whether to return attention weights (default: False)

Outputs (CLIPAttentionFused.forward)

Name Type Description
attn_output torch.Tensor Attention output of shape (batch, seq_len, embed_dim)
attn_weights_reshaped Optional[torch.Tensor] Attention weights if output_attentions=True, else None

Inputs (make_fused_vision_attn)

Name Type Required Description
model nn.Module Yes Model containing CLIPAttention modules to fuse
dev torch.device Yes Target device to move the fused model to

Outputs (make_fused_vision_attn)

Name Type Description
(in-place) None Modifies model in-place, replacing CLIPAttention with CLIPAttentionFused; moves model to dev

Usage Examples

Fuse Vision Attention at Model Load Time

from tinychat.modules.fused_vision_attn import make_fused_vision_attn
import torch

# Load a CLIP-based vision model (e.g., LLaVA vision tower)
vision_model = load_clip_vision_model()

# Fuse Q/K/V projections into single GEMM per layer
make_fused_vision_attn(vision_model, torch.device("cuda:0"))

# Now all CLIPAttention layers use fused QKV projection
# Inference proceeds as normal with improved throughput
output = vision_model(pixel_values)

Direct Construction

from tinychat.modules.fused_vision_attn import CLIPAttentionFused
import torch.nn as nn

# Manually create a fused attention layer
qkv_proj = nn.Linear(768, 768 * 3)  # fused Q+K+V
out_proj = nn.Linear(768, 768)
fused_attn = CLIPAttentionFused(
    hidden_size=768, num_heads=12,
    qkv_proj=qkv_proj, out_proj=out_proj,
    dev=torch.device("cuda:0")
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment