Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:VainF Torch Pruning LLM Structural Pruning

From Leeroopedia


Knowledge Sources
Domains Model_Compression, Structural_Pruning, LLMs
Last Updated 2026-02-07 23:30 GMT

Overview

End-to-end process for structurally pruning large language models (Llama, Phi, Qwen, DeepSeek) by removing attention heads and hidden dimensions, then evaluating with perplexity on WikiText-2.

Description

This workflow handles the structural pruning of causal language models loaded via HuggingFace Transformers. It addresses the unique challenges of LLM pruning: managing separate Q/K/V projections (or fused QKV), handling Grouped Query Attention (GQA) where key/value heads differ from query heads, pruning fused gate-up projections in MLP layers, and updating all model configuration attributes after pruning. The workflow uses magnitude-based importance at the group level and evaluates the quality of the pruned model by measuring perplexity on the WikiText-2 benchmark. Head count pruning (removing entire heads) is preferred over head dimension pruning to avoid breaking rotary position embeddings (RoPE).

Usage

Execute this workflow when you need to compress a large language model for deployment on hardware with limited memory or compute. Appropriate for Llama-2/3, Phi-3, Qwen-2/2.5, DeepSeek-R1-Distill, and similar decoder-only transformer LLMs loaded from HuggingFace.

Execution Steps

Step 1: Load LLM and tokenizer

Load the causal language model using AutoModelForCausalLM with float16 precision and automatic device mapping. Load the corresponding tokenizer. Set the model's sequence length from its configuration, optionally capping it to avoid out-of-memory errors.

Key considerations:

  • Use torch_dtype=torch.float16 and low_cpu_mem_usage=True for memory efficiency
  • device_map="auto" distributes layers across available GPUs
  • Set model.seqlen from model.config.max_position_embeddings

Step 2: Discover attention and MLP structure

Iterate through all named modules to build dictionaries mapping Q/K/V projection layers to their head counts, and identifying fused gate-up projections with their channel group counts. Detect whether the model uses separate Q/K/V projections or a fused QKV projection, and whether it employs Grouped Query Attention.

Key considerations:

  • For separate Q/K/V: register num_heads for q_proj (num_attention_heads) and k_proj/v_proj (num_key_value_heads)
  • For fused QKV (qkv_proj): register total num_attention_heads
  • For fused gate-up MLP (gate_up_proj): register out_channel_groups=2
  • Detect GQA by comparing num_attention_heads vs num_key_value_heads

Step 3: Configure and create pruner

Create a BasePruner with the model, a tokenized example input, importance criterion (GroupMagnitudeImportance with L2 norm), and the head/channel group mappings. Configure to prune entire heads (prune_num_heads=True) rather than head dimensions to preserve RoPE compatibility. Set the lm_head as ignored.

Key considerations:

  • Use prune_num_heads=True, prune_head_dims=False to avoid breaking rotary position embeddings
  • Set head_pruning_ratio equal to the overall pruning_ratio
  • Use round_to=4 to keep dimensions aligned
  • output_transform extracts logits from the model output

Step 4: Execute pruning

Run the pruner in interactive mode, iterating through all groups and pruning each one. This physically removes weights from all Q/K/V projection layers, MLP layers, LayerNorm layers, and embedding layers in a structurally consistent manner.

Key considerations:

  • Groups must be processed sequentially in interactive mode
  • The pruner handles the complex dependencies between attention projections, MLP layers, and normalization layers

Step 5: Update model configuration

After pruning, update all model configuration and module attributes to reflect the new dimensions. This includes hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, and per-module attributes like hidden_size on each self-attention and MLP module. Also update num_key_value_groups for GQA models.

Key considerations:

  • hidden_size is inferred from lm_head.in_features
  • Per-attention-layer num_heads is computed from q_proj.out_features / head_dim
  • intermediate_size comes from gate_proj.out_features
  • For non-GQA models, num_key_value_heads must equal num_attention_heads
  • This step is critical for correct model behavior after pruning

Step 6: Evaluate perplexity

Evaluate the pruned model on WikiText-2 to measure perplexity degradation. Tokenize the test set, run the model on sequential chunks of the sequence length, compute per-token cross-entropy, and derive perplexity as exp(mean_loss).

Key considerations:

  • Lower perplexity is better; compare against the unpruned baseline
  • WikiText-2 is the standard benchmark for LLM pruning evaluation
  • Free GPU cache after evaluation to avoid memory issues

Step 7: Save pruned model

Save the pruned model and tokenizer using HuggingFace's save_pretrained method. This preserves the updated configuration so the model can be loaded with standard HuggingFace APIs.

Key considerations:

  • save_pretrained saves both weights and config with updated dimensions
  • The saved model can be loaded with AutoModelForCausalLM.from_pretrained
  • Optionally run zero-shot evaluation benchmarks (BoolQ, HellaSwag, etc.) for comprehensive quality assessment

Execution Diagram

GitHub URL

Workflow Repository