Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:VainF Torch Pruning Sparse Training Pruning

From Leeroopedia
Revision as of 11:03, 16 February 2026 by Admin (talk | contribs) (Auto-imported from workflows/VainF_Torch_Pruning_Sparse_Training_Pruning.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Compression, Structural_Pruning, Sparse_Training
Last Updated 2026-02-07 23:30 GMT

Overview

End-to-end process for pruning neural networks using regularization-driven sparse training, where a sparsity-inducing regularizer is applied during training before pruning to produce cleaner importance signals.

Description

This workflow implements the train-with-regularization-then-prune paradigm, where a specialized pruner (BNScalePruner, GroupNormPruner, or GrowingRegPruner) adds gradient-based regularization during training to push unimportant channels toward zero. After the sparse training phase, the pruner uses the resulting weight magnitudes (which now more clearly separate important from unimportant channels) to make pruning decisions. This approach typically yields better accuracy retention compared to one-shot magnitude pruning because the model has time to redistribute importance during training. The workflow covers the CIFAR and ImageNet experiments from the DepGraph paper and supports multiple regularization methods including BN scale (Network Slimming), group-level norm regularization, and growing regularization.

Usage

Execute this workflow when you want to achieve the best possible accuracy retention after pruning, especially for research benchmarks or when one-shot pruning results in unacceptable accuracy loss. Appropriate for CIFAR-10/100 experiments with models like ResNet-56, VGG-19, DenseNet, and for ImageNet experiments with ResNet-50, MobileNetV2, ViT, and similar architectures.

Execution Steps

Step 1: Pretrain or load base model

Train the base model from scratch on the target dataset (CIFAR-10/100 or ImageNet), or load a pretrained checkpoint. The base model serves as the starting point for sparse training. Record baseline accuracy for comparison.

Key considerations:

  • CIFAR models are trained with standard SGD, learning rate 0.1, cosine or step decay
  • ImageNet models may use pretrained torchvision weights as the starting point
  • Ensure the model achieves expected baseline accuracy before proceeding to sparse training

Step 2: Configure pruning method and regularizer

Select the pruning method (BNScale, GroupNorm, GrowingReg, or magnitude-based methods like L1) and create the corresponding pruner. Configure the regularization coefficient, pruning speed-up target, and number of iterative steps. The pruner registry maps method names to their respective pruner classes and importance estimators.

Key considerations:

  • BNScalePruner regularizes BatchNorm scale factors (gamma) toward zero
  • GroupNormPruner applies group-level L1/L2 regularization on coupled weight groups
  • GrowingRegPruner gradually increases regularization strength over training epochs
  • The speed-up target (e.g., 2x) determines how aggressively to prune

Step 3: Run sparse training phase

Train the model with the regularizer active in each training step. At each iteration: compute the forward pass, compute the loss, call loss.backward(), then call pruner.regularize(model) to add regularization gradients before optimizer.step(). Call pruner.update_regularizer() at the start of each epoch. This phase runs for a configured number of epochs (sl_total_epochs) to drive unimportant parameters toward zero.

Key considerations:

  • pruner.regularize(model) must be called after loss.backward() and before optimizer.step()
  • pruner.update_regularizer() must be called at the beginning of each epoch to refresh group information
  • The regularization warmup period (sl_reg_warmup) gradually increases regularization strength
  • Use a separate learning rate schedule for the sparse training phase

Step 4: Execute progressive pruning

After sparse training, execute the actual pruning using progressive_pruning: repeatedly call pruner.step() until the target speed-up ratio is reached. Each step removes a small fraction of channels based on the importance scores shaped by the preceding sparse training.

Pseudocode:

While current_speed_up < target_speed_up:
    pruner.step()
    compute pruned_ops
    current_speed_up = base_ops / pruned_ops

Key considerations:

  • Progressive pruning gradually increases the pruning ratio until the target FLOPs reduction is met
  • For OBDC importance, calibration data forward-backward passes are needed before each step
  • The pruner tracks its current_step and stops at iterative_steps maximum

Step 5: Fine-tune pruned model

Fine-tune the pruned model for the configured number of epochs to recover accuracy lost during pruning. Use standard training with SGD, learning rate schedule, and the full training dataset.

Key considerations:

  • Fine-tuning typically uses a lower initial learning rate than pretraining
  • The number of fine-tuning epochs depends on the dataset (100 for CIFAR, more for ImageNet)
  • Monitor validation accuracy to ensure recovery

Step 6: Evaluate final model

Evaluate the final pruned and fine-tuned model on the test/validation set. Report the accuracy delta compared to the baseline, the actual speed-up achieved, and the parameter reduction.

Key considerations:

  • Compare against both the pretrained baseline and other pruning methods
  • The DepGraph paper's benchmark results provide reference accuracy targets
  • Measure actual inference latency in addition to FLOPs for practical deployment assessment

Execution Diagram

GitHub URL

Workflow Repository