Implementation:VainF Torch Pruning Head Pruning Config
| Property | Value |
|---|---|
| Source | Torch-Pruning|https://github.com/VainF/Torch-Pruning |
| Domains | Deep_Learning, Transformers, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete pattern for configuring transformer head pruning in Torch-Pruning.
Description
This is a Pattern Doc. Users must construct a num_heads dictionary mapping QKV projection layers to their head count, then pass it along with prune_num_heads/prune_head_dims flags to BasePruner. The pattern varies by model framework (timm vs HuggingFace vs LLM). After pruning, model attention attributes must be updated manually.
The three-step workflow is:
- Build the
num_headsdict by iterating over model modules and identifying attention layers. - Pass the dict to
tp.pruner.BasePruneralong with configuration flags. - Update model attributes after pruning to reflect the new head counts and dimensions.
Code Reference
Source files:
examples/transformers/prune_timm_vit.pylines 129-136 (timm)examples/transformers/prune_hf_vit.pylines 99-108 (HuggingFace)examples/LLMs/prune_llm.pylines 295-312 (LLM)
Import: N/A (user-defined pattern)
Interface Pattern
# Step 1: Build num_heads dict
num_heads = {}
ignored_layers = []
for m in model.modules():
if isinstance(m, AttentionModule):
num_heads[m.qkv] = m.num_heads # timm fused QKV
# OR for HuggingFace separate Q/K/V:
# num_heads[m.query] = heads
# num_heads[m.key] = heads
# num_heads[m.value] = heads
if isinstance(m, ClassificationHead):
ignored_layers.append(m)
# Step 2: Pass to pruner
pruner = tp.pruner.BasePruner(
model, example_inputs, importance,
num_heads=num_heads,
prune_num_heads=True, # remove entire heads
prune_head_dims=False, # don't prune within heads
ignored_layers=ignored_layers,
)
# Step 3: After pruning, update attributes
for m in model.modules():
if isinstance(m, AttentionModule):
m.num_heads = pruner.num_heads[m.qkv] # updated count
I/O Contract
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | num_heads | Dict[nn.Module, int] | Maps each QKV projection layer to its head count |
| Input | prune_num_heads | bool | If True, remove entire attention heads |
| Input | prune_head_dims | bool | If True, reduce dimension within each head |
| Output | pruner.num_heads | Dict[nn.Module, int] | Updated head counts after pruning has been applied |
Usage Examples
timm Vision Transformer
From examples/transformers/prune_timm_vit.py (lines 129-155):
import timm
import torch_pruning as tp
model = timm.create_model("vit_base_patch16_224", pretrained=True)
num_heads = {}
ignored_layers = [model.head]
for m in model.modules():
if isinstance(m, timm.models.vision_transformer.Attention):
num_heads[m.qkv] = m.num_heads
pruner = tp.pruner.BasePruner(
model, example_inputs,
importance=tp.importance.GroupMagnitudeImportance(p=1),
pruning_ratio=0.5,
num_heads=num_heads,
prune_num_heads=True,
prune_head_dims=False,
head_pruning_ratio=0.5,
ignored_layers=ignored_layers,
)
for g in pruner.step(interactive=True):
g.prune()
# Update model attributes after pruning
for m in model.modules():
if isinstance(m, timm.models.vision_transformer.Attention):
m.num_heads = pruner.num_heads[m.qkv]
m.head_dim = m.qkv.out_features // (3 * m.num_heads)
HuggingFace Vision Transformer
From examples/transformers/prune_hf_vit.py (lines 99-147):
from transformers import ViTForImageClassification
from transformers.models.vit.modeling_vit import ViTSelfAttention
import torch_pruning as tp
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
num_heads = {}
ignored_layers = [model.classifier]
for m in model.modules():
if isinstance(m, ViTSelfAttention):
num_heads[m.query] = m.num_attention_heads
num_heads[m.key] = m.num_attention_heads
num_heads[m.value] = m.num_attention_heads
pruner = tp.pruner.BasePruner(
model, example_inputs,
importance=tp.importance.MagnitudeImportance(p=1),
pruning_ratio=0.5,
num_heads=num_heads,
prune_head_dims=True,
prune_num_heads=False,
ignored_layers=ignored_layers,
output_transform=lambda out: out.logits.sum(),
)
for g in pruner.step(interactive=True):
g.prune()
# Update model attributes after pruning
for m in model.modules():
if isinstance(m, ViTSelfAttention):
m.num_attention_heads = pruner.num_heads[m.query]
m.attention_head_size = m.query.out_features // m.num_attention_heads
m.all_head_size = m.query.out_features
LLM (Llama / Phi)
From examples/LLMs/prune_llm.py (lines 295-362):
from transformers import AutoModelForCausalLM
import torch_pruning as tp
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
num_heads = {}
out_channel_groups = {}
for name, m in model.named_modules():
if name.endswith("self_attn"):
if hasattr(m, "q_proj"):
num_heads[m.q_proj] = model.config.num_attention_heads
num_heads[m.k_proj] = model.config.num_key_value_heads
num_heads[m.v_proj] = model.config.num_key_value_heads
elif hasattr(m, "qkv_proj"):
num_heads[m.qkv_proj] = model.config.num_attention_heads
if name.endswith("mlp"):
if hasattr(m, "gate_up_proj"):
out_channel_groups[m.gate_up_proj] = 2
pruner = tp.pruner.BasePruner(
model, example_inputs,
importance=tp.importance.GroupMagnitudeImportance(p=2, group_reduction="mean"),
pruning_ratio=0.5,
num_heads=num_heads,
prune_num_heads=True,
prune_head_dims=False,
head_pruning_ratio=0.5,
out_channel_groups=out_channel_groups,
ignored_layers=[model.lm_head],
)
for g in pruner.step(interactive=True):
g.prune()
# Update model config and module attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if name.endswith("self_attn"):
m.hidden_size = m.q_proj.out_features
m.num_heads = m.hidden_size // m.head_dim
model.config.num_attention_heads = m.num_heads