Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning BasePruner

From Leeroopedia


Metadata

Field Value
Sources Repo: Torch-Pruning, Paper: DepGraph
Domains Deep_Learning, Model_Compression
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for dependency-graph-based structural pruning provided by Torch-Pruning.

Description

BasePruner is the core pruning orchestrator in the Torch-Pruning library. It coordinates every stage of the structural pruning pipeline:

  1. Model tracing -- Passes example_inputs through the model to build a DependencyGraph, which captures all inter-layer structural dependencies.
  2. Group discovery -- Enumerates all prunable groups (sets of coupled layers and parameter slices) via DG.get_all_groups().
  3. Importance estimation -- Delegates to a pluggable importance callable (e.g., MagnitudeImportance, TaylorImportance) to score each channel.
  4. Ranking and thresholding -- Ranks channels within a configurable scope (local, global, isomorphic, or user-defined) and selects the least important channels up to the target pruning ratio.
  5. Coordinated pruning -- Removes the selected channels from all coupled layers in a single atomic group operation, keeping the network structurally valid.

Ranking scopes supported by BasePruner:

  • Local (global_pruning=False, isomorphic=False) -- Each group is ranked independently. Every root layer receives the same pruning ratio.
  • Global (global_pruning=True, isomorphic=False) -- All groups are pooled into a single ranking scope. Channels are removed network-wide by global importance rank.
  • Isomorphic (global_pruning=True, isomorphic=True) -- Groups with identical dependency-graph topology share a ranking scope, as described in Isomorphic Pruning (ECCV 2024).
  • User-defined -- Pass a pruning_ratio_dict with tuples of modules as keys to create custom ranking scopes with layer-specific ratios.

Multi-head attention is handled via:

  • num_heads -- declares which layers are multi-head attention projections.
  • prune_head_dims=True -- removes individual dimensions within each head (default).
  • prune_num_heads=True -- removes entire attention heads.
  • Grouped Query Attention (GQA) is supported: KV heads are handled separately from Q heads.

Iterative pruning divides the target ratio across iterative_steps steps, controlled by iterative_pruning_ratio_scheduler (default: linear_scheduler). Between steps the user is expected to fine-tune the model to recover accuracy.

Code Reference

Source Location

torch_pruning/pruner/algorithms/base_pruner.py, Lines 15--751

Class Signature

class BasePruner:
    def __init__(
        self,
        model: nn.Module,
        example_inputs: torch.Tensor,
        importance: typing.Callable,
        global_pruning: bool = False,
        pruning_ratio: float = 0.5,
        pruning_ratio_dict: Dict[Union[nn.Module, Tuple[nn.Module]], float] = None,
        max_pruning_ratio: float = 1.0,
        iterative_steps: int = 1,
        iterative_pruning_ratio_scheduler: Callable = linear_scheduler,
        ignored_layers: List[nn.Module] = None,
        round_to: int = None,
        isomorphic: bool = False,
        in_channel_groups: Dict[nn.Module, int] = dict(),
        out_channel_groups: Dict[nn.Module, int] = dict(),
        num_heads: Dict[nn.Module, int] = dict(),
        prune_num_heads: bool = False,
        prune_head_dims: bool = True,
        head_pruning_ratio: float = 0.0,
        head_pruning_ratio_dict: Dict[nn.Module, float] = None,
        customized_pruners: Dict = None,
        unwrapped_parameters: Dict[nn.Parameter, int] = None,
        root_module_types: List = [TORCH_CONV, TORCH_LINEAR, TORCH_LSTM],
        forward_fn: Callable = None,
        output_transform: Callable = None,
    ):

Key Method

def step(self, interactive: bool = False) -> Union[Generator, None]:
    """Execute one step of pruning.

    Args:
        interactive: If True, yields groups for manual inspection/control.
                     If False, prunes all groups automatically.

    Returns:
        Generator yielding pruning groups if interactive=True, None otherwise.
    """

Import Paths

import torch_pruning as tp
tp.pruner.BasePruner

# or directly:
from torch_pruning.pruner.algorithms import BasePruner

I/O Contract

Inputs

Parameter Type Description
model nn.Module The PyTorch model to be pruned. Modified in-place during pruning.
example_inputs torch.Tensor or List Dummy input(s) used for tracing the computational graph. Must match the model's expected input shape and device.
importance Callable An importance estimator (e.g., tp.importance.MagnitudeImportance). Receives a pruning group and returns a 1-D importance tensor.
pruning_ratio float (default 0.5) Target fraction of channels to remove from each prunable layer (local mode) or network-wide (global mode).
iterative_steps int (default 1) Number of iterative pruning steps. The target ratio is distributed across steps by the scheduler.
ignored_layers List[nn.Module] or None Modules to exclude from pruning (e.g., the final classifier head).
global_pruning bool (default False) If True, rank channels globally across the entire network rather than per-layer.
isomorphic bool (default False) If True (requires global_pruning=True), apply isomorphic pruning where structurally identical groups share a ranking scope.
pruning_ratio_dict Dict or None Layer-specific pruning ratios. Keys can be single modules or tuples of modules (shared scope).
round_to int or None Round remaining channel counts to the nearest multiple of this value (e.g., 8 for hardware alignment).
num_heads Dict[nn.Module, int] Declares which linear layers are multi-head attention projections and their head count.
customized_pruners Dict or None Module-type to pruning-function mappings for custom layer types.
unwrapped_parameters Dict[nn.Parameter, int] or None Standalone parameters (not inside standard modules) and their pruning dimension.

Outputs

  • The model is modified in-place -- weight tensors are sliced, and layer attributes (e.g., out_channels, out_features) are updated.
  • step(interactive=False) returns None. All groups are pruned automatically.
  • step(interactive=True) returns a Generator that yields Group objects. The caller can inspect, modify, or selectively prune each group by calling group.prune().

Usage Examples

Example 1: Basic CNN Pruning

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)

# Ignore the final classifier layer
ignored_layers = []
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

# Build the pruner
imp = tp.importance.MagnitudeImportance(p=2)
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs=example_inputs,
    importance=imp,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

# One-shot pruning
pruner.step()

# Verify the pruned model
out = model(example_inputs)
print(out.shape)  # torch.Size([1, 1000])

Example 2: Interactive Mode with Manual Group Control

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)

ignored_layers = []
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

imp = tp.importance.MagnitudeImportance(p=2)
iterative_steps = 5
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs=example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    # Interactive mode: iterate over groups, inspect, then prune
    for group in pruner.step(interactive=True):
        print(group)
        group.prune()

    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    # Fine-tune the model here between steps
    # finetune(model)

Example 3: Global Pruning Mode

import torch
import torch.nn as nn
from torchvision.models import resnet50
import torch_pruning as tp

model = resnet50(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)

ignored_layers = []
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

imp = tp.importance.MagnitudeImportance(p=1)
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs=example_inputs,
    importance=imp,
    pruning_ratio=0.5,
    global_pruning=True,       # rank all channels network-wide
    ignored_layers=ignored_layers,
)

pruner.step()

out = model(example_inputs)
print(out.shape)  # torch.Size([1, 1000])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment