Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning LLM Config Update Pattern

From Leeroopedia


Metadata

Field Value
Source Torch-Pruning
Domains NLP, Model_Compression
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete pattern for updating HuggingFace LLM configuration after structural pruning provided by Torch-Pruning examples.

Description

This is a Pattern Doc. After pruning an LLM with BasePruner, users must iterate through model.named_modules() to update all config attributes. The pattern handles:

  • Separate QKV projections (Llama-style) -- q_proj, k_proj, v_proj as individual linear layers
  • Fused QKV (some models) -- a single combined projection
  • Separate gate/up projections -- gate_proj and up_proj as individual linear layers
  • Fused gate_up projections -- a single gate_up_proj linear layer
  • Grouped Query Attention (GQA) -- where num_key_value_heads differs from num_attention_heads

Code Reference

Source

examples/LLMs/prune_llm.py, Lines 347-374

Interface

# After pruning, update configuration
model.config.hidden_size = model.lm_head.in_features

for name, m in model.named_modules():
    # Update attention configs
    if hasattr(m, 'self_attn'):
        # Separate Q/K/V
        if hasattr(m.self_attn, 'q_proj'):
            m.self_attn.num_heads = m.self_attn.q_proj.out_features // m.self_attn.head_dim
            if _is_gqa:
                m.self_attn.num_key_value_heads = m.self_attn.k_proj.out_features // m.self_attn.head_dim
        m.self_attn.hidden_size = m.self_attn.q_proj.out_features

    # Update MLP configs
    if hasattr(m, 'mlp'):
        if hasattr(m.mlp, 'gate_proj'):
            model.config.intermediate_size = m.mlp.gate_proj.out_features
        elif hasattr(m.mlp, 'gate_up_proj'):
            model.config.intermediate_size = m.mlp.gate_up_proj.out_features // 2

model.config.num_attention_heads = m.self_attn.num_heads
if _is_gqa:
    model.config.num_key_value_heads = m.self_attn.num_key_value_heads

Import

N/A (user-defined pattern, from example script)

I/O Contract

Inputs

  • Pruned model -- a HuggingFace model whose weights have been structurally pruned but whose model.config still reflects the original unpruned dimensions
  • _is_gqa (boolean) -- whether the model uses Grouped Query Attention (i.e., num_key_value_heads != num_attention_heads)
  • seperate_qkv (boolean) -- whether the model uses separate Q/K/V projections or a fused QKV layer

Outputs

  • model.config attributes updated to match pruned weight dimensions:
    • model.config.hidden_size
    • model.config.num_attention_heads
    • model.config.num_key_value_heads (if GQA)
    • model.config.intermediate_size

Usage Examples

Full config update after LLM pruning, followed by saving the model:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch_pruning as tp

# Load model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# --- Pruning step (abbreviated) ---
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs=example_inputs,
    importance=tp.importance.MagnitudeImportance(),
    pruning_ratio=0.5,
)
pruner.step()

# --- Config update step (mandatory before saving) ---
_is_gqa = hasattr(model.config, "num_key_value_heads") and \
          model.config.num_key_value_heads != model.config.num_attention_heads

model.config.hidden_size = model.lm_head.in_features

for name, m in model.named_modules():
    if hasattr(m, 'self_attn'):
        if hasattr(m.self_attn, 'q_proj'):
            m.self_attn.num_heads = m.self_attn.q_proj.out_features // m.self_attn.head_dim
            if _is_gqa:
                m.self_attn.num_key_value_heads = m.self_attn.k_proj.out_features // m.self_attn.head_dim
        m.self_attn.hidden_size = m.self_attn.q_proj.out_features

    if hasattr(m, 'mlp'):
        if hasattr(m.mlp, 'gate_proj'):
            model.config.intermediate_size = m.mlp.gate_proj.out_features
        elif hasattr(m.mlp, 'gate_up_proj'):
            model.config.intermediate_size = m.mlp.gate_up_proj.out_features // 2

model.config.num_attention_heads = m.self_attn.num_heads
if _is_gqa:
    model.config.num_key_value_heads = m.self_attn.num_key_value_heads

# --- Save the pruned model with corrected config ---
model.save_pretrained("pruned_llama_2_7b")
tokenizer.save_pretrained("pruned_llama_2_7b")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment