Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning GroupTaylorImportance

From Leeroopedia


Template:Metadata

Overview

GroupTaylorImportance is the concrete tool for first-order Taylor expansion importance estimation provided by the Torch-Pruning library. It computes a per-channel importance score by combining each parameter's magnitude with its gradient, enabling data-driven structural pruning decisions that account for how each channel affects the loss on real data.

Description

GroupTaylorImportance extends GroupMagnitudeImportance, inheriting its group reduction, normalization, and layer-type filtering capabilities. The key difference is in how per-channel importance is computed: instead of using weight magnitudes alone, it computes the element-wise product of the weight tensor w and its gradient dw (i.e., |w * dw|), where dw is Lw obtained from loss.backward().

The class supports two aggregation modes controlled by the multivariable parameter:

  • Standard mode (multivariable=False, default): Computes (w * dw).abs().sum(1) -- takes the absolute value of each element-wise product first, then sums across the channel dimension. This is the abs-then-sum variant.
  • Multivariable mode (multivariable=True): Computes (w * dw).sum(1).abs() -- sums the element-wise products first, then takes the absolute value. This is the sum-then-abs variant, which allows positive and negative contributions to cancel.

The scorer handles several layer types:

  • Conv/Linear output channels: Indexes the weight tensor by output channel, flattens, and computes the Taylor product.
  • Conv/Linear input channels: Transposes the weight tensor to index by input channel, then computes the Taylor product. Handles group convolutions by repeating importance values.
  • BatchNorm: Uses the affine scale parameter and its gradient when layer.affine is True.
  • LayerNorm: Uses the elementwise affine parameter and its gradient when layer.elementwise_affine is True.

After computing local importance for each layer in the dependency group, the scores are reduced across the group (via _reduce) and normalized (via _normalize), both inherited from GroupMagnitudeImportance.

PREREQUISITE: loss.backward() must be called before invoking this scorer. Without gradients populated on the model parameters, the computation will fail.

Usage

Use GroupTaylorImportance when:

  • Calibration data is available and you can afford a forward and backward pass to compute gradients.
  • Gradient-informed importance scores are desired for more accurate pruning than magnitude-only methods.
  • You are performing structural pruning using the Torch-Pruning dependency graph framework.

Code Reference

Source file: torch_pruning/pruner/importance.py, Lines 411--542

Class signature:

class GroupTaylorImportance(GroupMagnitudeImportance):
    def __init__(self,
                 group_reduction: str = "mean",
                 normalizer: str = 'mean',
                 multivariable: bool = False,
                 bias: bool = False,
                 target_types: list = [nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm, nn.modules.LayerNorm]):

Import paths:

from torch_pruning.pruner.importance import GroupTaylorImportance

or equivalently:

import torch_pruning as tp
tp.importance.GroupTaylorImportance

I/O Contract

Constructor Parameters

Parameter Type Default Description
group_reduction str "mean" Reduction method for combining importance scores across layers in a group. Options include "mean", "sum", "max", "prod", "first".
normalizer str "mean" Normalization method applied to the final importance vector. Options include "mean", "sum", "max", "standarization", "gaussian", "lamp", or None.
multivariable bool False When False, uses abs-then-sum (standard). When True, uses sum-then-abs (multivariable).
bias bool False Whether to include bias parameters in the importance computation.
target_types list [_ConvNd, Linear, _BatchNorm, LayerNorm] Layer types to consider when computing importance. Layers not matching these types are skipped.

__call__ Method

Parameter Type Required Description
group Group Yes A dependency group from the Torch-Pruning DependencyGraph. Contains tuples of (dependency, indices) representing coupled pruning operations.

Returns: A 1-D torch.Tensor of importance scores with length equal to the number of channels in the group, or None if no parameterized layers of the target types are found in the group.

PREREQUISITE: loss.backward() must be called before invoking this scorer so that .grad attributes are populated on model parameters.

Usage Examples

Basic usage with dependency graph:

import torch
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.models import resnet18

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# Build dependency graph
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# Forward and backward pass to accumulate gradients
inputs, labels = example_inputs, torch.randint(0, 1000, (1,))
loss = F.cross_entropy(model(inputs), labels)
loss.backward()

# Create scorer and compute importance for a group
scorer = tp.importance.GroupTaylorImportance()
group = DG.get_pruning_group(model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9])
imp_score = scorer(group)
# imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]
min_score = imp_score.min()

Using with a pruner for iterative pruning:

import torch
import torch_pruning as tp
from torchvision.models import resnet18

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# Set up Taylor importance scorer
imp = tp.importance.GroupTaylorImportance()

# Identify layers to ignore (e.g., final classifier)
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m)

# Configure pruner
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=5,
    pruning_ratio=0.5,
    ignored_layers=ignored_layers,
)

# Iterative pruning loop
for step in range(5):
    # Compute gradients on calibration data
    loss = model(example_inputs).sum()  # replace with real loss
    loss.backward()

    # Execute one pruning step (importance is computed internally)
    pruner.step()

    # Fine-tune model here...

Using multivariable mode:

# Multivariable variant: sum-then-abs instead of abs-then-sum
scorer = tp.importance.GroupTaylorImportance(multivariable=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment