Implementation:VainF Torch Pruning BasePruner
Appearance
Metadata
| Field | Value |
|---|---|
| Sources | Repo: Torch-Pruning, Paper: DepGraph |
| Domains | Deep_Learning, Model_Compression |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for dependency-graph-based structural pruning provided by Torch-Pruning.
Description
BasePruner is the core pruning orchestrator in the Torch-Pruning library. It coordinates every stage of the structural pruning pipeline:
- Model tracing -- Passes
example_inputsthrough the model to build aDependencyGraph, which captures all inter-layer structural dependencies. - Group discovery -- Enumerates all prunable groups (sets of coupled layers and parameter slices) via
DG.get_all_groups(). - Importance estimation -- Delegates to a pluggable
importancecallable (e.g.,MagnitudeImportance,TaylorImportance) to score each channel. - Ranking and thresholding -- Ranks channels within a configurable scope (local, global, isomorphic, or user-defined) and selects the least important channels up to the target pruning ratio.
- Coordinated pruning -- Removes the selected channels from all coupled layers in a single atomic group operation, keeping the network structurally valid.
Ranking scopes supported by BasePruner:
- Local (
global_pruning=False, isomorphic=False) -- Each group is ranked independently. Every root layer receives the same pruning ratio. - Global (
global_pruning=True, isomorphic=False) -- All groups are pooled into a single ranking scope. Channels are removed network-wide by global importance rank. - Isomorphic (
global_pruning=True, isomorphic=True) -- Groups with identical dependency-graph topology share a ranking scope, as described in Isomorphic Pruning (ECCV 2024). - User-defined -- Pass a
pruning_ratio_dictwith tuples of modules as keys to create custom ranking scopes with layer-specific ratios.
Multi-head attention is handled via:
num_heads-- declares which layers are multi-head attention projections.prune_head_dims=True-- removes individual dimensions within each head (default).prune_num_heads=True-- removes entire attention heads.- Grouped Query Attention (GQA) is supported: KV heads are handled separately from Q heads.
Iterative pruning divides the target ratio across iterative_steps steps, controlled by iterative_pruning_ratio_scheduler (default: linear_scheduler). Between steps the user is expected to fine-tune the model to recover accuracy.
Code Reference
Source Location
torch_pruning/pruner/algorithms/base_pruner.py, Lines 15--751
Class Signature
class BasePruner:
def __init__(
self,
model: nn.Module,
example_inputs: torch.Tensor,
importance: typing.Callable,
global_pruning: bool = False,
pruning_ratio: float = 0.5,
pruning_ratio_dict: Dict[Union[nn.Module, Tuple[nn.Module]], float] = None,
max_pruning_ratio: float = 1.0,
iterative_steps: int = 1,
iterative_pruning_ratio_scheduler: Callable = linear_scheduler,
ignored_layers: List[nn.Module] = None,
round_to: int = None,
isomorphic: bool = False,
in_channel_groups: Dict[nn.Module, int] = dict(),
out_channel_groups: Dict[nn.Module, int] = dict(),
num_heads: Dict[nn.Module, int] = dict(),
prune_num_heads: bool = False,
prune_head_dims: bool = True,
head_pruning_ratio: float = 0.0,
head_pruning_ratio_dict: Dict[nn.Module, float] = None,
customized_pruners: Dict = None,
unwrapped_parameters: Dict[nn.Parameter, int] = None,
root_module_types: List = [TORCH_CONV, TORCH_LINEAR, TORCH_LSTM],
forward_fn: Callable = None,
output_transform: Callable = None,
):
Key Method
def step(self, interactive: bool = False) -> Union[Generator, None]:
"""Execute one step of pruning.
Args:
interactive: If True, yields groups for manual inspection/control.
If False, prunes all groups automatically.
Returns:
Generator yielding pruning groups if interactive=True, None otherwise.
"""
Import Paths
import torch_pruning as tp
tp.pruner.BasePruner
# or directly:
from torch_pruning.pruner.algorithms import BasePruner
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
model |
nn.Module |
The PyTorch model to be pruned. Modified in-place during pruning. |
example_inputs |
torch.Tensor or List |
Dummy input(s) used for tracing the computational graph. Must match the model's expected input shape and device. |
importance |
Callable |
An importance estimator (e.g., tp.importance.MagnitudeImportance). Receives a pruning group and returns a 1-D importance tensor.
|
pruning_ratio |
float (default 0.5) |
Target fraction of channels to remove from each prunable layer (local mode) or network-wide (global mode). |
iterative_steps |
int (default 1) |
Number of iterative pruning steps. The target ratio is distributed across steps by the scheduler. |
ignored_layers |
List[nn.Module] or None |
Modules to exclude from pruning (e.g., the final classifier head). |
global_pruning |
bool (default False) |
If True, rank channels globally across the entire network rather than per-layer. |
isomorphic |
bool (default False) |
If True (requires global_pruning=True), apply isomorphic pruning where structurally identical groups share a ranking scope.
|
pruning_ratio_dict |
Dict or None |
Layer-specific pruning ratios. Keys can be single modules or tuples of modules (shared scope). |
round_to |
int or None |
Round remaining channel counts to the nearest multiple of this value (e.g., 8 for hardware alignment). |
num_heads |
Dict[nn.Module, int] |
Declares which linear layers are multi-head attention projections and their head count. |
customized_pruners |
Dict or None |
Module-type to pruning-function mappings for custom layer types. |
unwrapped_parameters |
Dict[nn.Parameter, int] or None |
Standalone parameters (not inside standard modules) and their pruning dimension. |
Outputs
- The model is modified in-place -- weight tensors are sliced, and layer attributes (e.g.,
out_channels,out_features) are updated. step(interactive=False)returnsNone. All groups are pruned automatically.step(interactive=True)returns a Generator that yieldsGroupobjects. The caller can inspect, modify, or selectively prune each group by callinggroup.prune().
Usage Examples
Example 1: Basic CNN Pruning
import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)
# Ignore the final classifier layer
ignored_layers = []
for m in model.modules():
if isinstance(m, nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
# Build the pruner
imp = tp.importance.MagnitudeImportance(p=2)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_inputs,
importance=imp,
pruning_ratio=0.5,
ignored_layers=ignored_layers,
)
# One-shot pruning
pruner.step()
# Verify the pruned model
out = model(example_inputs)
print(out.shape) # torch.Size([1, 1000])
Example 2: Interactive Mode with Manual Group Control
import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)
ignored_layers = []
for m in model.modules():
if isinstance(m, nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
imp = tp.importance.MagnitudeImportance(p=2)
iterative_steps = 5
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_inputs,
importance=imp,
iterative_steps=iterative_steps,
pruning_ratio=0.5,
ignored_layers=ignored_layers,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
# Interactive mode: iterate over groups, inspect, then prune
for group in pruner.step(interactive=True):
print(group)
group.prune()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(
" Iter %d/%d, Params: %.2f M => %.2f M"
% (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
)
# Fine-tune the model here between steps
# finetune(model)
Example 3: Global Pruning Mode
import torch
import torch.nn as nn
from torchvision.models import resnet50
import torch_pruning as tp
model = resnet50(pretrained=True).eval()
example_inputs = torch.randn(1, 3, 224, 224)
ignored_layers = []
for m in model.modules():
if isinstance(m, nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
imp = tp.importance.MagnitudeImportance(p=1)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=example_inputs,
importance=imp,
pruning_ratio=0.5,
global_pruning=True, # rank all channels network-wide
ignored_layers=ignored_layers,
)
pruner.step()
out = model(example_inputs)
print(out.shape) # torch.Size([1, 1000])
Related Pages
- Principle:VainF_Torch_Pruning_Structural_Pruning
- Environment:VainF_Torch_Pruning_PyTorch_Python_Core
- Heuristic:VainF_Torch_Pruning_Channel_Rounding_Alignment
- Heuristic:VainF_Torch_Pruning_Pruning_Ratio_vs_Parameter_Ratio
- Heuristic:VainF_Torch_Pruning_Over_Pruning_Prevention
- Heuristic:VainF_Torch_Pruning_GQA_Head_Pruning_Constraints
- Heuristic:VainF_Torch_Pruning_AutoGrad_Dependency_Graph
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment