Implementation:VainF Torch Pruning Node
| Knowledge Sources | |
|---|---|
| Domains | Graph_Analysis, Pruning, Model_Architecture |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for representing a single neural network layer or virtual operation as a vertex in the dependency graph, provided by the Torch-Pruning library.
Description
The Node class is the fundamental building block of the dependency graph (DependencyGraph). Every layer in the traced neural network becomes a Node, including both standard PyTorch modules (Conv2d, Linear, BatchNorm, etc.) and virtual operation modules (_ConcatOp, _SplitOp, _ElementWiseOp, etc.).
Each Node stores:
- module: Reference to the actual
nn.Moduleinstance. - grad_fn: The autograd gradient function associated with this module's output, used during graph tracing.
- inputs / outputs: Lists of connected Node objects representing the computational graph topology.
- dependencies: Adjacency list of Dependency objects that model how pruning propagates to/from this node.
- type: The
OPTYPEenum value determined byops.module2type(module). - pruning_dim: The dimension along which pruning is applied (set dynamically during dependency construction).
- enable_index_mapping: Boolean controlling index mapping for concat/split/chunk operations.
Usage
Import this class when you need to inspect or traverse the dependency graph built by DependencyGraph. Nodes are not typically created directly by users; they are constructed internally during DependencyGraph.build_dependency(). Access nodes through the graph to examine the model topology, pruning dimensions, and dependency relationships.
Code Reference
Source Location
- Repository: VainF_Torch_Pruning
- File: torch_pruning/dependency/node.py
- Lines: 1-60
Signature
class Node(object):
"""Node of DepGraph."""
def __init__(
self,
module: nn.Module,
grad_fn,
name: str = None,
):
"""
Args:
module: Reference to the torch.nn.Module (or virtual op).
grad_fn: Autograd gradient function of the module output.
name: Optional human-readable name from named_modules().
"""
self.inputs = [] # list[Node]: input nodes
self.outputs = [] # list[Node]: output nodes
self.module = module # nn.Module
self.grad_fn = grad_fn # grad_fn
self._name = name # str or None
self.type = ops.module2type(module) # OPTYPE enum
self.module_class = module.__class__
self.dependencies = [] # list[Dependency]
self.enable_index_mapping = True
self.pruning_dim = -1
@property
def name(self) -> str:
"""Formatted display name combining module name and type."""
...
def add_input(self, node: 'Node') -> None:
"""Append a node to the inputs list."""
...
def add_output(self, node: 'Node') -> None:
"""Append a node to the outputs list."""
...
def details(self) -> str:
"""Verbose dump of inputs, outputs, dependencies, and metadata."""
...
Import
from torch_pruning.dependency import Node
I/O Contract
__init__
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| module | nn.Module | Yes | The PyTorch module (or virtual op) this node represents |
| grad_fn | object | Yes | Autograd gradient function of the module's output tensor |
| name | str | No | Human-readable name from model.named_modules() |
Outputs
| Name | Type | Description |
|---|---|---|
| Node instance | Node | Initialized graph node with empty inputs/outputs/dependencies |
Key Attributes
| Name | Type | Description |
|---|---|---|
| inputs | list[Node] | Predecessor nodes in the computational graph |
| outputs | list[Node] | Successor nodes in the computational graph |
| dependencies | list[Dependency] | Pruning dependency edges to/from this node |
| type | OPTYPE | Operation type classification enum value |
| pruning_dim | int | Dimension along which pruning operates (default -1, set dynamically) |
| enable_index_mapping | bool | Whether index mapping is active for concat/split operations |
| name | str (property) | Formatted display name |
Usage Examples
Inspecting Graph Nodes
import torch
import torch.nn as nn
import torch_pruning as tp
# Build a simple model
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
)
# Build dependency graph
example_input = torch.randn(1, 3, 32, 32)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)
# Access a node for a specific module
node = DG.get_node(model[0]) # Node for first Conv2d
print(node) # <Node: (0 (Conv2d(3, 64, ...)))>
print(node.type) # OPTYPE.CONV
print(node.pruning_dim) # 0 or 1 depending on context
print(len(node.outputs)) # Number of successor nodes
Traversing Node Connections
# Walk the graph from the first conv
node = DG.get_node(model[0])
# Print output nodes
for out_node in node.outputs:
print(f" -> {out_node.name} (type={out_node.type})")
# Print dependencies
for dep in node.dependencies:
print(f" dep: {dep}")
# Detailed debug info
print(node.details())