Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning Head Pruning Config

From Leeroopedia


Property Value
Source Torch-Pruning|https://github.com/VainF/Torch-Pruning
Domains Deep_Learning, Transformers, Pruning
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete pattern for configuring transformer head pruning in Torch-Pruning.

Description

This is a Pattern Doc. Users must construct a num_heads dictionary mapping QKV projection layers to their head count, then pass it along with prune_num_heads/prune_head_dims flags to BasePruner. The pattern varies by model framework (timm vs HuggingFace vs LLM). After pruning, model attention attributes must be updated manually.

The three-step workflow is:

  1. Build the num_heads dict by iterating over model modules and identifying attention layers.
  2. Pass the dict to tp.pruner.BasePruner along with configuration flags.
  3. Update model attributes after pruning to reflect the new head counts and dimensions.

Code Reference

Source files:

  • examples/transformers/prune_timm_vit.py lines 129-136 (timm)
  • examples/transformers/prune_hf_vit.py lines 99-108 (HuggingFace)
  • examples/LLMs/prune_llm.py lines 295-312 (LLM)

Import: N/A (user-defined pattern)

Interface Pattern

# Step 1: Build num_heads dict
num_heads = {}
ignored_layers = []
for m in model.modules():
    if isinstance(m, AttentionModule):
        num_heads[m.qkv] = m.num_heads  # timm fused QKV
        # OR for HuggingFace separate Q/K/V:
        # num_heads[m.query] = heads
        # num_heads[m.key] = heads
        # num_heads[m.value] = heads
    if isinstance(m, ClassificationHead):
        ignored_layers.append(m)

# Step 2: Pass to pruner
pruner = tp.pruner.BasePruner(
    model, example_inputs, importance,
    num_heads=num_heads,
    prune_num_heads=True,   # remove entire heads
    prune_head_dims=False,  # don't prune within heads
    ignored_layers=ignored_layers,
)

# Step 3: After pruning, update attributes
for m in model.modules():
    if isinstance(m, AttentionModule):
        m.num_heads = pruner.num_heads[m.qkv]  # updated count

I/O Contract

Direction Name Type Description
Input num_heads Dict[nn.Module, int] Maps each QKV projection layer to its head count
Input prune_num_heads bool If True, remove entire attention heads
Input prune_head_dims bool If True, reduce dimension within each head
Output pruner.num_heads Dict[nn.Module, int] Updated head counts after pruning has been applied

Usage Examples

timm Vision Transformer

From examples/transformers/prune_timm_vit.py (lines 129-155):

import timm
import torch_pruning as tp

model = timm.create_model("vit_base_patch16_224", pretrained=True)

num_heads = {}
ignored_layers = [model.head]
for m in model.modules():
    if isinstance(m, timm.models.vision_transformer.Attention):
        num_heads[m.qkv] = m.num_heads

pruner = tp.pruner.BasePruner(
    model, example_inputs,
    importance=tp.importance.GroupMagnitudeImportance(p=1),
    pruning_ratio=0.5,
    num_heads=num_heads,
    prune_num_heads=True,
    prune_head_dims=False,
    head_pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

for g in pruner.step(interactive=True):
    g.prune()

# Update model attributes after pruning
for m in model.modules():
    if isinstance(m, timm.models.vision_transformer.Attention):
        m.num_heads = pruner.num_heads[m.qkv]
        m.head_dim = m.qkv.out_features // (3 * m.num_heads)

HuggingFace Vision Transformer

From examples/transformers/prune_hf_vit.py (lines 99-147):

from transformers import ViTForImageClassification
from transformers.models.vit.modeling_vit import ViTSelfAttention
import torch_pruning as tp

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

num_heads = {}
ignored_layers = [model.classifier]
for m in model.modules():
    if isinstance(m, ViTSelfAttention):
        num_heads[m.query] = m.num_attention_heads
        num_heads[m.key] = m.num_attention_heads
        num_heads[m.value] = m.num_attention_heads

pruner = tp.pruner.BasePruner(
    model, example_inputs,
    importance=tp.importance.MagnitudeImportance(p=1),
    pruning_ratio=0.5,
    num_heads=num_heads,
    prune_head_dims=True,
    prune_num_heads=False,
    ignored_layers=ignored_layers,
    output_transform=lambda out: out.logits.sum(),
)

for g in pruner.step(interactive=True):
    g.prune()

# Update model attributes after pruning
for m in model.modules():
    if isinstance(m, ViTSelfAttention):
        m.num_attention_heads = pruner.num_heads[m.query]
        m.attention_head_size = m.query.out_features // m.num_attention_heads
        m.all_head_size = m.query.out_features

LLM (Llama / Phi)

From examples/LLMs/prune_llm.py (lines 295-362):

from transformers import AutoModelForCausalLM
import torch_pruning as tp

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

num_heads = {}
out_channel_groups = {}
for name, m in model.named_modules():
    if name.endswith("self_attn"):
        if hasattr(m, "q_proj"):
            num_heads[m.q_proj] = model.config.num_attention_heads
            num_heads[m.k_proj] = model.config.num_key_value_heads
            num_heads[m.v_proj] = model.config.num_key_value_heads
        elif hasattr(m, "qkv_proj"):
            num_heads[m.qkv_proj] = model.config.num_attention_heads
    if name.endswith("mlp"):
        if hasattr(m, "gate_up_proj"):
            out_channel_groups[m.gate_up_proj] = 2

pruner = tp.pruner.BasePruner(
    model, example_inputs,
    importance=tp.importance.GroupMagnitudeImportance(p=2, group_reduction="mean"),
    pruning_ratio=0.5,
    num_heads=num_heads,
    prune_num_heads=True,
    prune_head_dims=False,
    head_pruning_ratio=0.5,
    out_channel_groups=out_channel_groups,
    ignored_layers=[model.lm_head],
)

for g in pruner.step(interactive=True):
    g.prune()

# Update model config and module attributes
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
    if name.endswith("self_attn"):
        m.hidden_size = m.q_proj.out_features
        m.num_heads = m.hidden_size // m.head_dim
        model.config.num_attention_heads = m.num_heads

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment