Implementation:VainF Torch Pruning BNScalePruner
Appearance
Overview
Concrete tool for BN-scale-based network slimming regularization and pruning provided by Torch-Pruning.
Description
BNScalePruner extends BasePruner to implement Network Slimming. It adds L1 regularization on BatchNorm scaling factors during training. The regularize() method adds reg * sign(weight) to BN layer gradients, driving small scaling factors toward zero.
The pruner supports two regularization modes:
- Standard L1 mode (default): Iterates over all BatchNorm modules in the model and applies
weight.grad.data += reg * sign(weight.data). - Group lasso mode (
group_lasso=True): Operates at the dependency group level, computing the L2 norm of BN scales per group and applyingweight.grad.data += reg * (1 / ||group_l2||) * weight.data. This usesMagnitudeImportance(p=2)internally to compute group norms.
Code Reference
Source: torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py, Lines 10-147
Signature:
class BNScalePruner(BasePruner):
def __init__(
self,
model: nn.Module,
example_inputs: torch.Tensor,
importance: typing.Callable,
reg=1e-5,
group_lasso=False,
global_pruning: bool = False,
pruning_ratio: float = 0.5,
pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
max_pruning_ratio: float = 1.0,
iterative_steps: int = 1,
iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler,
ignored_layers: typing.List[nn.Module] = None,
round_to: int = None,
isomorphic: bool = False,
in_channel_groups: typing.Dict[nn.Module, int] = dict(),
out_channel_groups: typing.Dict[nn.Module, int] = dict(),
num_heads: typing.Dict[nn.Module, int] = dict(),
prune_num_heads: bool = False,
prune_head_dims: bool = True,
head_pruning_ratio: float = 0.0,
head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None,
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None,
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],
forward_fn: typing.Callable = None,
output_transform: typing.Callable = None,
# ... inherits all BasePruner params
):
Key Methods:
update_regularizer()-- Refreshes the internal group list from the dependency graph. Call once per epoch.regularize(model, reg=None, bias=False)-- Applies L1 (or group lasso) regularization to BN weight gradients in-place. Call after eachloss.backward().step()-- Inherited fromBasePruner. Executes one round of structural pruning.
Import:
import torch_pruning as tp
pruner = tp.pruner.BNScalePruner(...)
I/O Contract
| Parameter | Type | Default | Description |
|---|---|---|---|
reg |
float | 1e-5 | Regularization coefficient for BN scale penalty |
group_lasso |
bool | False | Enable group lasso variant (L2-based group regularization) |
model |
nn.Module | (required) | The model to prune |
example_inputs |
torch.Tensor | (required) | Dummy inputs for dependency graph tracing |
importance |
Callable | (required) | Group importance estimator |
Behavior:
regularize()modifiesBN weight.grad.datain-place. In standard mode, it appliesreg * sign(weight). In group lasso mode, it appliesreg * (1 / sqrt(group_l2_norm)) * weight.step()prunes the model in-place, removing channels with the smallest importance scores.
Usage Examples
import torch
import torch.nn as nn
import torch_pruning as tp
# 1. Setup model and pruner for Network Slimming
model = ResNet50(num_classes=100)
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.BNScaleImportance()
pruner = tp.pruner.BNScalePruner(
model=model,
example_inputs=example_inputs,
importance=imp,
reg=1e-5,
group_lasso=False, # set True for group lasso variant
pruning_ratio=0.5,
iterative_steps=1,
)
# 2. Sparse training loop with BN scale regularization
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(num_epochs):
pruner.update_regularizer() # refresh groups each epoch
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = criterion(output, target)
loss.backward()
pruner.regularize(model) # apply L1 regularization to BN scales
optimizer.step()
optimizer.zero_grad()
# 3. Prune the model after sparse training
pruner.step()
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment