Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning Group Prune

From Leeroopedia


Metadata

Field Value
Source Torch-Pruning
Domains Deep_Learning, Model_Compression
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for executing atomic group pruning operations provided by Torch-Pruning. The Group class represents a collection of coupled (dependency, indices) pairs that must be pruned together. Its prune() method iterates through each pair, physically removing weight tensor rows and columns to maintain network consistency across all coupled layers.

Description

Group.prune() executes pruning on all coupled layers in a dependency group. It iterates through each (dependency, indices) pair and applies the pruning function, physically removing weight tensor rows/columns. The method handles two distinct cases:

Standard modules (Conv2d, Linear, BatchNorm, etc.): The dependency's handler function is called directly with the indices. This modifies the layer's weight and bias tensors in-place by removing the specified rows or columns.

nn.Parameter objects: These require special handling because PyTorch parameters cannot be resized in-place. Instead, Group.prune():

  1. Retrieves the parameter's registered name from the dependency graph
  2. Calls the pruning dependency to create a new, smaller parameter
  3. Navigates the module hierarchy to find the parent module
  4. Replaces the old parameter with the pruned one using setattr
  5. Updates all internal dependency graph references to point to the new parameter object

After all pruning operations complete, the method optionally records the pruning history for serialization. The history captures the root module name, whether it was output-channel pruning, and the pruned indices.

The Group class also supports:

  • Index override: Passing idxs to prune() creates a new group with the specified indices via DependencyGraph.get_pruning_group()
  • Deprecated exec(): Old interface that redirects to prune() with a deprecation warning
  • Callable interface: group() is equivalent to group.prune()

Code Reference

Field Value
Source File torch_pruning/dependency/group.py, Lines 7-143 (class), Lines 32-64 (prune method)
Import from torch_pruning.dependency import Group (typically obtained via DependencyGraph or BasePruner, not imported directly)

Class and Method Signature:

class Group(object):
    """Group is the basic unit for pruning.
    group := [ (Dep1, Indices1), (Dep2, Indices2), ..., (DepK, IndicesK) ]
    """

    def __init__(self):
        self._group = list()
        self._DG = None  # the dependency graph that this group belongs to

    def prune(self, idxs=None, record_history=True):
        """Prune all coupled layers in the group."""

Internal Structure:

# Each item in self._group is a GroupItem namedtuple:
# GroupItem(dep=Dependency, idxs=List[_HybridIndex])

# The Dependency object contains:
#   dep.target.module  -- the nn.Module to be pruned
#   dep.handler        -- the pruning function to apply
#   dep.target.type    -- OPTYPE enum (e.g., OPTYPE.PARAMETER)

# Pruning execution for standard modules:
dep(idxs)  # calls dep.handler(dep.target.module, idxs)

# Pruning execution for nn.Parameter:
pruned_parameter = dep(idxs)  # returns a new, smaller nn.Parameter

I/O Contract

Parameter Type Default Description
idxs list or None None Optional override indices. If provided, a new group is created with these indices and pruned instead.
record_history bool True Whether to record this pruning operation in the dependency graph's history for serialization.

Outputs:

  • Model weights are modified in-place -- channels are physically removed from weight tensors, reducing the tensor dimensions.
  • For nn.Parameter objects, new parameter objects are created and registered in the parent module.
  • Pruning history is appended to DependencyGraph._pruning_history if record_history=True.

Preconditions:

  • The group must be associated with a DependencyGraph (self._DG must not be None).
  • The group must contain valid dependencies with consistent index mappings.
  • The group should pass DG.check_pruning_group(group) validation.

Usage Examples

Interactive pruning with group.prune():

import torch
import torch.nn as nn
import torch_pruning as tp

model = YourModel()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

# Build dependency graph
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# Get a pruning group for specific channels
group = DG.get_pruning_group(
    model.conv1,
    tp.prune_conv_out_channels,
    idxs=[0, 2, 6]  # channels to remove
)

# Inspect the group before pruning
print(group)
# --------------------------------
#           Pruning Group
# --------------------------------
# [0] prune_conv_out_channels on conv1 (Conv2d), len(idxs)=3
# [1] prune_batchnorm_channels on bn1 (BatchNorm2d), len(idxs)=3
# [2] prune_conv_in_channels on conv2 (Conv2d), len(idxs)=3
# --------------------------------

# Execute the pruning
group.prune()

Automatic pruning via pruner.step():

import torch_pruning as tp

model = YourModel()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MetaPruner(
    model=model,
    example_inputs=example_inputs,
    importance=imp,
    pruning_ratio=0.5,
)

# Automatic mode: pruner calls group.prune() internally
pruner.step()

# Interactive mode: user receives groups and decides when to prune
for group in pruner.step(interactive=True):
    # Optional: filter or modify groups
    dep, idxs = group[0]  # inspect root dependency
    print(f"Pruning {dep.target.module} with {len(idxs)} indices")
    group.prune()  # execute

Using index override:

# Get a full group (all channels) for importance estimation
full_group = DG.get_pruning_group(
    model.conv1,
    tp.prune_conv_out_channels,
    idxs=list(range(64))  # all 64 channels
)

# Compute importance
importance_scores = imp(full_group)

# Prune with specific indices (overrides the group's original indices)
full_group.prune(idxs=[0, 2, 6])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment