Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:VainF Torch Pruning Transformer Head Pruning

From Leeroopedia


Property Value
Papers DepGraph, Are Sixteen Heads Really Better than One? (Michel et al., 2019)
Domains Deep_Learning, Transformers, Pruning
Last Updated 2026-02-08 00:00 GMT

Overview

Configuration and execution of attention head pruning in transformer architectures, supporting both head dimension reduction and entire head removal.

Description

Transformer models use multi-head attention (MHA) with Q/K/V projections. Pruning these requires special handling because channels are grouped by heads. Two strategies exist:

  1. prune_head_dims -- reduces the dimension within each head. The head_dim decreases while num_heads stays the same.
  2. prune_num_heads -- removes entire attention heads. The num_heads decreases while head_dim stays the same.

The pruner needs a num_heads dictionary mapping QKV projection layers to their head count. The structure of this dictionary varies by model framework:

  • HuggingFace models -- Q, K, and V are separate nn.Linear layers (e.g., m.query, m.key, m.value). Each must be registered individually in the num_heads dict.
  • timm models -- Q, K, and V are often fused into a single qkv projection layer. Only the fused layer needs to be registered.
  • LLMs (Llama, Phi, Qwen) -- Separate q_proj, k_proj, v_proj layers, and sometimes a fused qkv_proj. K/V may use fewer heads than Q when Grouped Query Attention (GQA) is employed.

Post-pruning, model attributes (num_heads, head_dim, all_head_size) must be updated manually by the user. The pruner tracks updated head counts in pruner.num_heads, but the model's own attributes are not automatically synchronized.

Usage

Use when pruning Vision Transformers (ViT, DeiT, Swin), BERT, or LLMs (Llama, Phi). Required for any transformer-based pruning workflow. The choice between prune_head_dims and prune_num_heads depends on the downstream requirements:

  • prune_num_heads=True is preferred for LLMs because it avoids disrupting positional encodings like RoPE.
  • prune_head_dims=True is useful for vision transformers where per-head dimension reduction is acceptable.

Theoretical Basis

Multi-head attention is defined as:

MHA(Q, K, V) = Concat(head_1, ..., head_h) W^O

where each head computes:

head_i = Attention(Q W^Q_i, K W^K_i, V W^V_i)

Head dimension pruning reduces d_k in the projection matrices W^Q_i in R^{d_model x d_k}. The number of heads h remains constant, but each head operates on a smaller subspace.

Head removal eliminates entire head_i terms from the concatenation. The remaining heads keep their full d_k dimension.

Grouped Query Attention (GQA) is used in models like Llama-2 and Llama-3, where K/V share fewer heads than Q. This requires the out_channel_groups mapping so the pruner understands that multiple Q heads correspond to a single K/V head. In the codebase, this is handled by setting:

num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment