Implementation:VainF Torch Pruning Progressive Pruning Fn
Appearance
Metadata
| Field | Value |
|---|---|
| Source | Torch-Pruning |
| Domains | Deep_Learning, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for progressive FLOPs-targeted pruning provided by the Torch-Pruning reproduce module.
Description
progressive_pruning() loops calling pruner.step() and measuring FLOPs via count_ops_and_params after each step until the target speedup ratio is met. Returns the achieved speedup. This is a utility function from the reproduce/main.py experiment runner.
The function handles two code paths:
- Standard path -- simply calls
pruner.step()to apply one round of pruning - OBDC path -- when using the OBDC (Optimal Brain Damage with Compensation) importance criterion, additional forward/backward passes are required on training data before each pruning step
Code Reference
- Source file:
reproduce/main.py, Lines 52-80 - Import: From
reproduce/main.py(example script, not library API)
Signature
def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
"""Prune model until target speedup is achieved.
Args:
pruner: BasePruner instance
model: nn.Module to prune
speed_up: target FLOPs reduction ratio (e.g., 2.0 for 2x)
example_inputs: tensor for FLOPs counting
train_loader: optional, for OBDC importance
Returns:
current_speed_up: float - achieved speedup ratio
"""
Source Implementation
def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
model.eval()
base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = 1
while current_speed_up < speed_up:
if args.method == "obdc":
model.zero_grad()
imp = pruner.importance
imp._prepare_model(model, pruner)
for k, (imgs, lbls) in enumerate(train_loader):
if k >= 10: break
imgs = imgs.to(args.device)
lbls = lbls.to(args.device)
output = model(imgs)
sampled_y = torch.multinomial(
torch.nn.functional.softmax(output.cpu().data, dim=1), 1
).squeeze().to(args.device)
loss_sample = F.cross_entropy(output, sampled_y)
loss_sample.backward()
imp.step()
pruner.step()
imp._rm_hooks(model)
imp._clear_buffer()
else:
pruner.step()
pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
current_speed_up = float(base_ops) / pruned_ops
if pruner.current_step == pruner.iterative_steps:
break
return current_speed_up
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
pruner |
BasePruner |
Yes | A configured pruner instance with importance criterion and iterative steps |
model |
nn.Module |
Yes | The neural network model to prune (modified in-place) |
speed_up |
float |
Yes | Target FLOPs reduction ratio (e.g., 2.0 for 2x speedup) |
example_inputs |
Tensor |
Yes | Example input tensor for FLOPs counting via count_ops_and_params
|
train_loader |
DataLoader |
No | Training data loader, required only for OBDC importance criterion |
Outputs
| Name | Type | Description |
|---|---|---|
current_speed_up |
float |
The achieved speedup ratio (original_FLOPs / pruned_FLOPs)
|
Side effect: The model is pruned in-place -- its architecture and weights are permanently modified.
Usage Examples
import torch
import torch.nn as nn
import torch_pruning as tp
# 1. Load a pretrained model
model = torchvision.models.resnet50(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# 2. Configure the pruner
imp = tp.importance.MagnitudeImportance(p=2)
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
importance=imp,
pruning_ratio=0.5,
iterative_steps=400,
)
# 3. Progressive pruning to 2x speedup
achieved_speedup = progressive_pruning(
pruner=pruner,
model=model,
speed_up=2.0,
example_inputs=example_inputs,
)
print(f"Achieved speedup: {achieved_speedup:.2f}x")
# 4. Fine-tune the pruned model
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(100):
train(model, train_loader, optimizer)
acc = evaluate(model, test_loader)
print(f"Epoch {epoch}: accuracy = {acc:.2f}%")
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment