Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:VainF Torch Pruning Dependency Group Pruning

From Leeroopedia


Metadata

Field Value
Paper DepGraph: Towards Any Structural Pruning (Fang et al., 2023)
Domains Deep_Learning, Model_Compression, Pruning
Last Updated 2026-02-08 00:00 GMT

Overview

The atomic unit of structural pruning that ensures all coupled layers are pruned consistently through dependency propagation. A dependency group represents the minimal set of (layer, indices) pairs that must be pruned together to maintain network consistency. This concept is fundamental to the Torch-Pruning framework and underpins all structural pruning operations it supports.

Description

In neural networks, layers are interconnected in ways that impose structural constraints on pruning. For example, a Conv2d -> BN -> ReLU chain requires that the output channels of the convolution match the number of features in the BatchNorm layer. If the convolution's output channels are pruned, the BatchNorm must be adjusted accordingly. A dependency group captures this coupling.

The DependencyGraph traces the model's computation graph to discover these couplings. It handles complex structural patterns including:

  • Sequential chains: Conv -> BN -> ReLU where output channels must be consistent
  • Concatenation: Cat([Conv1, Conv2]) where pruning indices must be mapped across concatenated tensors
  • Split/Chunk: Operations that divide a tensor along the channel dimension
  • Reshape/View: Operations that change the tensor layout
  • Multi-head attention: Where Q, K, V projections share dimensional constraints
  • Residual connections: Where skip connections require matching dimensions at the addition point

A dependency group G is produced by BFS (breadth-first search) traversal from a root module through the dependency graph. Starting from a root layer (typically a Conv2d or Linear layer), the traversal follows dependency edges to discover all layers whose parameters must be modified when the root is pruned.

The group is structured as an ordered list of (dependency, indices) pairs:

group = [
    (Conv2d -> BN,   [0, 1, 2, 3]),   # prune output channels of Conv, adjust BN
    (BN -> ReLU,     [0, 1, 2, 3]),   # adjust downstream ReLU (passthrough)
    (ReLU -> Conv2d, [0, 1, 2, 3]),   # adjust input channels of next Conv
]

The indices do not need to be the full set of channels. For importance estimation, a full-index group is used to score all channels. For actual pruning, a subset group is created with only the channels to be removed.

Usage

Dependency group pruning is fundamental to all structural pruning in Torch-Pruning. Groups are used in two primary contexts:

  • Automatic pruning via BasePruner.step(): The pruner internally enumerates all groups, computes importance scores, determines which channels to remove, and calls group.prune() on each group.
  • Interactive/manual pruning: The user can obtain groups from the DependencyGraph directly and inspect or modify them before pruning.
# Automatic: pruner handles groups internally
pruner.step()

# Interactive: user controls which groups to prune
for group in pruner.step(interactive=True):
    print(group)         # inspect the group
    group.prune()        # execute pruning

# Manual: obtain a specific group from the dependency graph
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)
group = DG.get_pruning_group(model.conv1, tp.prune_conv_out_channels, idxs=[0, 1])
group.prune()

Groups are also used for:

  • Importance estimation: Importance functions receive a group and return per-channel scores
  • Serialization: Pruning history is recorded per-group for model saving and loading
  • Validation: DG.check_pruning_group(group) verifies that a group is safe to prune

Theoretical Basis

A Group G is defined as:

G={(dep1,idx1),(dep2,idx2),,(depk,idxk)}

where depi represents a dependency edge in the graph (connecting a source layer to a target layer via a pruning function pair), and idxi represents the channel indices to prune in the target layer.

Pruning G means applying each depi(idxi) atomically -- that is, executing all pruning operations in the group as a single indivisible unit. This guarantees network consistency after pruning.

Group construction proceeds by BFS traversal from a root module:

  1. Select a root module and a pruning function (e.g., prune_conv_out_channels)
  2. Determine the indices to prune at the root
  3. Follow all outgoing dependency edges from the root node
  4. For each reached node, apply index mapping to translate source indices to target indices
  5. Continue traversal until all reachable nodes are visited
  6. Collect all (dependency, mapped_indices) pairs into the group

The index mapping step is critical for handling complex operations like concatenation and split, where channel indices in one layer do not directly correspond to the same indices in a connected layer.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment