Principle:VainF Torch Pruning Dependency Group Pruning
Metadata
| Field | Value |
|---|---|
| Paper | DepGraph: Towards Any Structural Pruning (Fang et al., 2023) |
| Domains | Deep_Learning, Model_Compression, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
The atomic unit of structural pruning that ensures all coupled layers are pruned consistently through dependency propagation. A dependency group represents the minimal set of (layer, indices) pairs that must be pruned together to maintain network consistency. This concept is fundamental to the Torch-Pruning framework and underpins all structural pruning operations it supports.
Description
In neural networks, layers are interconnected in ways that impose structural constraints on pruning. For example, a Conv2d -> BN -> ReLU chain requires that the output channels of the convolution match the number of features in the BatchNorm layer. If the convolution's output channels are pruned, the BatchNorm must be adjusted accordingly. A dependency group captures this coupling.
The DependencyGraph traces the model's computation graph to discover these couplings. It handles complex structural patterns including:
- Sequential chains:
Conv -> BN -> ReLUwhere output channels must be consistent - Concatenation:
Cat([Conv1, Conv2])where pruning indices must be mapped across concatenated tensors - Split/Chunk: Operations that divide a tensor along the channel dimension
- Reshape/View: Operations that change the tensor layout
- Multi-head attention: Where Q, K, V projections share dimensional constraints
- Residual connections: Where skip connections require matching dimensions at the addition point
A dependency group G is produced by BFS (breadth-first search) traversal from a root module through the dependency graph. Starting from a root layer (typically a Conv2d or Linear layer), the traversal follows dependency edges to discover all layers whose parameters must be modified when the root is pruned.
The group is structured as an ordered list of (dependency, indices) pairs:
group = [
(Conv2d -> BN, [0, 1, 2, 3]), # prune output channels of Conv, adjust BN
(BN -> ReLU, [0, 1, 2, 3]), # adjust downstream ReLU (passthrough)
(ReLU -> Conv2d, [0, 1, 2, 3]), # adjust input channels of next Conv
]
The indices do not need to be the full set of channels. For importance estimation, a full-index group is used to score all channels. For actual pruning, a subset group is created with only the channels to be removed.
Usage
Dependency group pruning is fundamental to all structural pruning in Torch-Pruning. Groups are used in two primary contexts:
- Automatic pruning via
BasePruner.step(): The pruner internally enumerates all groups, computes importance scores, determines which channels to remove, and callsgroup.prune()on each group. - Interactive/manual pruning: The user can obtain groups from the
DependencyGraphdirectly and inspect or modify them before pruning.
# Automatic: pruner handles groups internally
pruner.step()
# Interactive: user controls which groups to prune
for group in pruner.step(interactive=True):
print(group) # inspect the group
group.prune() # execute pruning
# Manual: obtain a specific group from the dependency graph
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)
group = DG.get_pruning_group(model.conv1, tp.prune_conv_out_channels, idxs=[0, 1])
group.prune()
Groups are also used for:
- Importance estimation: Importance functions receive a group and return per-channel scores
- Serialization: Pruning history is recorded per-group for model saving and loading
- Validation:
DG.check_pruning_group(group)verifies that a group is safe to prune
Theoretical Basis
A Group G is defined as:
where represents a dependency edge in the graph (connecting a source layer to a target layer via a pruning function pair), and represents the channel indices to prune in the target layer.
Pruning G means applying each atomically -- that is, executing all pruning operations in the group as a single indivisible unit. This guarantees network consistency after pruning.
Group construction proceeds by BFS traversal from a root module:
- Select a root module and a pruning function (e.g.,
prune_conv_out_channels) - Determine the indices to prune at the root
- Follow all outgoing dependency edges from the root node
- For each reached node, apply index mapping to translate source indices to target indices
- Continue traversal until all reachable nodes are visited
- Collect all
(dependency, mapped_indices)pairs into the group
The index mapping step is critical for handling complex operations like concatenation and split, where channel indices in one layer do not directly correspond to the same indices in a connected layer.