Principle:VainF Torch Pruning Transformer Head Pruning
| Property | Value |
|---|---|
| Papers | DepGraph, Are Sixteen Heads Really Better than One? (Michel et al., 2019) |
| Domains | Deep_Learning, Transformers, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Configuration and execution of attention head pruning in transformer architectures, supporting both head dimension reduction and entire head removal.
Description
Transformer models use multi-head attention (MHA) with Q/K/V projections. Pruning these requires special handling because channels are grouped by heads. Two strategies exist:
- prune_head_dims -- reduces the dimension within each head. The
head_dimdecreases whilenum_headsstays the same. - prune_num_heads -- removes entire attention heads. The
num_headsdecreases whilehead_dimstays the same.
The pruner needs a num_heads dictionary mapping QKV projection layers to their head count. The structure of this dictionary varies by model framework:
- HuggingFace models -- Q, K, and V are separate
nn.Linearlayers (e.g.,m.query,m.key,m.value). Each must be registered individually in thenum_headsdict. - timm models -- Q, K, and V are often fused into a single
qkvprojection layer. Only the fused layer needs to be registered. - LLMs (Llama, Phi, Qwen) -- Separate
q_proj,k_proj,v_projlayers, and sometimes a fusedqkv_proj. K/V may use fewer heads than Q when Grouped Query Attention (GQA) is employed.
Post-pruning, model attributes (num_heads, head_dim, all_head_size) must be updated manually by the user. The pruner tracks updated head counts in pruner.num_heads, but the model's own attributes are not automatically synchronized.
Usage
Use when pruning Vision Transformers (ViT, DeiT, Swin), BERT, or LLMs (Llama, Phi). Required for any transformer-based pruning workflow. The choice between prune_head_dims and prune_num_heads depends on the downstream requirements:
- prune_num_heads=True is preferred for LLMs because it avoids disrupting positional encodings like RoPE.
- prune_head_dims=True is useful for vision transformers where per-head dimension reduction is acceptable.
Theoretical Basis
Multi-head attention is defined as:
MHA(Q, K, V) = Concat(head_1, ..., head_h) W^O
where each head computes:
head_i = Attention(Q W^Q_i, K W^K_i, V W^V_i)
Head dimension pruning reduces d_k in the projection matrices W^Q_i in R^{d_model x d_k}. The number of heads h remains constant, but each head operates on a smaller subspace.
Head removal eliminates entire head_i terms from the concatenation. The remaining heads keep their full d_k dimension.
Grouped Query Attention (GQA) is used in models like Llama-2 and Llama-3, where K/V share fewer heads than Q. This requires the out_channel_groups mapping so the pruner understands that multiple Q heads correspond to a single K/V head. In the codebase, this is handled by setting:
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads