Heuristic:VainF Torch Pruning AutoGrad Dependency Graph
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Debugging |
| Last Updated | 2026-02-08 12:00 GMT |
Overview
AutoGrad must be enabled (no torch.no_grad()) when building the DependencyGraph, because Torch-Pruning uses PyTorch's autograd engine to trace the computation graph.
Description
Torch-Pruning traces the model's forward pass using PyTorch's autograd mechanism to discover inter-layer structural dependencies. If autograd is disabled (e.g., inside a torch.no_grad() context manager or with torch.set_grad_enabled(False)), the tracing will fail because the computational graph is not recorded.
This is a common pitfall for users who habitually wrap inference code in torch.no_grad() for efficiency and forget to remove it when initializing the pruner.
Usage
Use this heuristic whenever you encounter errors during DependencyGraph.build_dependency() or BasePruner.__init__(). If the error occurs during the forward pass trace, the most likely cause is that autograd is disabled.
The Insight (Rule of Thumb)
- Action: Ensure that
torch.no_grad()andtorch.inference_mode()are NOT active when building the dependency graph or initializing a pruner. - Value: N/A (boolean requirement).
- Trade-off: Slightly higher memory usage during graph building due to gradient tracking, but this is a one-time cost that is negligible compared to the model size.
Reasoning
Torch-Pruning's DependencyGraph.build_dependency() passes example_inputs through the model and inspects the resulting autograd graph to discover which layers are connected. Each node in the autograd graph corresponds to an operation (convolution, linear, batch norm, etc.), and the edges reveal the data flow between layers. Without autograd, no graph is recorded, and the dependency analysis cannot proceed.
This is fundamentally different from model tracing with torch.jit.trace(), which builds an execution graph even without gradients. Torch-Pruning specifically relies on the autograd graph because it provides richer structural information about parameter dependencies.
Code Evidence
From README.md (line 95):
Please make sure that AutoGrad is enabled since TP will analyze the model
structure with the Pytorch AutoGrad. This means we need to remove torch.no_grad()
or something similar when building the dependency graph.
Flexible input handling with autograd from torch_pruning/dependency/graph.py:472-475:
try:
out = model(*example_inputs)
except:
out = model(example_inputs)
Note that count_ops_and_params uses @torch.no_grad() (at torch_pruning/utils/op_counter.py:21), which is fine because FLOPs counting does not need the dependency graph. But this decorator must NOT be present during graph building.