Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning GrowingRegPruner

From Leeroopedia


Metadata

Field Value
Sources Torch-Pruning, Growing Regularization
Domains Deep_Learning, Regularization, Pruning
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for growing regularization-based pruning provided by Torch-Pruning. GrowingRegPruner extends BasePruner with progressively increasing per-channel regularization, implementing the strategy described in Neural Pruning via Growing Regularization (Wang et al., 2021). It maintains per-group regularization state and exposes three key methods for integration into a training loop: update_regularizer(), update_reg(), and regularize().

Description

GrowingRegPruner extends BasePruner with progressively increasing regularization. It maintains per-group regularization state (self.group_reg), where each group corresponds to a set of coupled layers discovered by the DependencyGraph.

The workflow for using GrowingRegPruner involves three methods:

  • update_regularizer() -- Refreshes the internal group list after a pruning step. Resets all per-group regularization coefficients to the base value. Must be called after each call to pruner.step().
  • update_reg() -- Increments the per-channel regularization penalties by delta_reg * standardized_importance. Should be called once per epoch (or at the desired regularization update frequency).
  • regularize(model, bias=False) -- Applies the accumulated regularization to the model's gradients in-place. Should be called every training iteration after loss.backward() and before optimizer.step(). Modifies weight.grad.data directly for BatchNorm, Conv, and Linear layers.

The regularization targets both output and input channels of convolution and linear layers, as well as BatchNorm affine parameters, ensuring that all coupled parameters in a dependency group are regularized consistently.

Code Reference

Field Value
Source File torch_pruning/pruner/algorithms/growing_reg_pruner.py, Lines 9-182
Parent Class BasePruner (from torch_pruning/pruner/algorithms/base_pruner.py)
Import import torch_pruning as tp; tp.pruner.GrowingRegPruner

Constructor Signature:

class GrowingRegPruner(BasePruner):
    def __init__(
        self,
        model: nn.Module,
        example_inputs: torch.Tensor,
        importance: typing.Callable,
        reg=1e-5,
        delta_reg=1e-5,
        global_pruning: bool = False,
        pruning_ratio: float = 0.5,
        pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
        max_pruning_ratio: float = 1.0,
        iterative_steps: int = 1,
        iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler,
        ignored_layers: typing.List[nn.Module] = None,
        round_to: int = None,
        isomorphic: bool = False,
        in_channel_groups: typing.Dict[nn.Module, int] = dict(),
        out_channel_groups: typing.Dict[nn.Module, int] = dict(),
        num_heads: typing.Dict[nn.Module, int] = dict(),
        prune_num_heads: bool = False,
        prune_head_dims: bool = True,
        head_pruning_ratio: float = 0.0,
        head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
        customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None,
        unwrapped_parameters: typing.Dict[nn.Parameter, int] = None,
        root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],
        forward_fn: typing.Callable = None,
        output_transform: typing.Callable = None,
        channel_groups: typing.Dict[nn.Module, int] = dict(),
        ch_sparsity: float = None,
        ch_sparsity_dict: typing.Dict[nn.Module, float] = None,
    ):

Key Methods:

def update_regularizer(self):
    """Refresh group list and reset per-group regularization after pruning."""

def update_reg(self):
    """Increment per-channel regularization by delta_reg * standardized_importance."""

def regularize(self, model, bias=False):
    """Apply accumulated regularization to weight gradients in-place."""

I/O Contract

Parameter Type Default Description
reg float 1e-5 Base regularization coefficient. All groups are initialized with this value.
delta_reg float 1e-5 Increment added to regularization each epoch, scaled by standardized importance.
bias bool False If True, regularize() also penalizes bias parameters.

Behavior:

  • update_reg() increments self.group_reg by delta_reg * standardized_importance where standardized importance maps raw importance scores to [0, 1] with lower-importance channels receiving higher values.
  • regularize() modifies weight.grad.data in-place by adding gamma * weight for each channel, where gamma is the accumulated per-channel regularization coefficient.

Usage Examples

Training loop with growing regularization:

import torch
import torch.nn as nn
import torch_pruning as tp

# 1. Setup model and pruner
model = YourModel()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.GrowingRegPruner(
    model=model,
    example_inputs=example_inputs,
    importance=imp,
    reg=1e-5,
    delta_reg=1e-5,
    pruning_ratio=0.5,
    iterative_steps=3,
)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 2. Regularization phase: train with growing regularization
for epoch in range(num_reg_epochs):
    pruner.update_reg()  # increment per-channel penalties each epoch
    for images, labels in train_loader:
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        pruner.regularize(model)  # apply regularization to gradients
        optimizer.step()

# 3. Pruning step: remove low-importance channels
pruner.step()

# 4. Reset regularizer for next round (if iterative pruning)
pruner.update_regularizer()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment