Principle:VainF Torch Pruning Dependency Graph Representation
| Knowledge Sources | |
|---|---|
| Domains | Graph_Analysis, Pruning, Model_Architecture |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
A graph-based data structure that represents each neural network layer as a typed vertex with computational connections and pruning dependency edges, enabling automatic propagation of structural changes across the network.
Description
Dependency Graph Representation addresses the fundamental challenge of structural pruning in deep neural networks: when a channel is removed from one layer, coupled layers throughout the network must be updated consistently. For example, pruning output channels of a Conv2d requires corresponding changes to the next BatchNorm2d, the next Conv2d's input channels, and potentially layers connected through skip connections or concatenations.
A dependency graph solves this by modeling the network as a directed graph where:
- Each layer (or operation) becomes a node storing the module reference, operation type, and pruning metadata.
- Computational edges (inputs/outputs) capture the forward-pass data flow, mirroring the autograd graph.
- Dependency edges capture how pruning of one node's channels affects other nodes, including the direction (input vs. output channels) and index mapping for operations like concatenation and splitting.
This dual-edge representation (computational + dependency) allows the pruning system to automatically determine all layers that must change when any single layer is pruned, group them together, and apply consistent index transformations.
Usage
Use this principle when designing a structural pruning system that must handle arbitrary network architectures including residual connections, concatenation, splitting, and other complex topologies. The graph representation is the core abstraction that makes architecture-agnostic pruning possible.
Theoretical Basis
The dependency graph is formally a labeled directed graph where:
- = set of nodes, one per layer/operation
- = computational edges (data flow)
- = dependency edges with channel direction labels
Pseudo-code Logic:
# Abstract algorithm description (NOT real implementation)
# Step 1: Build computational graph via autograd tracing
for each module in network:
node = Node(module, grad_fn, name)
node.type = classify(module) # OPTYPE enum
# Step 2: Connect nodes based on autograd graph
for each (producer, consumer) in traced_connections:
producer.add_output(consumer)
consumer.add_input(producer)
# Step 3: Build dependency edges
for each computational edge (A -> B):
# If pruning A's out_channels affects B's in_channels:
add_dependency(A, out_dim, B, in_dim, index_mapping)
# Step 4: Query for pruning groups
group = graph.get_pruning_group(target_node, channel_indices)
# Returns all (node, indices) pairs that must change together
Key properties of each node:
- type: Operation category from the type classification system
- pruning_dim: Which tensor dimension corresponds to channels for this operation
- dependencies: List of dependency edges encoding inter-layer constraints
- enable_index_mapping: Whether index transforms are needed (for concat/split operations)
The graph construction relies on PyTorch's autograd mechanism to trace the computational graph, then augments it with dependency edges derived from the operation types and their pruning semantics.