Implementation:VainF Torch Pruning LLM Config Update Pattern
Appearance
Metadata
| Field | Value |
|---|---|
| Source | Torch-Pruning |
| Domains | NLP, Model_Compression |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete pattern for updating HuggingFace LLM configuration after structural pruning provided by Torch-Pruning examples.
Description
This is a Pattern Doc. After pruning an LLM with BasePruner, users must iterate through model.named_modules() to update all config attributes. The pattern handles:
- Separate QKV projections (Llama-style) --
q_proj,k_proj,v_projas individual linear layers - Fused QKV (some models) -- a single combined projection
- Separate gate/up projections --
gate_projandup_projas individual linear layers - Fused gate_up projections -- a single
gate_up_projlinear layer - Grouped Query Attention (GQA) -- where
num_key_value_headsdiffers fromnum_attention_heads
Code Reference
Source
examples/LLMs/prune_llm.py, Lines 347-374
Interface
# After pruning, update configuration
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
# Update attention configs
if hasattr(m, 'self_attn'):
# Separate Q/K/V
if hasattr(m.self_attn, 'q_proj'):
m.self_attn.num_heads = m.self_attn.q_proj.out_features // m.self_attn.head_dim
if _is_gqa:
m.self_attn.num_key_value_heads = m.self_attn.k_proj.out_features // m.self_attn.head_dim
m.self_attn.hidden_size = m.self_attn.q_proj.out_features
# Update MLP configs
if hasattr(m, 'mlp'):
if hasattr(m.mlp, 'gate_proj'):
model.config.intermediate_size = m.mlp.gate_proj.out_features
elif hasattr(m.mlp, 'gate_up_proj'):
model.config.intermediate_size = m.mlp.gate_up_proj.out_features // 2
model.config.num_attention_heads = m.self_attn.num_heads
if _is_gqa:
model.config.num_key_value_heads = m.self_attn.num_key_value_heads
Import
N/A (user-defined pattern, from example script)
I/O Contract
Inputs
- Pruned model -- a HuggingFace model whose weights have been structurally pruned but whose
model.configstill reflects the original unpruned dimensions - _is_gqa (boolean) -- whether the model uses Grouped Query Attention (i.e.,
num_key_value_heads != num_attention_heads) - seperate_qkv (boolean) -- whether the model uses separate Q/K/V projections or a fused QKV layer
Outputs
model.configattributes updated to match pruned weight dimensions:model.config.hidden_sizemodel.config.num_attention_headsmodel.config.num_key_value_heads(if GQA)model.config.intermediate_size
Usage Examples
Full config update after LLM pruning, followed by saving the model:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch_pruning as tp
# Load model
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# --- Pruning step (abbreviated) ---
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_inputs,
importance=tp.importance.MagnitudeImportance(),
pruning_ratio=0.5,
)
pruner.step()
# --- Config update step (mandatory before saving) ---
_is_gqa = hasattr(model.config, "num_key_value_heads") and \
model.config.num_key_value_heads != model.config.num_attention_heads
model.config.hidden_size = model.lm_head.in_features
for name, m in model.named_modules():
if hasattr(m, 'self_attn'):
if hasattr(m.self_attn, 'q_proj'):
m.self_attn.num_heads = m.self_attn.q_proj.out_features // m.self_attn.head_dim
if _is_gqa:
m.self_attn.num_key_value_heads = m.self_attn.k_proj.out_features // m.self_attn.head_dim
m.self_attn.hidden_size = m.self_attn.q_proj.out_features
if hasattr(m, 'mlp'):
if hasattr(m.mlp, 'gate_proj'):
model.config.intermediate_size = m.mlp.gate_proj.out_features
elif hasattr(m.mlp, 'gate_up_proj'):
model.config.intermediate_size = m.mlp.gate_up_proj.out_features // 2
model.config.num_attention_heads = m.self_attn.num_heads
if _is_gqa:
model.config.num_key_value_heads = m.self_attn.num_key_value_heads
# --- Save the pruned model with corrected config ---
model.save_pretrained("pruned_llama_2_7b")
tokenizer.save_pretrained("pruned_llama_2_7b")
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment