Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:VainF Torch Pruning Node

From Leeroopedia


Knowledge Sources
Domains Graph_Analysis, Pruning, Model_Architecture
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for representing a single neural network layer or virtual operation as a vertex in the dependency graph, provided by the Torch-Pruning library.

Description

The Node class is the fundamental building block of the dependency graph (DependencyGraph). Every layer in the traced neural network becomes a Node, including both standard PyTorch modules (Conv2d, Linear, BatchNorm, etc.) and virtual operation modules (_ConcatOp, _SplitOp, _ElementWiseOp, etc.).

Each Node stores:

  • module: Reference to the actual nn.Module instance.
  • grad_fn: The autograd gradient function associated with this module's output, used during graph tracing.
  • inputs / outputs: Lists of connected Node objects representing the computational graph topology.
  • dependencies: Adjacency list of Dependency objects that model how pruning propagates to/from this node.
  • type: The OPTYPE enum value determined by ops.module2type(module).
  • pruning_dim: The dimension along which pruning is applied (set dynamically during dependency construction).
  • enable_index_mapping: Boolean controlling index mapping for concat/split/chunk operations.

Usage

Import this class when you need to inspect or traverse the dependency graph built by DependencyGraph. Nodes are not typically created directly by users; they are constructed internally during DependencyGraph.build_dependency(). Access nodes through the graph to examine the model topology, pruning dimensions, and dependency relationships.

Code Reference

Source Location

Signature

class Node(object):
    """Node of DepGraph."""
    def __init__(
        self,
        module: nn.Module,
        grad_fn,
        name: str = None,
    ):
        """
        Args:
            module: Reference to the torch.nn.Module (or virtual op).
            grad_fn: Autograd gradient function of the module output.
            name: Optional human-readable name from named_modules().
        """
        self.inputs = []           # list[Node]: input nodes
        self.outputs = []          # list[Node]: output nodes
        self.module = module       # nn.Module
        self.grad_fn = grad_fn     # grad_fn
        self._name = name          # str or None
        self.type = ops.module2type(module)  # OPTYPE enum
        self.module_class = module.__class__
        self.dependencies = []     # list[Dependency]
        self.enable_index_mapping = True
        self.pruning_dim = -1

    @property
    def name(self) -> str:
        """Formatted display name combining module name and type."""
        ...

    def add_input(self, node: 'Node') -> None:
        """Append a node to the inputs list."""
        ...

    def add_output(self, node: 'Node') -> None:
        """Append a node to the outputs list."""
        ...

    def details(self) -> str:
        """Verbose dump of inputs, outputs, dependencies, and metadata."""
        ...

Import

from torch_pruning.dependency import Node

I/O Contract

__init__

Inputs

Name Type Required Description
module nn.Module Yes The PyTorch module (or virtual op) this node represents
grad_fn object Yes Autograd gradient function of the module's output tensor
name str No Human-readable name from model.named_modules()

Outputs

Name Type Description
Node instance Node Initialized graph node with empty inputs/outputs/dependencies

Key Attributes

Name Type Description
inputs list[Node] Predecessor nodes in the computational graph
outputs list[Node] Successor nodes in the computational graph
dependencies list[Dependency] Pruning dependency edges to/from this node
type OPTYPE Operation type classification enum value
pruning_dim int Dimension along which pruning operates (default -1, set dynamically)
enable_index_mapping bool Whether index mapping is active for concat/split operations
name str (property) Formatted display name

Usage Examples

Inspecting Graph Nodes

import torch
import torch.nn as nn
import torch_pruning as tp

# Build a simple model
model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, 3, padding=1),
)

# Build dependency graph
example_input = torch.randn(1, 3, 32, 32)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_input)

# Access a node for a specific module
node = DG.get_node(model[0])  # Node for first Conv2d
print(node)              # <Node: (0 (Conv2d(3, 64, ...)))>
print(node.type)         # OPTYPE.CONV
print(node.pruning_dim)  # 0 or 1 depending on context
print(len(node.outputs)) # Number of successor nodes

Traversing Node Connections

# Walk the graph from the first conv
node = DG.get_node(model[0])

# Print output nodes
for out_node in node.outputs:
    print(f"  -> {out_node.name} (type={out_node.type})")

# Print dependencies
for dep in node.dependencies:
    print(f"  dep: {dep}")

# Detailed debug info
print(node.details())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment