Workflow:VainF Torch Pruning Sparse Training Pruning
| Knowledge Sources | |
|---|---|
| Domains | Model_Compression, Structural_Pruning, Sparse_Training |
| Last Updated | 2026-02-07 23:30 GMT |
Overview
End-to-end process for pruning neural networks using regularization-driven sparse training, where a sparsity-inducing regularizer is applied during training before pruning to produce cleaner importance signals.
Description
This workflow implements the train-with-regularization-then-prune paradigm, where a specialized pruner (BNScalePruner, GroupNormPruner, or GrowingRegPruner) adds gradient-based regularization during training to push unimportant channels toward zero. After the sparse training phase, the pruner uses the resulting weight magnitudes (which now more clearly separate important from unimportant channels) to make pruning decisions. This approach typically yields better accuracy retention compared to one-shot magnitude pruning because the model has time to redistribute importance during training. The workflow covers the CIFAR and ImageNet experiments from the DepGraph paper and supports multiple regularization methods including BN scale (Network Slimming), group-level norm regularization, and growing regularization.
Usage
Execute this workflow when you want to achieve the best possible accuracy retention after pruning, especially for research benchmarks or when one-shot pruning results in unacceptable accuracy loss. Appropriate for CIFAR-10/100 experiments with models like ResNet-56, VGG-19, DenseNet, and for ImageNet experiments with ResNet-50, MobileNetV2, ViT, and similar architectures.
Execution Steps
Step 1: Pretrain or load base model
Train the base model from scratch on the target dataset (CIFAR-10/100 or ImageNet), or load a pretrained checkpoint. The base model serves as the starting point for sparse training. Record baseline accuracy for comparison.
Key considerations:
- CIFAR models are trained with standard SGD, learning rate 0.1, cosine or step decay
- ImageNet models may use pretrained torchvision weights as the starting point
- Ensure the model achieves expected baseline accuracy before proceeding to sparse training
Step 2: Configure pruning method and regularizer
Select the pruning method (BNScale, GroupNorm, GrowingReg, or magnitude-based methods like L1) and create the corresponding pruner. Configure the regularization coefficient, pruning speed-up target, and number of iterative steps. The pruner registry maps method names to their respective pruner classes and importance estimators.
Key considerations:
- BNScalePruner regularizes BatchNorm scale factors (gamma) toward zero
- GroupNormPruner applies group-level L1/L2 regularization on coupled weight groups
- GrowingRegPruner gradually increases regularization strength over training epochs
- The speed-up target (e.g., 2x) determines how aggressively to prune
Step 3: Run sparse training phase
Train the model with the regularizer active in each training step. At each iteration: compute the forward pass, compute the loss, call loss.backward(), then call pruner.regularize(model) to add regularization gradients before optimizer.step(). Call pruner.update_regularizer() at the start of each epoch. This phase runs for a configured number of epochs (sl_total_epochs) to drive unimportant parameters toward zero.
Key considerations:
- pruner.regularize(model) must be called after loss.backward() and before optimizer.step()
- pruner.update_regularizer() must be called at the beginning of each epoch to refresh group information
- The regularization warmup period (sl_reg_warmup) gradually increases regularization strength
- Use a separate learning rate schedule for the sparse training phase
Step 4: Execute progressive pruning
After sparse training, execute the actual pruning using progressive_pruning: repeatedly call pruner.step() until the target speed-up ratio is reached. Each step removes a small fraction of channels based on the importance scores shaped by the preceding sparse training.
Pseudocode:
While current_speed_up < target_speed_up:
pruner.step()
compute pruned_ops
current_speed_up = base_ops / pruned_ops
Key considerations:
- Progressive pruning gradually increases the pruning ratio until the target FLOPs reduction is met
- For OBDC importance, calibration data forward-backward passes are needed before each step
- The pruner tracks its current_step and stops at iterative_steps maximum
Step 5: Fine-tune pruned model
Fine-tune the pruned model for the configured number of epochs to recover accuracy lost during pruning. Use standard training with SGD, learning rate schedule, and the full training dataset.
Key considerations:
- Fine-tuning typically uses a lower initial learning rate than pretraining
- The number of fine-tuning epochs depends on the dataset (100 for CIFAR, more for ImageNet)
- Monitor validation accuracy to ensure recovery
Step 6: Evaluate final model
Evaluate the final pruned and fine-tuned model on the test/validation set. Report the accuracy delta compared to the baseline, the actual speed-up achieved, and the parameter reduction.
Key considerations:
- Compare against both the pretrained baseline and other pruning methods
- The DepGraph paper's benchmark results provide reference accuracy targets
- Measure actual inference latency in addition to FLOPs for practical deployment assessment