Workflow:VainF Torch Pruning Vision Transformer Pruning
| Knowledge Sources | |
|---|---|
| Domains | Model_Compression, Structural_Pruning, Vision_Transformers |
| Last Updated | 2026-02-07 23:30 GMT |
Overview
End-to-end process for structurally pruning Vision Transformers (ViT, DeiT, Swin, BERT) from HuggingFace Transformers or timm, with attention head dimension and count management.
Description
This workflow addresses the unique challenges of pruning transformer architectures, where multi-head self-attention layers require coordinated pruning of Q/K/V projections. It covers loading pretrained ViT models, registering attention head groupings, choosing between head-dimension pruning (reducing the embedding size per head) and head-count pruning (removing entire attention heads), executing the pruning, and updating internal model attributes that depend on head counts and dimensions. The workflow supports models from both the timm library and HuggingFace Transformers, each requiring slightly different handling of attention internals.
Usage
Execute this workflow when you need to compress a Vision Transformer or BERT-family model for faster inference or reduced memory footprint. This is appropriate when working with ViT, DeiT, Swin Transformer, or BERT models from timm or HuggingFace, especially when the model's attention mechanism requires special handling during pruning.
Execution Steps
Step 1: Load pretrained transformer model
Load the pretrained Vision Transformer from timm (timm.create_model) or HuggingFace (ViTForImageClassification.from_pretrained). Place the model in eval mode and create example inputs. Measure baseline MACs and parameter count for comparison.
Key considerations:
- For timm models, the attention forward function may need patching to handle dynamic channel sizes after pruning
- For HuggingFace ViT, the model wraps attention in ViTSelfAttention with query/key/value projections
- Prepare ImageNet data loaders if using gradient-based importance criteria (Taylor, Hessian)
Step 2: Register attention head configuration
Iterate through all attention modules and register the number of heads for each Q/K/V projection layer in a num_heads dictionary. This tells the pruner how to group channels within attention layers so that pruning respects head boundaries. Also specify the final classifier as an ignored layer.
Key considerations:
- For timm ViT: register num_heads for the fused qkv projection (m.qkv)
- For HuggingFace ViT: register num_heads separately for m.query, m.key, and m.value
- Optionally use bottleneck pruning mode by ignoring FFN output layers (fc2/dense) to only prune internal dimensions
Step 3: Choose pruning strategy for heads
Decide between two complementary strategies: prune_head_dims (reduce the embedding dimension of each head while keeping all heads) or prune_num_heads (remove entire attention heads while keeping head dimensions intact). These can be combined but are typically used separately.
Pseudocode:
If prune_head_dims=True: each head gets a smaller dimension (e.g., 64 -> 32) If prune_num_heads=True: some heads are entirely removed (e.g., 12 -> 8 heads)
Key considerations:
- Head dimension pruning is more fine-grained but requires updating head_dim attributes after pruning
- Head count pruning is coarser but preserves the original head dimension, which can be important for positional encodings like RoPE
Step 4: Accumulate gradients for importance estimation
If using Taylor or Hessian importance, run forward-backward passes on a small calibration dataset to accumulate gradient information. For Taylor importance, compute cross-entropy loss and call backward(). For Hessian importance, compute per-sample losses and accumulate diagonal Hessian estimates.
Key considerations:
- Typically 10 batches of calibration data suffice for stable importance estimates
- For Hessian importance, each sample requires a separate backward pass with retain_graph=True
- Skip this step entirely if using magnitude-based importance (L1/L2)
Step 5: Execute pruning and update attention attributes
Run the pruner in interactive mode to prune all groups. After pruning, iterate through all attention modules to update their num_heads, head_dim, and all_head_size attributes to reflect the new dimensions. This is critical because transformer implementations rely on these attributes for reshaping tensors during the forward pass.
Key considerations:
- For timm ViT: update m.num_heads and m.head_dim based on pruner.num_heads[m.qkv]
- For HuggingFace ViT: update m.num_attention_heads, m.attention_head_size, and m.all_head_size
- The timm attention forward function must be patched to use dynamic reshape (B, N, -1) instead of hardcoded (B, N, C)
Step 6: Evaluate and save pruned model
Measure the pruned model's MACs and parameters. Optionally evaluate accuracy on ImageNet validation set. Benchmark inference latency. Save the pruned model using torch.save(model, path) for the whole model object.
Key considerations:
- Compare MACs reduction and parameter reduction to the baseline
- Accuracy drop before fine-tuning is expected for aggressive pruning ratios
- Fine-tuning on ImageNet is handled by a separate fine-tuning script (finetune.py)