Implementation:Mit han lab Llm awq CLIPAttentionFused
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Vision |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Fused vision attention module that merges separate Q, K, V projection layers into a single matrix multiply for CLIP vision encoders.
Description
CLIPAttentionFused replaces the standard CLIPAttention module by fusing the separate q_proj, k_proj, and v_proj linear layers into a single qkv_proj linear layer. This reduces the number of GEMM calls from three to one per attention layer, improving throughput on GPU hardware. The module preserves the full attention computation including causal and padding attention masks, dropout, and optional attention weight output.
CLIPMLP provides a standard two-layer feed-forward network with configurable activation function (from HuggingFace ACT2FN). CLIPEncoderLayer composes the attention and MLP with layer normalization and residual connections in the standard pre-norm transformer pattern.
The make_fused_vision_attn function traverses a model's module tree, identifies all CLIPAttention instances, concatenates their Q/K/V weight matrices (and biases if present), constructs CLIPAttentionFused replacements, and swaps them in-place. Memory is freed via garbage collection and CUDA cache clearing after each replacement.
Usage
Call make_fused_vision_attn(model, dev) after loading a CLIP-based vision model to fuse all attention projections before inference. This is a one-time model transformation step.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/modules/fused_vision_attn.py
- Lines: 1-272
Signature
class CLIPAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_proj, out_proj, dev,
attention_dropout=0.0):
"""Fused multi-head attention with single QKV projection."""
def forward(self, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False)
-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...
class CLIPMLP(nn.Module):
def __init__(self, config): ...
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ...
class CLIPEncoderLayer(nn.Module):
def __init__(self, config: CLIPConfig): ...
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False) -> Tuple[torch.FloatTensor]: ...
def make_fused_vision_attn(model, dev):
"""Replace all CLIPAttention modules with CLIPAttentionFused, fusing Q/K/V projections."""
Import
from tinychat.modules.fused_vision_attn import make_fused_vision_attn
I/O Contract
Inputs (CLIPAttentionFused.forward)
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input tensor of shape (batch, seq_len, embed_dim) |
| attention_mask | torch.Tensor | No | Padding attention mask of shape (batch, 1, tgt_len, src_len) |
| causal_attention_mask | torch.Tensor | No | Causal mask of shape (batch, 1, tgt_len, src_len) |
| output_attentions | bool | No | Whether to return attention weights (default: False) |
Outputs (CLIPAttentionFused.forward)
| Name | Type | Description |
|---|---|---|
| attn_output | torch.Tensor | Attention output of shape (batch, seq_len, embed_dim) |
| attn_weights_reshaped | Optional[torch.Tensor] | Attention weights if output_attentions=True, else None |
Inputs (make_fused_vision_attn)
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Model containing CLIPAttention modules to fuse |
| dev | torch.device | Yes | Target device to move the fused model to |
Outputs (make_fused_vision_attn)
| Name | Type | Description |
|---|---|---|
| (in-place) | None | Modifies model in-place, replacing CLIPAttention with CLIPAttentionFused; moves model to dev |
Usage Examples
Fuse Vision Attention at Model Load Time
from tinychat.modules.fused_vision_attn import make_fused_vision_attn
import torch
# Load a CLIP-based vision model (e.g., LLaVA vision tower)
vision_model = load_clip_vision_model()
# Fuse Q/K/V projections into single GEMM per layer
make_fused_vision_attn(vision_model, torch.device("cuda:0"))
# Now all CLIPAttention layers use fused QKV projection
# Inference proceeds as normal with improved throughput
output = vision_model(pixel_values)
Direct Construction
from tinychat.modules.fused_vision_attn import CLIPAttentionFused
import torch.nn as nn
# Manually create a fused attention layer
qkv_proj = nn.Linear(768, 768 * 3) # fused Q+K+V
out_proj = nn.Linear(768, 768)
fused_attn = CLIPAttentionFused(
hidden_size=768, num_heads=12,
qkv_proj=qkv_proj, out_proj=out_proj,
dev=torch.device("cuda:0")
)