Implementation:VainF Torch Pruning Ops
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Pruning, Graph_Analysis |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for classifying neural network operations into typed categories and providing virtual operation modules and dummy pruners for non-standard operations in the dependency graph, provided by the Torch-Pruning library.
Description
The ops module is the core type system of Torch-Pruning. It provides three main components:
1. Virtual Operation Modules: Lightweight nn.Module subclasses that represent operations detected during autograd graph tracing that are not standard PyTorch modules. These include _ConcatOp (torch.cat), _SplitOp (torch.split), _UnbindOp (torch.unbind), _ReshapeOp (view/reshape), _SliceOp (tensor slicing), _OutputOp (model outputs), _ElementWiseOp (add/sub/mul/div), _ExpandOp (tensor expansion), and _CustomizedOp (user-defined operations).
2. OPTYPE Enum: An IntEnum with 21 operation types that classifies every neural network layer for dependency analysis: CONV, BN, LINEAR, PRELU, DEPTHWISE_CONV, CONCAT, SPLIT, CUSTOMIZED, ELEMENTWISE, LN, EMBED, PARAMETER, MHA, LSTM, RESHAPE, GN, IN, UNBIND, EXPAND, SLICE, OUTPUT.
3. Dummy Pruners: DummyPruner (no-op base) and specialized pruners for virtual ops that handle metadata updates during pruning: ConcatPruner, SplitPruner, SlicePruner, OutputPruner, plus pass-through pruners for Unbind, Expand, Reshape, ElementWise, and Customized operations.
4. Type Mapping Functions: module2type() converts a module instance to its OPTYPE enum value. type2class() converts an OPTYPE back to the corresponding module class. Constants like TORCH_CONV, TORCH_BATCHNORM, TORCH_LINEAR alias standard PyTorch module base classes.
Usage
Import this module when you need to classify PyTorch modules into operation types for dependency graph construction, or when extending Torch-Pruning to handle new operation types. It is used internally by the dependency graph builder (DependencyGraph) and pruning functions, and is also needed when registering custom operations.
Code Reference
Source Location
- Repository: VainF_Torch_Pruning
- File: torch_pruning/ops.py
- Lines: 1-346
Signature
class OPTYPE(IntEnum):
CONV = 0
BN = 1
LINEAR = 2
PRELU = 3
DEPTHWISE_CONV = 4
CONCAT = 5
SPLIT = 6
CUSTOMIZED = 7
ELEMENTWISE = 8
LN = 9
EMBED = 10
PARAMETER = 11
MHA = 12
LSTM = 13
RESHAPE = 14
GN = 15
IN = 16
UNBIND = 17
EXPAND = 18
SLICE = 19
OUTPUT = 20
def module2type(module) -> OPTYPE:
"""Classify a PyTorch module into its OPTYPE enum value."""
...
def type2class(op_type: OPTYPE):
"""Convert an OPTYPE enum value back to its corresponding module class."""
...
class DummyPruner(object):
"""No-op base pruner for virtual operations."""
def prune_out_channels(self, layer, idxs): ...
def prune_in_channels(self, layer, idxs): ...
def get_out_channels(self, layer) -> int: ...
def get_in_channels(self, layer) -> int: ...
def get_in_channel_groups(self, layer) -> int: ...
def get_out_channel_groups(self, layer) -> int: ...
class ConcatPruner(DummyPruner):
"""Pruner that updates concat_sizes and offsets when channels are removed."""
def prune_out_channels(self, layer, idxs): ...
class SplitPruner(DummyPruner):
"""Pruner that updates split_sizes and offsets when channels are removed."""
def prune_out_channels(self, layer, idxs): ...
class SlicePruner(DummyPruner):
"""Pruner that adjusts slice start/end offsets when channels are removed."""
def prune_out_channels(self, layer, idxs): ...
class OutputPruner(DummyPruner):
"""Pruner that updates output shape metadata when channels are removed."""
def prune_out_channels(self, layer, idxs): ...
def get_in_channels(self, layer) -> int: ...
Import
from torch_pruning import ops
from torch_pruning.ops import OPTYPE, module2type, type2class, DummyPruner
I/O Contract
module2type
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| module | nn.Module | Yes | Any PyTorch module or virtual operation module |
Outputs
| Name | Type | Description |
|---|---|---|
| return | OPTYPE | The enum value classifying the module type |
type2class
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| op_type | OPTYPE | Yes | An operation type enum value |
Outputs
| Name | Type | Description |
|---|---|---|
| return | type | The PyTorch module class corresponding to the op type |
DummyPruner.prune_out_channels / prune_in_channels
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| layer | nn.Module | Yes | The virtual operation module to prune |
| idxs | list[int] | Yes | Channel indices to remove |
Outputs
| Name | Type | Description |
|---|---|---|
| return | nn.Module or None | The layer with updated metadata (in-place modification) |
Usage Examples
Classifying a Module
import torch.nn as nn
from torch_pruning.ops import module2type, OPTYPE
# Standard modules
conv = nn.Conv2d(64, 128, 3)
assert module2type(conv) == OPTYPE.CONV
bn = nn.BatchNorm2d(128)
assert module2type(bn) == OPTYPE.BN
linear = nn.Linear(512, 10)
assert module2type(linear) == OPTYPE.LINEAR
# Depthwise convolution is auto-detected
dw_conv = nn.Conv2d(64, 64, 3, groups=64)
assert module2type(dw_conv) == OPTYPE.DEPTHWISE_CONV
Converting OPTYPE Back to Class
from torch_pruning.ops import type2class, OPTYPE
cls = type2class(OPTYPE.CONV)
# cls is torch.nn.modules.conv._ConvNd
cls = type2class(OPTYPE.LINEAR)
# cls is torch.nn.Linear
Using Type Constants
from torch_pruning.ops import TORCH_CONV, TORCH_BATCHNORM, TORCH_LINEAR
# Check if a module is a convolution (including all Conv subclasses)
model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64))
for m in model.modules():
if isinstance(m, TORCH_CONV):
print(f"Found convolution: {m}")
elif isinstance(m, TORCH_BATCHNORM):
print(f"Found batchnorm: {m}")