Principle:VainF Torch Pruning Progressive Pruning
Metadata
| Field | Value |
|---|---|
| Paper | DepGraph |
| Domains | Deep_Learning, Model_Compression, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Iteratively pruning a model in small steps until a target FLOPs reduction ratio is achieved, measured empirically after each step.
Description
Rather than pruning to a target ratio in one shot, progressive pruning applies small pruning steps repeatedly, checking the actual FLOPs reduction after each step. This approach is more robust because the relationship between channel pruning ratio and FLOPs reduction is non-linear (especially with skip connections, grouped convolutions, etc.). The loop continues calling pruner.step() and measuring FLOPs via count_ops_and_params until the target speedup (FLOPs_original / FLOPs_current) is achieved.
Key characteristics of progressive pruning:
- Iterative execution -- pruning is applied in many small steps rather than a single large step
- Empirical measurement -- FLOPs are counted after each step using
tp.utils.count_ops_and_paramsrather than estimated analytically - Robustness -- handles non-linear relationships between pruning ratio and actual FLOPs reduction
- Convergence guarantee -- each step monotonically reduces FLOPs, so the loop always terminates
- Compatibility -- works with any
BasePrunerinstance and any importance criterion
Usage
Use in the sparse training + pruning pipeline when you need to hit a precise FLOPs target. This is common in CIFAR/ImageNet reproduction experiments where exact speedup ratios (e.g., 2x, 4x) must be achieved for fair comparison with published results.
Typical workflow:
- Train the model to convergence (or load a pretrained checkpoint)
- Configure a
BasePrunerwith the desired importance criterion and iterative steps - Call
progressive_pruning()to prune to the target speedup - Fine-tune the pruned model to recover accuracy
Theoretical Basis
Given target speedup S and original FLOPs F_0, the progressive pruning algorithm is:
# Progressive pruning algorithm (pseudocode)
F_0 = count_ops(model) # measure original FLOPs
current_speedup = 1.0
while current_speedup < S:
pruner.step() # apply one pruning step
F_current = count_ops(model) # measure current FLOPs
current_speedup = F_0 / F_current
if pruner.current_step == pruner.iterative_steps:
break # exhausted all pruning steps
return current_speedup
This is a greedy algorithm that converges because each call to pruner.step() removes channels/filters, which monotonically reduces FLOPs. The algorithm terminates either when the target speedup is achieved or when all iterative steps are exhausted.
The non-linearity arises from several sources:
- Skip connections -- pruning a layer may not reduce FLOPs if the skip connection constrains the channel count
- Grouped convolutions -- FLOPs scale differently with group structure
- Dependency graphs -- coupled layers must be pruned together, creating discrete jumps in FLOPs