Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning Progressive Pruning Fn

From Leeroopedia


Metadata

Field Value
Source Torch-Pruning
Domains Deep_Learning, Pruning
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for progressive FLOPs-targeted pruning provided by the Torch-Pruning reproduce module.

Description

progressive_pruning() loops calling pruner.step() and measuring FLOPs via count_ops_and_params after each step until the target speedup ratio is met. Returns the achieved speedup. This is a utility function from the reproduce/main.py experiment runner.

The function handles two code paths:

  • Standard path -- simply calls pruner.step() to apply one round of pruning
  • OBDC path -- when using the OBDC (Optimal Brain Damage with Compensation) importance criterion, additional forward/backward passes are required on training data before each pruning step

Code Reference

  • Source file: reproduce/main.py, Lines 52-80
  • Import: From reproduce/main.py (example script, not library API)

Signature

def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
    """Prune model until target speedup is achieved.

    Args:
        pruner: BasePruner instance
        model: nn.Module to prune
        speed_up: target FLOPs reduction ratio (e.g., 2.0 for 2x)
        example_inputs: tensor for FLOPs counting
        train_loader: optional, for OBDC importance

    Returns:
        current_speed_up: float - achieved speedup ratio
    """

Source Implementation

def progressive_pruning(pruner, model, speed_up, example_inputs, train_loader=None):
    model.eval()
    base_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
    current_speed_up = 1
    while current_speed_up < speed_up:
        if args.method == "obdc":
            model.zero_grad()
            imp = pruner.importance
            imp._prepare_model(model, pruner)
            for k, (imgs, lbls) in enumerate(train_loader):
                if k >= 10: break
                imgs = imgs.to(args.device)
                lbls = lbls.to(args.device)
                output = model(imgs)
                sampled_y = torch.multinomial(
                    torch.nn.functional.softmax(output.cpu().data, dim=1), 1
                ).squeeze().to(args.device)
                loss_sample = F.cross_entropy(output, sampled_y)
                loss_sample.backward()
                imp.step()
            pruner.step()
            imp._rm_hooks(model)
            imp._clear_buffer()
        else:
            pruner.step()
        pruned_ops, _ = tp.utils.count_ops_and_params(model, example_inputs=example_inputs)
        current_speed_up = float(base_ops) / pruned_ops
        if pruner.current_step == pruner.iterative_steps:
            break
    return current_speed_up

I/O Contract

Inputs

Parameter Type Required Description
pruner BasePruner Yes A configured pruner instance with importance criterion and iterative steps
model nn.Module Yes The neural network model to prune (modified in-place)
speed_up float Yes Target FLOPs reduction ratio (e.g., 2.0 for 2x speedup)
example_inputs Tensor Yes Example input tensor for FLOPs counting via count_ops_and_params
train_loader DataLoader No Training data loader, required only for OBDC importance criterion

Outputs

Name Type Description
current_speed_up float The achieved speedup ratio (original_FLOPs / pruned_FLOPs)

Side effect: The model is pruned in-place -- its architecture and weights are permanently modified.

Usage Examples

import torch
import torch.nn as nn
import torch_pruning as tp

# 1. Load a pretrained model
model = torchvision.models.resnet50(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 2. Configure the pruner
imp = tp.importance.MagnitudeImportance(p=2)
pruner = tp.pruner.MetaPruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5,
    iterative_steps=400,
)

# 3. Progressive pruning to 2x speedup
achieved_speedup = progressive_pruning(
    pruner=pruner,
    model=model,
    speed_up=2.0,
    example_inputs=example_inputs,
)
print(f"Achieved speedup: {achieved_speedup:.2f}x")

# 4. Fine-tune the pruned model
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(100):
    train(model, train_loader, optimizer)
    acc = evaluate(model, test_loader)
    print(f"Epoch {epoch}: accuracy = {acc:.2f}%")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment