Principle:VainF Torch Pruning YOLO Architecture Adaptation
Metadata
| Field | Value |
|---|---|
| Domains | Computer_Vision, Object_Detection, Pruning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Adapting YOLO detection model architectures to be compatible with structural pruning by replacing non-traceable operations.
Description
YOLO models (particularly YOLOv8) use C2f modules that internally use torch.chunk() operations which cannot be properly traced by dependency graphs. The solution is to replace C2f with a pruning-compatible C2f_v2 variant that splits the fused cv1 convolution into separate cv0 and cv1 convolutions, making the data flow explicit and traceable. Weight transfer preserves the original model's learned parameters. This architectural adaptation is required before pruning YOLO models.
The core problem is that torch.chunk() creates an implicit data dependency that DepGraph cannot resolve:
- The single
cv1convolution produces a tensor that is split into two halves - One half feeds into the bottleneck chain, the other is concatenated at the end
- DepGraph sees one convolution output going to two destinations but cannot determine which output channels map to which destination
By replacing the single convolution + chunk with two separate convolutions (cv0 and cv1), the data flow becomes explicit and fully traceable.
Usage
This is a required preprocessing step before pruning YOLOv8 models. The adaptation must be performed after loading the model but before constructing the pruner's dependency graph.
Similar adaptations may be needed for other detection architectures with non-traceable operations, including:
- Models using
torch.split()for multi-branch architectures - Models using
torch.chunk()for channel splitting - Any architecture where a single layer's output is implicitly divided among multiple consumers
Typical workflow:
- Load the YOLOv8 model from Ultralytics
- Call
replace_c2f_with_c2f_v2(model)to replace all C2f modules - Verify the model produces identical outputs (weight transfer preserves behavior)
- Proceed with pruner construction and pruning
Theoretical Basis
The original C2f module uses an implicit split via torch.chunk():
# Original C2f forward pass
def forward(self, x):
x = self.cv1(x) # single convolution: c1 -> 2*c channels
x = x.chunk(2, dim=1) # implicit split: [c, c] -- NOT traceable!
y = [x[0], x[1]]
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
The replacement C2f_v2 module uses explicit separate convolutions:
# C2f_v2 forward pass
def forward(self, x):
y = [self.cv0(x), self.cv1(x)] # two separate convolutions -- traceable!
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
Weight transfer maps the original fused convolution weights to the two split convolutions:
# Weight transfer scheme
c = cv1.out_channels // 2
cv0.conv.weight = cv1.conv.weight[:c, :, :, :] # first half of output channels
cv1.conv.weight = cv1.conv.weight[c:, :, :, :] # second half of output channels
# Same splitting for batch normalization parameters
cv0.bn.weight = cv1.bn.weight[:c]
cv0.bn.bias = cv1.bn.bias[:c]
cv1.bn.weight = cv1.bn.weight[c:]
cv1.bn.bias = cv1.bn.bias[c:]
This transformation is mathematically equivalent -- the C2f_v2 module produces identical outputs to the original C2f module when initialized with transferred weights.