Implementation:VainF Torch Pruning GroupTaylorImportance
Overview
GroupTaylorImportance is the concrete tool for first-order Taylor expansion importance estimation provided by the Torch-Pruning library. It computes a per-channel importance score by combining each parameter's magnitude with its gradient, enabling data-driven structural pruning decisions that account for how each channel affects the loss on real data.
Description
GroupTaylorImportance extends GroupMagnitudeImportance, inheriting its group reduction, normalization, and layer-type filtering capabilities. The key difference is in how per-channel importance is computed: instead of using weight magnitudes alone, it computes the element-wise product of the weight tensor w and its gradient dw (i.e., |w * dw|), where dw is obtained from loss.backward().
The class supports two aggregation modes controlled by the multivariable parameter:
- Standard mode (
multivariable=False, default): Computes(w * dw).abs().sum(1)-- takes the absolute value of each element-wise product first, then sums across the channel dimension. This is the abs-then-sum variant. - Multivariable mode (
multivariable=True): Computes(w * dw).sum(1).abs()-- sums the element-wise products first, then takes the absolute value. This is the sum-then-abs variant, which allows positive and negative contributions to cancel.
The scorer handles several layer types:
- Conv/Linear output channels: Indexes the weight tensor by output channel, flattens, and computes the Taylor product.
- Conv/Linear input channels: Transposes the weight tensor to index by input channel, then computes the Taylor product. Handles group convolutions by repeating importance values.
- BatchNorm: Uses the affine scale parameter and its gradient when
layer.affineisTrue. - LayerNorm: Uses the elementwise affine parameter and its gradient when
layer.elementwise_affineisTrue.
After computing local importance for each layer in the dependency group, the scores are reduced across the group (via _reduce) and normalized (via _normalize), both inherited from GroupMagnitudeImportance.
PREREQUISITE: loss.backward() must be called before invoking this scorer. Without gradients populated on the model parameters, the computation will fail.
Usage
Use GroupTaylorImportance when:
- Calibration data is available and you can afford a forward and backward pass to compute gradients.
- Gradient-informed importance scores are desired for more accurate pruning than magnitude-only methods.
- You are performing structural pruning using the Torch-Pruning dependency graph framework.
Code Reference
Source file: torch_pruning/pruner/importance.py, Lines 411--542
Class signature:
class GroupTaylorImportance(GroupMagnitudeImportance):
def __init__(self,
group_reduction: str = "mean",
normalizer: str = 'mean',
multivariable: bool = False,
bias: bool = False,
target_types: list = [nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm, nn.modules.LayerNorm]):
Import paths:
from torch_pruning.pruner.importance import GroupTaylorImportance
or equivalently:
import torch_pruning as tp
tp.importance.GroupTaylorImportance
I/O Contract
Constructor Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
group_reduction |
str | "mean" |
Reduction method for combining importance scores across layers in a group. Options include "mean", "sum", "max", "prod", "first".
|
normalizer |
str | "mean" |
Normalization method applied to the final importance vector. Options include "mean", "sum", "max", "standarization", "gaussian", "lamp", or None.
|
multivariable |
bool | False |
When False, uses abs-then-sum (standard). When True, uses sum-then-abs (multivariable).
|
bias |
bool | False |
Whether to include bias parameters in the importance computation. |
target_types |
list | [_ConvNd, Linear, _BatchNorm, LayerNorm] |
Layer types to consider when computing importance. Layers not matching these types are skipped. |
__call__ Method
| Parameter | Type | Required | Description |
|---|---|---|---|
group |
Group |
Yes | A dependency group from the Torch-Pruning DependencyGraph. Contains tuples of (dependency, indices) representing coupled pruning operations.
|
Returns: A 1-D torch.Tensor of importance scores with length equal to the number of channels in the group, or None if no parameterized layers of the target types are found in the group.
PREREQUISITE: loss.backward() must be called before invoking this scorer so that .grad attributes are populated on model parameters.
Usage Examples
Basic usage with dependency graph:
import torch
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.models import resnet18
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# Build dependency graph
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)
# Forward and backward pass to accumulate gradients
inputs, labels = example_inputs, torch.randint(0, 1000, (1,))
loss = F.cross_entropy(model(inputs), labels)
loss.backward()
# Create scorer and compute importance for a group
scorer = tp.importance.GroupTaylorImportance()
group = DG.get_pruning_group(model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9])
imp_score = scorer(group)
# imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]
min_score = imp_score.min()
Using with a pruner for iterative pruning:
import torch
import torch_pruning as tp
from torchvision.models import resnet18
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)
# Set up Taylor importance scorer
imp = tp.importance.GroupTaylorImportance()
# Identify layers to ignore (e.g., final classifier)
ignored_layers = []
for m in model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
ignored_layers.append(m)
# Configure pruner
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=imp,
iterative_steps=5,
pruning_ratio=0.5,
ignored_layers=ignored_layers,
)
# Iterative pruning loop
for step in range(5):
# Compute gradients on calibration data
loss = model(example_inputs).sum() # replace with real loss
loss.backward()
# Execute one pruning step (importance is computed internally)
pruner.step()
# Fine-tune model here...
Using multivariable mode:
# Multivariable variant: sum-then-abs instead of abs-then-sum
scorer = tp.importance.GroupTaylorImportance(multivariable=True)