Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:VainF Torch Pruning Ops

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Pruning, Graph_Analysis
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for classifying neural network operations into typed categories and providing virtual operation modules and dummy pruners for non-standard operations in the dependency graph, provided by the Torch-Pruning library.

Description

The ops module is the core type system of Torch-Pruning. It provides three main components:

1. Virtual Operation Modules: Lightweight nn.Module subclasses that represent operations detected during autograd graph tracing that are not standard PyTorch modules. These include _ConcatOp (torch.cat), _SplitOp (torch.split), _UnbindOp (torch.unbind), _ReshapeOp (view/reshape), _SliceOp (tensor slicing), _OutputOp (model outputs), _ElementWiseOp (add/sub/mul/div), _ExpandOp (tensor expansion), and _CustomizedOp (user-defined operations).

2. OPTYPE Enum: An IntEnum with 21 operation types that classifies every neural network layer for dependency analysis: CONV, BN, LINEAR, PRELU, DEPTHWISE_CONV, CONCAT, SPLIT, CUSTOMIZED, ELEMENTWISE, LN, EMBED, PARAMETER, MHA, LSTM, RESHAPE, GN, IN, UNBIND, EXPAND, SLICE, OUTPUT.

3. Dummy Pruners: DummyPruner (no-op base) and specialized pruners for virtual ops that handle metadata updates during pruning: ConcatPruner, SplitPruner, SlicePruner, OutputPruner, plus pass-through pruners for Unbind, Expand, Reshape, ElementWise, and Customized operations.

4. Type Mapping Functions: module2type() converts a module instance to its OPTYPE enum value. type2class() converts an OPTYPE back to the corresponding module class. Constants like TORCH_CONV, TORCH_BATCHNORM, TORCH_LINEAR alias standard PyTorch module base classes.

Usage

Import this module when you need to classify PyTorch modules into operation types for dependency graph construction, or when extending Torch-Pruning to handle new operation types. It is used internally by the dependency graph builder (DependencyGraph) and pruning functions, and is also needed when registering custom operations.

Code Reference

Source Location

Signature

class OPTYPE(IntEnum):
    CONV = 0
    BN = 1
    LINEAR = 2
    PRELU = 3
    DEPTHWISE_CONV = 4
    CONCAT = 5
    SPLIT = 6
    CUSTOMIZED = 7
    ELEMENTWISE = 8
    LN = 9
    EMBED = 10
    PARAMETER = 11
    MHA = 12
    LSTM = 13
    RESHAPE = 14
    GN = 15
    IN = 16
    UNBIND = 17
    EXPAND = 18
    SLICE = 19
    OUTPUT = 20

def module2type(module) -> OPTYPE:
    """Classify a PyTorch module into its OPTYPE enum value."""
    ...

def type2class(op_type: OPTYPE):
    """Convert an OPTYPE enum value back to its corresponding module class."""
    ...

class DummyPruner(object):
    """No-op base pruner for virtual operations."""
    def prune_out_channels(self, layer, idxs): ...
    def prune_in_channels(self, layer, idxs): ...
    def get_out_channels(self, layer) -> int: ...
    def get_in_channels(self, layer) -> int: ...
    def get_in_channel_groups(self, layer) -> int: ...
    def get_out_channel_groups(self, layer) -> int: ...

class ConcatPruner(DummyPruner):
    """Pruner that updates concat_sizes and offsets when channels are removed."""
    def prune_out_channels(self, layer, idxs): ...

class SplitPruner(DummyPruner):
    """Pruner that updates split_sizes and offsets when channels are removed."""
    def prune_out_channels(self, layer, idxs): ...

class SlicePruner(DummyPruner):
    """Pruner that adjusts slice start/end offsets when channels are removed."""
    def prune_out_channels(self, layer, idxs): ...

class OutputPruner(DummyPruner):
    """Pruner that updates output shape metadata when channels are removed."""
    def prune_out_channels(self, layer, idxs): ...
    def get_in_channels(self, layer) -> int: ...

Import

from torch_pruning import ops
from torch_pruning.ops import OPTYPE, module2type, type2class, DummyPruner

I/O Contract

module2type

Inputs

Name Type Required Description
module nn.Module Yes Any PyTorch module or virtual operation module

Outputs

Name Type Description
return OPTYPE The enum value classifying the module type

type2class

Inputs

Name Type Required Description
op_type OPTYPE Yes An operation type enum value

Outputs

Name Type Description
return type The PyTorch module class corresponding to the op type

DummyPruner.prune_out_channels / prune_in_channels

Inputs

Name Type Required Description
layer nn.Module Yes The virtual operation module to prune
idxs list[int] Yes Channel indices to remove

Outputs

Name Type Description
return nn.Module or None The layer with updated metadata (in-place modification)

Usage Examples

Classifying a Module

import torch.nn as nn
from torch_pruning.ops import module2type, OPTYPE

# Standard modules
conv = nn.Conv2d(64, 128, 3)
assert module2type(conv) == OPTYPE.CONV

bn = nn.BatchNorm2d(128)
assert module2type(bn) == OPTYPE.BN

linear = nn.Linear(512, 10)
assert module2type(linear) == OPTYPE.LINEAR

# Depthwise convolution is auto-detected
dw_conv = nn.Conv2d(64, 64, 3, groups=64)
assert module2type(dw_conv) == OPTYPE.DEPTHWISE_CONV

Converting OPTYPE Back to Class

from torch_pruning.ops import type2class, OPTYPE

cls = type2class(OPTYPE.CONV)
# cls is torch.nn.modules.conv._ConvNd

cls = type2class(OPTYPE.LINEAR)
# cls is torch.nn.Linear

Using Type Constants

from torch_pruning.ops import TORCH_CONV, TORCH_BATCHNORM, TORCH_LINEAR

# Check if a module is a convolution (including all Conv subclasses)
model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64))
for m in model.modules():
    if isinstance(m, TORCH_CONV):
        print(f"Found convolution: {m}")
    elif isinstance(m, TORCH_BATCHNORM):
        print(f"Found batchnorm: {m}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment