Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:VainF Torch Pruning CNN Pruning

From Leeroopedia
Revision as of 11:05, 16 February 2026 by Admin (talk | contribs) (Auto-imported from workflows/VainF_Torch_Pruning_CNN_Pruning.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Compression, Structural_Pruning, Computer_Vision
Last Updated 2026-02-07 23:30 GMT

Overview

End-to-end process for structurally pruning convolutional neural networks (CNNs) from torchvision using the DepGraph algorithm and high-level pruners.

Description

This workflow covers the standard procedure for removing entire filters and channels from CNN architectures such as ResNet, VGG, DenseNet, EfficientNet, MobileNet, and other torchvision models. It leverages Torch-Pruning's DepGraph to automatically discover inter-layer dependencies, group structurally coupled parameters, and prune them together. The process supports both local (per-layer uniform ratio) and global (cross-layer ranking) pruning strategies with multiple importance criteria including L1/L2 magnitude, Taylor expansion, and random baselines. After pruning, the resulting model is a genuinely smaller network (fewer parameters and FLOPs) rather than a masked/sparse model.

Usage

Execute this workflow when you have a pretrained CNN from torchvision (or any PyTorch CNN) and need to reduce its parameter count and computational cost for deployment on resource-constrained hardware, or to accelerate inference. This is the recommended starting point for users new to Torch-Pruning.

Execution Steps

Step 1: Load pretrained model

Load a pretrained CNN from torchvision or a custom checkpoint. Place the model in evaluation mode and prepare example inputs matching the model's expected input dimensions (typically a single batch of the correct spatial resolution). Count baseline MACs and parameters for later comparison.

Key considerations:

  • Ensure AutoGrad is enabled (no torch.no_grad context) during dependency graph construction
  • Use the correct input resolution for the model architecture (224x224 for most, 299x299 for Inception)

Step 2: Configure ignored layers and channel groups

Identify layers that must not be pruned and register them as ignored. Typically, the final classification head (the Linear layer outputting class logits) is ignored to preserve the output dimensionality. For detection models, task-specific heads (RPN, FPN, ROI heads) are also excluded. For Vision Transformers within torchvision, configure channel groups to respect multi-head attention constraints.

Key considerations:

  • Always ignore the final classifier layer to preserve output class count
  • For detection/segmentation models, ignore task-specific heads that have fixed channel requirements
  • Optionally register unwrapped nn.Parameters for models with standalone parameter tensors

Step 3: Select importance criterion

Choose an importance estimator to rank which channels/filters to remove. The main options are magnitude-based (L1 or L2 norm of grouped weights), Taylor expansion (first-order gradient information), Hessian-based (second-order information), or random selection as a baseline.

Key considerations:

  • L1/L2 magnitude importance requires no calibration data and is fast
  • Taylor importance requires a forward-backward pass on calibration data to accumulate gradients
  • Hessian importance requires per-sample gradients and is more expensive but can be more accurate
  • Group-level importance aggregates scores across all coupled layers in a dependency group

Step 4: Initialize the pruner

Create a high-level pruner (BasePruner, GroupNormPruner, or BNScalePruner) with the model, example inputs, importance criterion, target pruning ratio, and configuration options. The pruner internally constructs the DepGraph, discovers all pruning groups, and prepares the pruning plan.

Key considerations:

  • Set pruning_ratio to control how many channels to remove (e.g., 0.5 removes 50% of channels per layer)
  • Use global_pruning=True for cross-layer importance ranking, False for uniform per-layer ratios
  • Set round_to=8 to align channel counts to multiples of 8 for hardware acceleration
  • Configure iterative_steps for multi-step progressive pruning

Step 5: Execute pruning

Call pruner.step() to execute one round of pruning. This scans all groups, computes importance scores, determines which indices to remove, and physically removes the weights from the model tensors. Optionally use interactive=True mode to inspect and control each group before pruning.

Key considerations:

  • In interactive mode, groups must be processed sequentially (not stored as a list)
  • After pruning, update any static model attributes that depend on channel counts (e.g., hidden_dim for ViT)
  • For iterative pruning, repeat steps 4-5 with fine-tuning between iterations

Step 6: Validate and measure pruned model

Run a forward pass to verify the pruned model produces valid outputs. Count the new MACs and parameter count and compare to the baseline. Optionally evaluate accuracy on a validation set to measure the accuracy drop from pruning.

Key considerations:

  • The pruned model is a standard PyTorch model with genuinely smaller tensors
  • Accuracy drop before fine-tuning is expected; the goal is to recover accuracy through fine-tuning
  • Use tp.utils.count_ops_and_params to get accurate FLOPs/parameter counts

Step 7: Save pruned model and fine-tune

Save the pruned model using torch.save(model, path) (whole model, not state_dict, since the architecture has changed). Fine-tune the pruned model on the training dataset using standard training procedures to recover accuracy lost during pruning.

Key considerations:

  • Must save the entire model object, not just state_dict, because layer dimensions have changed
  • For PyTorch 2.6+, loading requires weights_only=False
  • Clear gradients with model.zero_grad() before saving to reduce file size
  • Fine-tuning typically uses a lower learning rate than original training

Execution Diagram

GitHub URL

Workflow Repository