Implementation:VainF Torch Pruning GrowingRegPruner
Metadata
| Field | Value |
|---|---|
| Sources | Torch-Pruning, Growing Regularization |
| Domains | Deep_Learning, Regularization, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for growing regularization-based pruning provided by Torch-Pruning. GrowingRegPruner extends BasePruner with progressively increasing per-channel regularization, implementing the strategy described in Neural Pruning via Growing Regularization (Wang et al., 2021). It maintains per-group regularization state and exposes three key methods for integration into a training loop: update_regularizer(), update_reg(), and regularize().
Description
GrowingRegPruner extends BasePruner with progressively increasing regularization. It maintains per-group regularization state (self.group_reg), where each group corresponds to a set of coupled layers discovered by the DependencyGraph.
The workflow for using GrowingRegPruner involves three methods:
update_regularizer()-- Refreshes the internal group list after a pruning step. Resets all per-group regularization coefficients to the base value. Must be called after each call topruner.step().update_reg()-- Increments the per-channel regularization penalties bydelta_reg * standardized_importance. Should be called once per epoch (or at the desired regularization update frequency).regularize(model, bias=False)-- Applies the accumulated regularization to the model's gradients in-place. Should be called every training iteration afterloss.backward()and beforeoptimizer.step(). Modifiesweight.grad.datadirectly for BatchNorm, Conv, and Linear layers.
The regularization targets both output and input channels of convolution and linear layers, as well as BatchNorm affine parameters, ensuring that all coupled parameters in a dependency group are regularized consistently.
Code Reference
| Field | Value |
|---|---|
| Source File | torch_pruning/pruner/algorithms/growing_reg_pruner.py, Lines 9-182
|
| Parent Class | BasePruner (from torch_pruning/pruner/algorithms/base_pruner.py)
|
| Import | import torch_pruning as tp; tp.pruner.GrowingRegPruner
|
Constructor Signature:
class GrowingRegPruner(BasePruner):
def __init__(
self,
model: nn.Module,
example_inputs: torch.Tensor,
importance: typing.Callable,
reg=1e-5,
delta_reg=1e-5,
global_pruning: bool = False,
pruning_ratio: float = 0.5,
pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
max_pruning_ratio: float = 1.0,
iterative_steps: int = 1,
iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler,
ignored_layers: typing.List[nn.Module] = None,
round_to: int = None,
isomorphic: bool = False,
in_channel_groups: typing.Dict[nn.Module, int] = dict(),
out_channel_groups: typing.Dict[nn.Module, int] = dict(),
num_heads: typing.Dict[nn.Module, int] = dict(),
prune_num_heads: bool = False,
prune_head_dims: bool = True,
head_pruning_ratio: float = 0.0,
head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None,
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None,
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],
forward_fn: typing.Callable = None,
output_transform: typing.Callable = None,
channel_groups: typing.Dict[nn.Module, int] = dict(),
ch_sparsity: float = None,
ch_sparsity_dict: typing.Dict[nn.Module, float] = None,
):
Key Methods:
def update_regularizer(self):
"""Refresh group list and reset per-group regularization after pruning."""
def update_reg(self):
"""Increment per-channel regularization by delta_reg * standardized_importance."""
def regularize(self, model, bias=False):
"""Apply accumulated regularization to weight gradients in-place."""
I/O Contract
| Parameter | Type | Default | Description |
|---|---|---|---|
reg |
float | 1e-5 | Base regularization coefficient. All groups are initialized with this value. |
delta_reg |
float | 1e-5 | Increment added to regularization each epoch, scaled by standardized importance. |
bias |
bool | False | If True, regularize() also penalizes bias parameters.
|
Behavior:
update_reg()incrementsself.group_regbydelta_reg * standardized_importancewhere standardized importance maps raw importance scores to [0, 1] with lower-importance channels receiving higher values.regularize()modifiesweight.grad.datain-place by addinggamma * weightfor each channel, wheregammais the accumulated per-channel regularization coefficient.
Usage Examples
Training loop with growing regularization:
import torch
import torch.nn as nn
import torch_pruning as tp
# 1. Setup model and pruner
model = YourModel()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)
pruner = tp.pruner.GrowingRegPruner(
model=model,
example_inputs=example_inputs,
importance=imp,
reg=1e-5,
delta_reg=1e-5,
pruning_ratio=0.5,
iterative_steps=3,
)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# 2. Regularization phase: train with growing regularization
for epoch in range(num_reg_epochs):
pruner.update_reg() # increment per-channel penalties each epoch
for images, labels in train_loader:
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
loss.backward()
pruner.regularize(model) # apply regularization to gradients
optimizer.step()
# 3. Pruning step: remove low-importance channels
pruner.step()
# 4. Reset regularizer for next round (if iterative pruning)
pruner.update_regularizer()